import json
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoConfig
from datasets import load_dataset
from functools import partial
import gc

from modify_llama import convert_kvcache_llama_heavy_recent, convert_llama_channel_config
from modify_mistral import convert_kvcache_mistral_heavy_recent, convert_mistral_channel_config
# from modify_mixtral import convert_kvcache_mixtral_heavy_recent, convert_mixtral_channel_config
from streaming_llama import convert_streaming
from rtn_llama import convert_rtn
from offload_llama import convert_kvcache_llama_offloading, convert_llama_offloading_channel_config
from offload_mistral import convert_kvcache_mistral_offloading, convert_mistral_offloading_channel_config


def evaluate(model, tokenizer):
    testenc = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    testenc = tokenizer("\n\n".join(testenc['text']), return_tensors='pt')

    torch.cuda.empty_cache()
    gc.collect()

    max_seq_len = 2048

    testenc = testenc.input_ids.to(model.device)
    print(testenc.shape)
    nsamples = testenc.shape[1] // max_seq_len
    model = model.eval()

    nlls = []

    # 57 -> nan

    gc.collect()
    for i in tqdm.tqdm(range(nsamples), desc="evaluating..."):
        batch = testenc[:, (i * max_seq_len):((i + 1) * max_seq_len)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = testenc[:, (i * max_seq_len):((i + 1) * max_seq_len)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * max_seq_len
        nlls.append(neg_log_likelihood)
        batch = None
        torch.cuda.empty_cache()
        gc.collect()

    return torch.exp(torch.stack(nlls).sum() / (nsamples * max_seq_len))

# llama-2-7b: 5.47
model_path = "meta-llama/Llama-2-7b-hf"
# channel_path = "llama2-7b-channel-config.json"
channel_path = "llama2-7b-qk-channel-config.json"

# llama-2-7b-chat: 6.94
# model_path = "meta-llama/Llama-2-7b-chat-hf"
# channel_path = "llama2-7b-chat-channel-config.json"
# channel_path = "llama2-7b-chat-qk-channel-config.json"

# llama-7b: 5.68
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16"
# channel_path = "llama-7b-channel-config.json"
# channel_path = "llama-7b-qk-channel-config.json"

# vicuna-7b-v1.5-16k: 7.15
# model_path = "lmsys/vicuna-7b-v1.5-16k"
# channel_path = "vicuna-7b-v1.5-16k-channel-config.json"
# channel_path = "vicuna-7b-v1.5-16k-qk-channel-config.json"

# mistral-7b: 5.25
# model_path = "mistralai/Mistral-7B-v0.1"
# channel_path = "mistral-7b-v0.1-channel-config.json"
# channel_path = "mistral-7b-v0.1-qk-channel-config.json"

# llama-2-70b-chat: 4.65
# model_path = "meta-llama/Llama-2-70b-chat-hf"
# channel_path = "llama2-70b-chat-channel-config.json"
# channel_path = "llama2-70b-chat-qk-channel-config.json"

# llama-2-70b: 3.32
# model_path = "meta-llama/Llama-2-70b-hf"
# channel_path = "llama2-70b-channel-config.json"
# channel_path = "llama2-70b-qk-channel-config.json"

# mixtral-8x7B-v0.1: 3.84
# model_path = "mistralai/Mixtral-8x7B-v0.1"
# channel_path = "mixtral-8x7b-channel-config.json"
# channel_path = "mixtral-8x7b-qk-channel-config.json"


# opt-6.7b: 10.86
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0"

if "70b" in model_path:
    # TODO: support more than 8 x a10g
    device_map = {"model.embed_tokens": 0, "model.norm": 7, "lm_head": 7}
    for i in range(80):
        device_map[f"model.layers.{i}"] = i // 10
else:
    device_map = "auto"

# kwargs = {"torch_dtype": torch.float16, "device_map": device_map}
kwargs = {"torch_dtype": torch.float16}

# model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
# model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs).cuda()
model = LlamaForCausalLM.from_pretrained(model_path, **kwargs).cuda()
# tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path)

config = AutoConfig.from_pretrained(model_path)




channel_config = None
with open(channel_path, "r") as f:
    channel_config = json.load(f)

# model = convert_kvcache_llama_heavy_recent(model, config, 128, 4, 4)
# model = convert_llama_channel_config(model, channel_config, "qk")

# TODO: fix multi-gpu for offloading
model = convert_kvcache_llama_offloading(model, config, 128, 2, 2, [0,1,31])
model = convert_llama_offloading_channel_config(model, channel_config, "qk")

# model = convert_kvcache_mistral_heavy_recent(model, config, 128, 4, 4)
# model = convert_mistral_channel_config(model, channel_config, "qk") #TODO: no k outlier for gqa

# model = convert_kvcache_mistral_offloading(model, config, 128, 2, 4, [0,1,31])
# model = convert_mistral_offloading_channel_config(model, channel_config, "qk")

# model = convert_kvcache_mixtral_heavy_recent(model, config, 128, 4, 4)
# model = convert_mixtral_channel_config(model, channel_config, "q") #TODO: no k outlier for gqa

# model = convert_streaming(model, config, 128, 4)

# model = convert_rtn(model, config, 2)

model.eval()

gc.collect()
torch.cuda.empty_cache()

print(evaluate(model, tokenizer))


