import torch
import torch.nn as nn
from tqdm import tqdm
import os
from loguru import logger

import hydra
from omegaconf import DictConfig, OmegaConf

from src.dataset import get_wikitext2
from src.dataset import get_c4

from src.dataset import get_c4_ko
from src.model import load_model_and_tokenizer
from src.recorder import recorder


def get_ppl_eval_loaders(name, tokenizer, seqlen):
    if "wikitext2" in name:
        test_loader, testenc = get_wikitext2(-1, 42, seqlen, tokenizer, "test")
        return test_loader
    elif "c4" in name:
        if "ko" in name:
            test_loader, testenc = get_c4_ko(-1, 42, seqlen, tokenizer, "test")
        else:
            test_loader, testenc = get_c4(-1, 42, seqlen, tokenizer, "test")
        return test_loader
    else:
        raise NotImplementedError(f"No such dataset: {name}")

def collate_fn(batch):
    # batch: list of (input_ids, labels) tuples
    input_ids, labels = zip(*batch)
    input_ids = torch.cat(input_ids, dim=0)  # (B, T)
    labels = torch.cat(labels, dim=0)
    return input_ids, labels

@torch.no_grad()
def eval_ppl(model, tokenizer, model_name, datasets, seqlen):
    results = {}

    for dataset in datasets.split(","):
        cache_testloader = (
            f"/tmp/{dataset}_test_{model_name.replace('/', '_')}_{seqlen}_all.cache"
        )
        if os.path.exists(cache_testloader):
            testloader = torch.load(cache_testloader)
        else:
            testloader = get_ppl_eval_loaders(dataset, tokenizer, seqlen)
        torch.save(testloader, cache_testloader)

        # Limit the number of data
        # testloader = testloader[:10]

        batch_size = 8
        nlls = []
        loss_fct = nn.CrossEntropyLoss(reduction='none')  # 개별 loss 보기 위해 none
        from torch.utils.data import DataLoader
        testloader = DataLoader(testloader, batch_size=batch_size, collate_fn=collate_fn)
        for batch in tqdm(testloader):
            inp, labels = batch
            inp = inp.to(model.device)
            labels = labels.to(model.device)

            curr_batch_size, seqlen = inp.shape
            past_key_values = None
            total_nll = torch.zeros(curr_batch_size, device=model.device)

            for i in tqdm(range(seqlen - 1)):  # 마지막 토큰은 예측 안함
                input_ids = inp[:, i].unsqueeze(1)  # (B, 1)
                target = labels[:, i + 1]  # 다음 토큰이 target
                position_ids = torch.full((curr_batch_size, 1), i, dtype=torch.long, device=model.device)  # (B, 1)

                outputs = model(input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, use_cache=True)
                logits = outputs.logits[:, -1, :]  # 마지막 위치의 logit
                past_key_values = outputs.past_key_values  # 캐시 갱신

                loss = loss_fct(logits, target)  # (B,)
                nlls.append(loss)
            
        ppl = torch.exp(torch.cat(nlls, dim=0).mean())
        results.update({dataset: ppl.item()})

    return results


@hydra.main(version_base=None, config_path="src/conf", config_name="ppl_config")
def main(cfg: DictConfig) -> None:
    logger.remove()
    logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO")
    logger.info(f"Evaluating PPL with the following config: {cfg}")
    model, tokenizer = load_model_and_tokenizer(cfg.model_name_or_path, cfg.quantizer, forward_quant=False)

    logger.info(f"Start evaluating ppl...")
    logger.info(f"*model: {cfg.model_name_or_path}")
    logger.info(f"*datasets: {cfg.dataset}")
    logger.info(f"*sequence length {cfg.seqlen}")
    results = eval_ppl(model, tokenizer, cfg.model_name_or_path, cfg.dataset, cfg.seqlen)

    cfg.output_dir += "_generative"
    os.makedirs(cfg.output_dir, exist_ok=True)
    with open(os.path.join(cfg.output_dir, f"{cfg.dataset}_results_{cfg.model_name_or_path.split('/')[-1]}_{cfg.quantizer.save_postfix}.txt"), "w") as f:
        for dataset, ppl in results.items():
            logger.info(f"PPL: {ppl}")
            print(f"{dataset}: {ppl}", file=f)

    # import numpy as np
    # keys = ["k_cossim", "v_cossim"]
    # for key_ in keys:
    #     for i in range(32):
    #         print(key_, i, np.mean(recorder.values[key_][i]))
    recorder.agg_everything()
    print(recorder.values)
    # import pickle
    # with open("vis_data/mean_std_test.pickle", "wb") as f:
    #     pickle.dump(recorder.values, f)
    
if __name__ == '__main__':
    main()
