import torch
import torch.nn as nn
from tqdm import tqdm
import os
from loguru import logger

import hydra
from omegaconf import DictConfig, OmegaConf

from src.dataset import get_wikitext2
from src.dataset import get_c4
from src.dataset import get_c4_ko
from src.model import load_model_and_tokenizer
from src.recorder import recorder


def get_ppl_eval_loaders(name, tokenizer, seqlen):
    if "wikitext2" in name:
        test_loader, testenc = get_wikitext2(-1, 42, seqlen, tokenizer, "test")
        return test_loader
    elif "c4" in name:
        if "ko" in name:
            test_loader, testenc = get_c4_ko(-1, 42, seqlen, tokenizer, "test")
        else:
            test_loader, testenc = get_c4(-1, 42, seqlen, tokenizer, "test")
        return test_loader
    else:
        raise NotImplementedError(f"No such dataset: {name}")


@torch.no_grad()
def eval_ppl(model, tokenizer, model_name, datasets, seqlen):
    results = {}

    for dataset in datasets.split(","):
        cache_testloader = (
            f"/tmp/{dataset}_test_{model_name.replace('/', '_')}_{seqlen}_all.cache"
        )
        if os.path.exists(cache_testloader):
            testloader = torch.load(cache_testloader)
        else:
            testloader = get_ppl_eval_loaders(dataset, tokenizer, seqlen)
            torch.save(testloader, cache_testloader)
        
        nlls = []

        for batch in tqdm(testloader):
            inp, labels = batch

            inp = inp.to(model.device)
            labels = labels.to(model.device)

            outputs = model(inp, use_cache=False)
            shift_logits = outputs.logits[:, :-1, :]
            shift_labels = labels[:, 1:]

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
            print(loss)

            neg_log_likelihood = loss.float() * seqlen
            nlls.append(neg_log_likelihood)
            
        ppl = torch.exp(torch.stack(nlls).sum() / (len(nlls) * seqlen))
        results.update({dataset: ppl.item()})

    return results


@hydra.main(version_base=None, config_path="src/conf", config_name="ppl_config")
def main(cfg: DictConfig) -> None:
    logger.remove()
    logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO")
    logger.info(f"Evaluating PPL with the following config: {cfg}")
    model, tokenizer = load_model_and_tokenizer(cfg.model_name_or_path, cfg.quantizer, forward_quant=True)

    logger.info(f"Start evaluating ppl...")
    logger.info(f"*model: {cfg.model_name_or_path}")
    logger.info(f"*datasets: {cfg.dataset}")
    logger.info(f"*sequence length {cfg.seqlen}")
    results = eval_ppl(model, tokenizer, cfg.model_name_or_path, cfg.dataset, cfg.seqlen)

    os.makedirs(cfg.output_dir, exist_ok=True)
    with open(os.path.join(cfg.output_dir, f"{cfg.dataset}_results_{cfg.model_name_or_path.split('/')[-1]}_{cfg.quantizer.save_postfix}.txt"), "w") as f:
        for dataset, ppl in results.items():
            logger.info(f"PPL: {ppl}")
            print(f"{dataset}: {ppl}", file=f)

    recorder.agg_everything()
    print(recorder.values)
    
if __name__ == '__main__':
    main()
