import sys
sys.path.append(".")

import argparse
import os
import time
import json
import logging
import pprint
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset

def average_excluding_min_max(numbers):
    if len(numbers) <= 2:
        raise ValueError("The list must contain more than two elements.")
    
    numbers_excluding_min_max = numbers.copy()
    numbers_excluding_min_max.remove(min(numbers))
    numbers_excluding_min_max.remove(max(numbers))

    return sum(numbers_excluding_min_max) / len(numbers_excluding_min_max)

def main(args):
    set_seed(args.seed)
    args.mode = args.method
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    if args.method.lower() != 'fullkv':
        from score_kv.monkeypatch import replace_llama, replace_mistral 
        replace_llama(args.method)
        replace_mistral(args.method)

    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        use_cache=args.use_cache,
        attn_implementation=args.attn_implementation
    )

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model.eval()

    model.model.config.window_size       =  8
    model.model.config.base_capacity     = args.max_capacity_prompt
    model.model.config.kernel_size       = args.kernel_size        
    model.model.config.skip              = args.skip               
    model.model.config.normalize         = True                    
    model.model.config.pooling           = args.pooling            
    model.model.config.floor             = args.floor_alpha        
            
    input_id = torch.ones((1,args.seqlen), dtype=torch.int64).to(model.device)
    attn_mask = torch.ones((1,args.seqlen), dtype=torch.int64).to(model.device)
    context_length = input_id.shape[-1]

    if args.num_warmups > 0:
        for i in range(args.num_warmups):

            total_time = 0
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            if args.mode in []:
                with torch.no_grad():
                    outputs = model(input_id, attn_mask)
                    
                    past_key_values = outputs.past_key_values
                    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                    generated_ids = [pred_token_idx.item()]

                    for _ in range(args.genlen-1):
                        start.record()
                        outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values)
                        end.record()
                        torch.cuda.synchronize()
                        total_time += start.elapsed_time(end)
                        past_key_values = outputs.past_key_values
                        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                        generated_ids.append(pred_token_idx.item())

                generation_length = len(generated_ids)
                throughput = (args.genlen-1) / (total_time / 1000)
                response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
                del outputs
                del past_key_values
            else:
                raise ValueError(f"We does not support {args.mode} mode")

    result_list = []
    for i in range(args.num_runs):

        total_time = 0
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        if args.mode in []:
            with torch.no_grad():
                outputs = model(input_id, attn_mask)
                
                past_key_values = outputs.past_key_values
                pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                generated_ids = [pred_token_idx.item()]

                for _ in range(args.genlen-1):
                    start.record()
                    outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values)
                    end.record()
                    torch.cuda.synchronize()
                    total_time += start.elapsed_time(end)
                    past_key_values = outputs.past_key_values
                    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                    generated_ids.append(pred_token_idx.item())

            generation_length = len(generated_ids)
            throughput = (args.genlen-1) / (total_time / 1000)
            response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            del outputs
            del past_key_values
        else:
            raise ValueError(f"We does not support {args.mode} mode")
        
        result_list.append(throughput)

    mean_throughput = average_excluding_min_max(result_list)

    print(f"\nMode: {args.mode}")
    print(f"Context Length = {context_length}")
    if args.mode == "fullkv":
        print(f"Context Capacity = {context_length}")
    else:
        print(f"Context Capacity = {args.max_capacity_prompt}")
    print(f"Generation Length = {generation_length}")
    print(f"Throughput: {(mean_throughput):.5f} Tokens/sec")
    print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1000**2 / 1000:.2f} GB\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="model name of model path")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--save_path", default="", type=str, help="Path to save the output")
    
    parser.add_argument("--mode", type=str, default="fastkv")
    parser.add_argument("--window_size", type=int, default=8)
    parser.add_argument("--max_capacity_prompt", type=int, default=512)
    parser.add_argument("--kernel_size", type=int, default=7)
    parser.add_argument("--pooling", type=str, default="maxpool")

    parser.add_argument("--skip", type=int, default=-1)
    parser.add_argument('--floor_alpha', type=float, default=0.2)
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--pyram', action='store_true')
    parser.add_argument('--pyram_beta', default=20,type=int)
    parser.add_argument('--gqa_support', action='store_true')

    parser.add_argument("--seqlen", type=int, default=131072, help="")
    parser.add_argument("--genlen", type=int, default=128, help="")
    parser.add_argument("--num_warmups", type=int, default=2, help="")
    parser.add_argument("--num_runs", type=int, default=10, help="num_runs must be larger than 2")

    parser.add_argument("--attn_implementation", type=str, default="flash_attention_2", choices=["flash_attention_2", "sdpa", "eager"])
    parser.add_argument("--use_cache", type=bool, default=True, help="")
    parser.add_argument("--method", type=str, default=None, help="KV cache compression method")
    args = parser.parse_args()

    main(args)
