import json
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoConfig
from datasets import load_dataset
from functools import partial
import gc

from debug import convert_kvcache_llama_heavy_recent, convert_llama_channel_config

def evaluate(model, tokenizer):
    testenc = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    testenc = tokenizer("\n\n".join(testenc['text']), return_tensors='pt')

    torch.cuda.empty_cache()
    gc.collect()

    max_seq_len = 2048

    testenc = testenc.input_ids.to(model.device)
    print(testenc.shape)
    nsamples = testenc.shape[1] // max_seq_len
    model = model.eval()

    nlls = []

    # 57 -> nan

    gc.collect()
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * max_seq_len):((i + 1) * max_seq_len)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * max_seq_len):((i + 1) * max_seq_len)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * max_seq_len
        nlls.append(neg_log_likelihood)
        batch = None
        torch.cuda.empty_cache()
        gc.collect()

    return torch.exp(torch.stack(nlls).sum() / (nsamples * max_seq_len))

# llama-2-7b: 5.47
# model_path = "meta-llama/Llama-2-7b-hf"
# channel_path = "llama2-7b-channel-config.json"
# channel_path = "llama2-7b-qk-channel-config.json"

# llama-2-7b-chat: 6.94
model_path = "meta-llama/Llama-2-7b-chat-hf"
# channel_path = "llama2-7b-chat-channel-config.json"
channel_path = "llama2-7b-chat-qk-channel-config.json"

# llama-7b: 5.68
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16"
# channel_path = "llama-7b-channel-config.json"
# channel_path = "llama-7b-qk-channel-config.json"

# vicuna-7b-v1.5-16k: 7.15
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--lmsys--vicuna-7b-v1.5-16k/snapshots/9a93d7d11fac7f3f9074510b80092b53bc1a5bec"
# channel_path = "vicuna-7b-v1.5-16k-channel-config.json"
# channel_path = "vicuna-7b-v1.5-16k-qk-channel-config.json"

# opt-6.7b: 10.86
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0"

# model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()
model = LlamaForCausalLM.from_pretrained(model_path).half().cuda()
# tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path)

config = AutoConfig.from_pretrained(model_path)




channel_config = None
with open(channel_path, "r") as f:
    channel_config = json.load(f)

model = convert_kvcache_llama_heavy_recent(model, config, 128, 32)
model = convert_llama_channel_config(model, channel_config, "qk")


print(evaluate(model, tokenizer))


