import os

import datasets
from dotenv import load_dotenv
import numpy as np
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from transformers.optimization import Adafactor, AdafactorSchedule
import torch
from torch import nn
from tqdm import tqdm
import wandb

from text_ood.transformer.transformer import Transformer
from text_ood.dataset.wmt import WMTDataset


def eval_loss(model, loader, tokenizer, src_lang, trg_lang, input_ctx_len, device):
    model.eval()
    losses = []
    with torch.no_grad():
        for batch in loader:
            tokenized = tokenizer(batch['translation'][src_lang],
                                  text_target=batch['translation'][trg_lang],
                                  truncation=True,
                                  padding=True,
                                  max_length=input_ctx_len,
                                  return_tensors='pt')

            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            target_input_ids = tokenized['labels'].to(device)
            target_attention_mask = (target_input_ids != tokenizer.pad_token_id).int()
            target_input_ids_input = torch.concat([torch.full([len(target_input_ids), 1], tokenizer.pad_token_id, device=device), target_input_ids], dim=1)[:, :-1]
            target_attention_mask_input = torch.concat([torch.full([len(target_attention_mask), 1], 1, device=device), target_attention_mask], dim=1)[:, :-1]
            predictions = model(input_ids=input_ids,
                                attention_mask=attention_mask,
                                target_input_ids=target_input_ids_input,
                                target_attention_mask=target_attention_mask_input)

            predictions = torch.masked_select(predictions, target_attention_mask.unsqueeze(-1).bool()).reshape(-1, len(tokenizer))
            target_input_ids = torch.masked_select(target_input_ids, target_attention_mask.bool())
            loss = nn.CrossEntropyLoss(reduction='none')(predictions, target_input_ids)
            losses.append(loss)
        losses = torch.concat(losses, dim=0)
        loss = torch.mean(losses, dim=0)
    return loss



def save_checkpoint(model, optimizer, epoch, n_steps, scheduler, filename='checkpoint.pth.tar'):
    print("=> saving checkpoint '{}'".format(filename))
    state = {'epoch': epoch, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'n_steps': n_steps, 'scheduler': scheduler.state_dict()}
    torch.save(state, filename)


def main():
    
    load_dotenv()

    continue_training = False
    checkpoint_path = os.environ['WMT_MODEL_CHECKPOINT']

    wandb.init(project='ap-ood-transformer')

    torch.backends.cuda.enable_flash_sdp(True)

    device = 'cuda'

    train_dataset = WMTDataset(split='train', cache_dir=os.environ['WMT_ROOT'])
    test_dataset = WMTDataset(split='test', cache_dir=os.environ['WMT_ROOT'])

    tokenizer = AutoTokenizer.from_pretrained('google-t5/t5-small')
    model = Transformer(len(tokenizer)).to(device)

    input_ctx_len = 512

    src_lang = 'en'
    trg_lang = 'fr'

    def tokenize_function(examples):
        src_texts = [ex[src_lang] for ex in examples['translation']]
        trg_texts = [ex[trg_lang] for ex in examples['translation']]
        
        tokenized = tokenizer(src_texts, text_target=trg_texts, truncation=True, padding='max_length', max_length=input_ctx_len, return_tensors="pt")
        return tokenized

    per_device_batch_size = 128
    gradient_step_batch_size = 1024
    max_steps = 100_000
    n_eval_steps = 1_000

    gradient_accumulation_steps = gradient_step_batch_size // per_device_batch_size

    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=per_device_batch_size, drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=per_device_batch_size)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=max_steps)

    n_steps = 0
    epoch = 0

    pbar = tqdm(total=max_steps)
    pbar.update(n_steps)
    losses = []
    n_tokens = []
    while n_steps < max_steps:
        for batch in train_loader:
            model.train()
            tokenized = tokenizer(batch['translation'][src_lang],
                                   text_target=batch['translation'][trg_lang],
                                   truncation=True,
                                   padding=True,
                                   max_length=input_ctx_len,
                                   return_tensors='pt')

            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            target_input_ids = tokenized['labels'].to(device)
            target_attention_mask = (target_input_ids != tokenizer.pad_token_id).int()
            target_input_ids_input = torch.concat([torch.full([len(target_input_ids), 1], tokenizer.pad_token_id, device=device), target_input_ids], dim=1)[:, :-1]
            target_attention_mask_input = torch.concat([torch.full([len(target_attention_mask), 1], 1, device=device), target_attention_mask], dim=1)[:, :-1]
            with torch.autocast(device, torch.bfloat16):
                predictions = model(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    target_input_ids=target_input_ids_input,
                                    target_attention_mask=target_attention_mask_input)

                predictions = torch.masked_select(predictions, target_attention_mask.unsqueeze(-1).bool()).reshape(-1, len(tokenizer))
                target_input_ids = torch.masked_select(target_input_ids, target_attention_mask.bool())
                loss = nn.CrossEntropyLoss()(predictions, target_input_ids)

            (loss / gradient_accumulation_steps).backward()
            losses.append(loss)
            n_tokens.append(len(target_input_ids))
            if len(losses) == gradient_accumulation_steps:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                n_steps += 1
                pbar.update()

                wandb.log({
                    'train/loss': sum(losses) / len(losses),
                    'general/epoch': epoch,
                    'general/step': n_steps,
                    'train/lr': optimizer.param_groups[0]['lr'],
                    'train/n_tokens': sum(n_tokens),
                },
                step=n_steps)
                losses = []
                n_tokens = []

                if n_steps % n_eval_steps == 0 and n_steps != 0:
                    val_loss = eval_loss(model, test_loader, tokenizer, src_lang, trg_lang, input_ctx_len, device)
                    wandb.log({
                        'val/loss': val_loss,
                        'general/epoch': epoch,
                        'general/step': n_steps,
                        'train/lr': optimizer.param_groups[0]['lr']
                    },
                    step=n_steps)
                    print(f'Saving model in: {checkpoint_path}')
                    os.makedirs(os.path.split(checkpoint_path)[0], exist_ok=True)
                    torch.save(model.state_dict(), os.environ['WMT_MODEL_CHECKPOINT'])

                if n_steps >= max_steps:
                    break

        epoch = epoch + 1

    wandb.finish()


if __name__ == '__main__':
    main()
