import argparse
import time
import os
import numpy as np
import random
import os.path as osp
import urllib.request
import json
import torch
import re
from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from fastchat.model import get_conversation_template
from utils import seed_everything, get_memory_info, get_hf_model_path, get_kv_size

from algorithms.h2o import LlamaAttentionH2O
from algorithms.sink import LlamaAttentionSink
from algorithms.hyper import LlamaAttentionHyper
from algorithms.exact import LlamaAttentionExact


def load_testcases(test_file):
    with open(test_file, 'r') as json_file:
        json_list = list(json_file)

    test_cases = []
    for test_case in json_list:
        test_case = json.loads(test_case)
        test_cases.append(test_case)

    return test_cases


@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
    outputs = model(
        input_ids=input_ids,
        past_key_values=past_key_values,
        use_cache=True,
    )
    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    generated_ids = [pred_token_idx.item()]
    pos = 0
    for i in range(max_gen_len - 1):
        outputs = model(
            input_ids=pred_token_idx,
            past_key_values=past_key_values,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids.append(pred_token_idx.item())
        # generated_text = (
        #     tokenizer.decode(
        #         generated_ids,
        #         skip_special_tokens=True,
        #         clean_up_tokenization_spaces=True,
        #         spaces_between_special_tokens=False,
        #     )
        #     .strip()
        #     .split(" ")
        # )
        # print(f"i: {i}, generated: {generated_text}, generated_ids: {generated_ids}, tokens: {tokenizer.convert_ids_to_tokens(generated_ids)}")
        # import pdb; pdb.set_trace();

        # now = len(generated_text) - 1
        # if now > pos:
        #     print(" ".join(generated_text[pos:now]), end=" ", flush=True)
        #     pos = now

        if pred_token_idx == tokenizer.eos_token_id:
            break

    return generated_ids, past_key_values

    # print(" ".join(generated_text[pos:]), flush=True)
    # return past_key_values


def get_output_dir(args):
    path = args.model_name_or_path

    if path[-1] == "/":
        path = path[:-1]
    name = path.split("/")[-1]

    output_dir = f"./evaluation/{args.task}/predictions_orig/{name}"
    # print(f"output to {output_dir}")
    return output_dir


def test_lines_one_sample(model, tokenizer, test_case, output_file, idx, args):
    prompt = test_case["prompt"]
    correct_line = test_case["correct_line"]
    expected_number = test_case["expected_number"]

    if "mosaicml/mpt-7b-storywriter" in args.model_name_or_path:
        from transformers import pipeline
        pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda:0')
        # Use next word prediction to get storywriter answer
        prompt += f'Line <{test_case["random_idx"][0]}>: <REGISTER_CONTENT> is'
        prompt_length = len(tokenizer(prompt).input_ids)
        with torch.autocast('cuda', dtype=torch.bfloat16):
            output = pipe(prompt, max_new_tokens=15, do_sample=True, use_cache=True)[0]['generated_text'][len(prompt):]
    elif args.model_name_or_path in ["THUDM/chatglm2-6b" , "THUDM/chatglm3-6b" ]:
        prompt_length = len(tokenizer(prompt).input_ids)
        output, _ = model.chat(tokenizer, prompt, history=[], max_length=16384)
    elif "gpt-" in args.model_name_or_path:
        prompt_length, output = retrieve_from_openai(prompt, args.model_name_or_path)
    elif "claude" in args.model_name_or_path:
        prompt_length, output = retrieve_from_anthropic(prompt, args.model_name_or_path)
    else:
        if "longchat" in args.model_name_or_path:
            conv = get_conversation_template("vicuna")
        else:
            conv = get_conversation_template(args.model_name_or_path)
        # print(f"Using conversation template: {conv.name}")

        if "mosaicml/mpt-30b-chat" in args.model_name_or_path:
            prompt += f'Answer in the format <{test_case["random_idx"][0]}> <REGISTER_CONTENT>.'
        
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input = tokenizer(prompt, return_tensors="pt")
        prompt_length = input.input_ids.shape[-1]
        
        # Disable use_cache if using longchat models with flash attention
        use_cache = not ("longchat" in args.model_name_or_path and args.longchat_flash_attn)

        device = getattr(model, "device", "cpu")

        # print("\n\nbefore " + get_memory_info())

        past_key_values = None
        max_gen_len = 100
        output, past_key_values = greedy_generate(model, tokenizer, input.input_ids.to(device), past_key_values, max_gen_len)
        output = tokenizer.batch_decode([output], skip_special_tokens=True)[0]

        # print("after 1 " + get_memory_info())

        # del past_key_values
        # torch.cuda.empty_cache()
        # print("after 2 " + get_memory_info())

        # for i in range(32):
        #     del model.model.layers[i].self_attn.rotary_emb.sin_cached
        #     del model.model.layers[i].self_attn.rotary_emb.cos_cached
        # torch.cuda.empty_cache()
        # print("after 3 " + get_memory_info())

        # output = model.generate(input.input_ids.to(device), max_new_tokens=100, use_cache=use_cache)[0]
        # output = output[prompt_length:]
        # output = tokenizer.batch_decode([output], skip_special_tokens=True)[0]

    if args.kv_cache_method in ['h2o', 'kmeans', 'kcenter', 'weighted_kcenter']:
        for i in range(len(model.model.layers)):
            model.model.layers[i].self_attn._clean_cache()

    # print(f"[1 - {idx}] {get_memory_info()}\n")

    # Matching the last digit of the model output
    response_number = re.findall("\d+", output)
    if response_number is not None and len(response_number) > 0:
        response_number = int(response_number[-1])
    else:
        print(f"Got unparsable result")
        response_number = -1

    summary = f"Label: {expected_number}, Predict: {output}, Parsed: {response_number}, length: {prompt_length}".replace('\n', ' ')
    if past_key_values is not None:
        summary += f", kv_size: {list(past_key_values[0][0].size())}, kv_memory: {get_kv_size(past_key_values):.3f}"
        del past_key_values

    summary += f", {get_memory_info()}"

    print(summary)
    if not args.debug:
        if idx ==0:
            with open(output_file, "w") as f:
                f.write(summary)
                f.write("\n")
        else:
            with open(output_file, "a+") as f:
                f.write(summary)
                f.write("\n")
    torch.cuda.empty_cache()
    return expected_number == response_number, prompt_length, summary


@torch.no_grad()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name-or-path", type=str, default="lmsys/longchat-7b-v1.5-32k", help="model path")
    # parser.add_argument("--task", type=str, required=True, help="Which evaluation task to use. currently support [topics, lines]")
    # parser.add_argument("--num_gpus", type=int, default=1, help="number of gpus to use")
    parser.add_argument("--seed", type=int, default=1, help="random seed")
    parser.add_argument("--max_gpu_memory", type=int, default=40, help="max per gpu memory in GiB. A100 is 40 or 80.")
    parser.add_argument("--longchat_flash_attn", action='store_true', help="Only apply to longchat models. Whether to enable flash attention to save memory, but slower.")
    parser.add_argument("--longchat_ratio", type=int, default=8, help="Only apply to longchat models. Use ratio=8 for 16K context length model. Only ratio=8 is supported now.")
    # parser.add_argument("--eval_shortest_only", action='store_true', default=0, help="Only eval the shortest case for illustration purpose")
    parser.add_argument("--kv_cache_method", type=str, default="exact", help="KV cache method")
    parser.add_argument("--framework", type=str, default=None, help="Framework for serving")
    parser.add_argument("--task", type=str, default="lines", help="either lines or xxx")

    parser.add_argument("--hh_size", type=int, default=8, help="size of heavy-hitter to keep")
    parser.add_argument("--recent_size", type=int, default=2048, help="size of recent key/value to keep")
    
    parser.add_argument("--debug", action='store_true')
    args = parser.parse_args()

    seed_everything(args.seed)
    output_dir = get_output_dir(args)
    if args.kv_cache_method in ['sink', 'h2o', 'kcenter', 'kmeans', 'weighted_kcenter']:
        output_dir = output_dir + f"-{args.kv_cache_method}-h{args.hh_size}_r{args.recent_size}"
    else:
        output_dir = output_dir + f"-{args.kv_cache_method}"
    device = "cuda"

    if not os.path.exists(output_dir) and not args.debug:
        os.makedirs(output_dir, exist_ok=True)

    print(f"output_dir: {output_dir}")

    tic0 = time.time()
    chpt_path = get_hf_model_path(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    
    print(f"define the model | ", end=''); tic = time.time()
    # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True).half().cuda() #.to(device, dtype=torch.bfloat16)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
        print(f"kv_cache_method: {args.kv_cache_method}, ", end='')
        if args.kv_cache_method == 'exact':
            for i in range(config.num_hidden_layers):
                model.model.layers[i].self_attn = LlamaAttentionExact(config, i)

        elif args.kv_cache_method == 'h2o':
            print(f"hh_size: {args.hh_size}, recent_size: {args.recent_size}")
            config.hh_size = args.hh_size
            config.recent_size = args.recent_size
            config.cache_size = config.hh_size * 2 + config.recent_size
            for i in range(config.num_hidden_layers):
                model.model.layers[i].self_attn = LlamaAttentionH2O(config, i)
        
        elif args.kv_cache_method == 'sink':
            print(f"start_size: {args.hh_size}, recent_size: {args.recent_size}")
            config.hh_size = args.hh_size
            config.recent_size = args.recent_size
            config.cache_size = config.hh_size * 2 + config.recent_size
            for i in range(config.num_hidden_layers):
                model.model.layers[i].self_attn = LlamaAttentionSink(config, i)

        elif args.kv_cache_method in ['kcenter', 'kmeans', 'weighted_kcenter']:
            print(f"cluster_size: {args.hh_size}, recent_size: {args.recent_size}")
            config.hh_size = args.hh_size
            config.recent_size = args.recent_size
            config.cache_size = config.hh_size * 2 + config.recent_size
            config.method = args.kv_cache_method
            for i in range(config.num_hidden_layers):
                model.model.layers[i].self_attn = LlamaAttentionHyper(config, i)

        else:
            raise NotImplementedError

    model = load_checkpoint_and_dispatch(model, checkpoint=chpt_path, device_map="auto", dtype=torch.float16)
    model = model.eval()
    print(f"time: {time.time() - tic:.4f} sec")

    print(f"[1] {get_memory_info()}")
    # for num_lines in [200, 300, 400, 500, 600, 680]:
    for num_lines in [200]:

        print(f"************ Start testing {num_lines} lines per LRT prompt ************")
        test_file = f"./datasets/lines/testcases_orig/{num_lines}_lines.jsonl"
        
        output_file = os.path.join(output_dir, f"{num_lines}_response.txt")
        num_correct = 0
        avg_length = 0

        test_cases = load_testcases(test_file)
        pbar = tqdm(range(len(test_cases)))
        for idx in pbar:
            test_case = test_cases[idx]
            correct, prompt_length, summary = test_lines_one_sample(model=model, tokenizer=tokenizer, test_case=test_case, output_file=output_file, idx=idx, args=args)
            avg_length += prompt_length / len(test_cases)
            num_correct += correct
            pbar.set_description(f"{idx}/{len(test_cases)}, correct: {correct}, n_correct: {num_correct}, acc: {num_correct / len(test_cases)}")
        accuracy = num_correct / len(test_cases)
        if not args.debug:
            with open(output_file, "a+") as f:
                f.write(f"Accuracy: {accuracy}")

        print(f"************ Finish testing {num_lines} lines per prompt with average prompt length {avg_length}, accuracy: {accuracy} ************")

if __name__ == "__main__":
    main()
    