import os
import pathlib
from dataclasses import dataclass

import torch
import transformers
from argparse_dataclass import ArgumentParser
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm, trange
import json

from transformers import DynamicCache

from quick_extend.models.load_model import ModelConfig, load_model


@dataclass
class EvalConfig(ModelConfig):
    stride: int = 4096
    overwrite: bool = False
    disable_mem_efficient_ce: bool = False


@torch.inference_mode()
def job_ppl(args: EvalConfig, model, tokenizer: transformers.LlamaTokenizer, device):
    model_name = (args.init_from_checkpoint if args.init_from_checkpoint else args.model)
    model_name = model_name.replace("/", "_").replace('.', '_')
    if args.use_infinigen:
        outfile = (f'./cache/llama_eval/{model_name}-infinigen'
                   f'-tk{args.hip_top_k_elems}-es{args.end_sink_tokens}.json')
    elif args.disable_hip:
        outfile = f'./cache/llama_eval/{model_name}-nohip.json'
    else:
        outfile = (f'./cache/llama_eval/{model_name}'
                   f'-tk{args.hip_top_k_elems}-bq{args.hip_block_size_q}-bk{args.hip_block_size_k}'
                   f'-ss{args.start_sink_tokens}-es{args.end_sink_tokens}-dl{args.dense_layers}.json')
    pathlib.Path(outfile).parent.mkdir(parents=True, exist_ok=True)
    print("Will write to", outfile)
    if os.path.exists(outfile) and not args.overwrite:
        print(f'PPL already computed, skipping: {outfile}')
        return

    os.makedirs('./cache', exist_ok=True)
    cache_path = f'./cache/llama_eval_{args.model.replace("/", "_")}.pth'
    if not os.path.exists(cache_path):
        test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt").input_ids
        torch.save(encodings, cache_path)
    else:
        encodings = torch.load(cache_path)

    max_length = model.config.max_position_embeddings if hasattr(model, 'config') else 2048
    max_length = stride = args.stride if args.stride > 0 else max_length
    seq_len = encodings.size(1)

    nlls = []
    prev_end_loc = 0
    with tqdm(range(0, seq_len, stride)[:-1], dynamic_ncols=True) as pbar:
        for begin_loc in pbar:
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
            input_ids = encodings[:, begin_loc:end_loc].to(device)
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            model.eval()
            with torch.no_grad():
                if args.use_infinigen:
                    prefill_len = args.hip_top_k_elems + args.end_sink_tokens
                    outputs = model(
                        input_ids[:, :prefill_len],
                        past_key_values=DynamicCache(),
                        use_cache=True,
                    )
                    past_key_values = outputs.past_key_values

                    loss_fct = CrossEntropyLoss(reduction='sum')
                    shift_logits = outputs.logits.view(-1, model.config.vocab_size)
                    shift_labels = target_ids[:, 1:prefill_len + 1].view(-1)
                    # Enable model parallelism
                    shift_labels = shift_labels.to(shift_logits.device)
                    neg_log_likelihood = loss_fct(shift_logits, shift_labels)
                    print(f"ppl={torch.exp(neg_log_likelihood / prefill_len).item():.3f}")

                    progress = trange(prefill_len, input_ids.size(1))
                    for i in progress:
                        outputs = model(
                            input_ids[:, i:i + 1],
                            past_key_values=past_key_values,
                            use_cache=True,
                        )
                        past_key_values = outputs.past_key_values

                        if i < input_ids.size(1) - 1:
                            loss_fct = CrossEntropyLoss(reduction='sum')
                            shift_logits = outputs.logits.view(-1, model.config.vocab_size)
                            shift_labels = target_ids[:, i + 1:i + 2].view(-1)
                            # Enable model parallelism
                            shift_labels = shift_labels.to(shift_logits.device)
                            loss = loss_fct(shift_logits, shift_labels)
                            neg_log_likelihood += loss

                            if i % 10 == 0:
                                progress.set_description(f"cur={loss.item():.3f} "
                                                         f"ppl={torch.exp(neg_log_likelihood / (i + 1)).item():.3f}")

                    neg_log_likelihood /= input_ids.size(1)

                else:
                    kwargs = {}
                    if not args.disable_mem_efficient_ce:
                        kwargs['use_efficient_ce'] = True
                    outputs = model(
                        input_ids,
                        labels=target_ids,
                        **kwargs,
                    )
                    neg_log_likelihood = outputs.loss

            nlls.append(neg_log_likelihood.cpu())

            prev_end_loc = end_loc

            ppl = torch.exp(torch.stack(nlls).mean()).item()
            pbar.set_description(f"ppl: {ppl:.3f}")

            if end_loc == seq_len:
                break

    ppl = torch.exp(torch.stack(nlls).mean()).item()

    os.makedirs('./cache/llama_eval/', exist_ok=True)
    with open(outfile, 'w') as f:
        json.dump({'ppl': ppl}, f)

    print(f'PPL: {ppl:.4f}')


def main():
    parser = ArgumentParser(EvalConfig)
    args = parser.parse_args()
    print(args)

    model, tokenizer = load_model(model_config=args, for_training=False)

    job_ppl(args, model, tokenizer, 'cuda')


if __name__ == '__main__':
    main()
