import argparse
import os
import re
import json
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoConfig
import numpy as np

from fastchat.model import get_conversation_template
from modify_llama import convert_kvcache_llama_heavy_recent, convert_llama_channel_config


if __name__ == "__main__":

    # llama-2-7b: 5.47
    # model_path = "meta-llama/Llama-2-7b-hf"
    # channel_path = "llama2-7b-channel-config.json"
    # channel_path = "llama2-7b-qk-channel-config.json"

    # llama-2-7b-chat: 6.94
    model_path = "meta-llama/Llama-2-7b-chat-hf"
    # channel_path = "llama2-7b-chat-channel-config.json"
    channel_path = "llama2-7b-chat-qk-channel-config.json"

    # llama-7b: 5.68
    # model_path = "/home/ec2-user/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16"
    # channel_path = "llama-7b-channel-config.json"
    # channel_path = "llama-7b-qk-channel-config.json"

    # opt-6.7b: 10.86
    # model_path = "/home/ec2-user/.cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0"

    # model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()
    model = LlamaForCausalLM.from_pretrained(model_path).half().cuda()
    # tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer = LlamaTokenizer.from_pretrained(model_path)

    config = AutoConfig.from_pretrained(model_path)

    channel_config = None
    with open(channel_path, "r") as f:
        channel_config = json.load(f)

    # double sparsity
    model = convert_kvcache_llama_heavy_recent(model, config, 16, 2, 4)
    model = convert_llama_channel_config(model, channel_config, "qk")

    conv = get_conversation_template(model_path)

    while True:
        print(f"{conv.roles[0]}:", end="")
        inp = input()
        if inp == "quit":
            break
        
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)

        prompt = conv.get_prompt()

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        prompt_length = input_ids.shape[-1]
        output = model.generate(input_ids, do_sample=True, max_new_tokens=2048-prompt_length, use_cache=True)[0]
        
        output = output[prompt_length:]
        output = tokenizer.batch_decode([output], skip_special_tokens=True)[0]

        print(f"{conv.roles[1]}:{output}")
        conv.update_last_message(output)
