import torch
import time
import evaluate
from .logging_utils import Averager
from datasets.iterable_dataset import IterableDataset
import wandb
import hydra
from tqdm import tqdm

def maybe_save_checkpoint(accelerator, args):
    if (
        args.current_train_step > args.optim.total_steps
        or args.current_train_step % args.checkpoint.every_steps == 0
    ):

        if args.mode != "ft":
            output_dir = f'checkpoint-{args.mode}-{args.current_train_step}-{wandb.run.id}'
            accelerator.save_state(output_dir=f"/work/YamadaU/takezawa/polyak/nanoT5/checkpoints/{args.seed}/{wandb.run.id}")
        
def maybe_eval_predict(model, dataloader, logger, args, tokenizer):
    if (
        args.current_train_step > args.optim.total_steps
        or args.current_train_step % args.eval.every_steps == 0
    ):
        model.eval()
        
        with torch.no_grad():
            eval_loss = eval(model, dataloader, logger, args, tokenizer)
            
            if args.mode == 'ft':
                predict(
                    model, dataloader, logger, args, tokenizer
                )

        args.last_log = time.time()
        model.train()
        return eval_loss
    else:
        return None        

def maybe_logging(averager, args, model, optimizer, logger):
    if args.current_train_step % args.logging.every_steps == 0:
        stats = extra_stats(args, model, optimizer)

        averager.update(stats)
        averaged_stats = averager.average()

        logger.log_stats(
            stats=averaged_stats,
            step=args.current_train_step,
            args=args,
            prefix='train/'
        )

        args.last_log = time.time()

        
def maybe_grad_clip_and_grad_calc(accelerator, model, args):
    if args.optim.grad_clip > 0:
        grad_l2 = accelerator.clip_grad_norm_(
            parameters=model.parameters(),
            max_norm=args.optim.grad_clip,
            norm_type=2,
        )
    else:
        grad_l2 = None

    if args.logging.grad_l2:
        if grad_l2 is None:
            grad_l2 = (
                sum(p.grad.detach().data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
            )

        return {'grad_l2': grad_l2}
    else:
        return {}


def extra_stats(args, model, optimizer):
    stats = {}

    if args.logging.weights_l2:
        weights_l2 = sum(p.detach().norm(2).item() ** 2 for p in model.parameters()) ** 0.5
        stats['weights_l2'] = weights_l2

    stats['lr'] = optimizer.param_groups[0]['lr']
    stats['seconds_per_step'] = (time.time() - args.last_log) / args.logging.every_steps

    return stats


def forward(model, batch, calc_acc=False):
    outputs = model(**batch)
    loss = outputs.loss

    stats = {}
    stats['loss'] = loss.detach().float().item()

    if calc_acc:
        correct = (outputs.logits.argmax(-1) == batch["labels"]).sum().item()
        accuracy = correct / batch["labels"].numel()
        stats['accuracy'] = accuracy

    return loss, stats


def eval(model, dataloader, logger, args, tokenizer):
    args.last_log = time.time()
    averager = Averager()
    
    for batch_id, batch in enumerate(dataloader, start=1):
        if batch_id == args.eval.corrected_steps * args.optim.grad_acc:
            break

        _, stats = forward(model, batch, calc_acc=True)
        averager.update(stats)

    averager.update({'time': time.time() - args.last_log})
    averaged_stats = averager.average()
    
    logger.log_stats(
        stats=averaged_stats,
        step=args.current_train_step,
        args=args,
        prefix='eval/'
    )
    return averaged_stats["loss"]


def eval_all(model, dataloader, logger, args, tokenizer):
    args.last_log = time.time()
    averager_val = Averager()
    averager_test = Averager()
    
    model.eval()
    with torch.no_grad():
        for batch_id, batch in enumerate(tqdm(dataloader), start=1):
            _, stats = forward(model, batch, calc_acc=True)

            if batch_id <= 2530:
                averager_val.update(stats)
            else:
                averager_test.update(stats)

    #assert batch_id==5060, "batch_id is not 5060"
    print(f"batch_id {batch_id}")
    
    averager_val.update({'time': time.time() - args.last_log})
    averager_test.update({'time': time.time() - args.last_log})
    averaged_val_stats = averager_val.average()
    averaged_test_stats = averager_test.average()
    model.train()

    return averaged_val_stats["loss"], averaged_test_stats["loss"]


def eval_estimated_val(model, dataloader, logger, args, tokenizer):
    args.last_log = time.time()
    averager_val = Averager()
    
    model.eval()
    with torch.no_grad():
        for batch_id, batch in enumerate(tqdm(dataloader), start=1):
            _, stats = forward(model, batch, calc_acc=True)

            if batch_id <= 500:
                averager_val.update(stats)
            else:
                break

    averager_val.update({'time': time.time() - args.last_log})
    averaged_val_stats = averager_val.average()
    model.train()

    return averaged_val_stats["loss"]


def store_grad_norm(model):
    state = model.state_dict()
    grad_norms = {}
    for key, p in zip(state.keys(), model.parameters()):
        if p.grad is None:
            grad_norms[f"grad/{key}"] = None
        else:
            grad_norms[f"grad/{key}"] = torch.norm(p.grad.data)
    return grad_norms


def predict(model, dataloader, logger, args, tokenizer):
    args.last_log = time.time()
    metric = evaluate.load('rouge')
    samples_seen = 0

    def decode(preds):
        preds[preds == -100] = tokenizer.pad_token_id
        preds = tokenizer.batch_decode(
            preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        preds = [pred.strip() for pred in preds]
        return preds

    for step, batch in enumerate(dataloader):
        predictions = model.generate(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            max_length=args.data.max_target_len,
            generation_config=model.generation_config,
        )
        predictions = decode(predictions)
        references = decode(batch["labels"])

        # If we are in a multiprocess environment, the last batch has duplicates
        if step == len(dataloader) - 1:
            predictions = predictions[: len(dataloader.dataset) - samples_seen]
            references = references[: len(dataloader.dataset) - samples_seen]
        else:
            samples_seen += len(references)

        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    eval_metric = metric.compute(use_stemmer=True, use_aggregator=False)
    rougeL = sum(eval_metric["rougeL"]) * 100 / len(eval_metric["rougeL"])
    
    logger.log_stats(
        stats={
            "rougeL": rougeL,
            "time": time.time() - args.last_log,
        },
        step=args.current_train_step,
        args=args,
        prefix="test/",
    )

    return rougeL


def train(model, train_dataloader, test_dataloader, accelerator, lr_scheduler,
          optimizer, logger, args, tokenizer):
    model.train()

    print(f"wandb id {wandb.run.id}")
    
    train_averager = Averager()
    estimated_val_loss = eval_estimated_val(model, test_dataloader, logger, args, tokenizer)
    #grad_norms = store_grad_norm(model)
    wandb.log({"iteration": args.current_train_step-1,
               "train/minibatch_loss": None,
               "test/full_loss": None,
               "val/full_loss": None,
               "val/estimated_loss": estimated_val_loss,
               "rougeL": None}, step=args.current_train_step-1)
    
    while args.current_train_step <= args.optim.total_steps:
        if isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(args.current_epoch)

        # In case there is a remainder from previous epoch, we need to reset the optimizer
        optimizer.zero_grad(set_to_none=True)

        for batch_id, batch in enumerate(train_dataloader, start=1):
            if args.current_train_step > args.optim.total_steps:
                break

            loss, stats = forward(model, batch)
            accelerator.backward(loss / args.optim.grad_acc)
            train_averager.update(stats)

            if batch_id % args.optim.grad_acc == 0:
                stats = maybe_grad_clip_and_grad_calc(accelerator, model, args)
                train_averager.update(stats)

                if "polyak" in args.optim.name or "sps" in args.optim.name:
                    optimizer.step(loss)
                else:
                    optimizer.step()
                    
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                maybe_logging(train_averager, args, model, optimizer, logger)

                if args.current_train_step % 200 == 0:
                    estimated_val_loss = eval_estimated_val(model, test_dataloader, logger, args, tokenizer)
                else:
                    estimated_val_loss = None

                wandb.log({"iteration": args.current_train_step,
                           "train/minibatch_loss": loss,
                           "test/full_loss": None,
                           "val/full_loss": None,
                           "val/estimated_loss": estimated_val_loss,
                           "rougeL": None}, step=args.current_train_step)
                
                args.current_train_step += 1

        args.current_epoch += 1

    maybe_save_checkpoint(accelerator, args)

    print(f'log path: checkpoint-{args.mode}-{args.current_train_step}-{wandb.run.id}')
    print("start final evaluation")
    val_loss, test_loss = eval_all(model, test_dataloader, logger, args, tokenizer)

    if args.mode == "ft":
        model.eval()
        with torch.no_grad():
            rougeL = predict(model, test_dataloader, logger, args, tokenizer)
    else:
        rougeL = None
        
    #grad_norms = store_grad_norm(model)            
    wandb.log({"iteration": None,
               "train/minibatch_loss": None,
               "test/minibatch_loss": None,
               "val/full_loss": val_loss,
               "test/full_loss": test_loss,
               "val/estimated_loss": None,
               "rougeL": rougeL},
              step=args.current_train_step)

