import sys
sys.path.append(".")

import argparse

import torch

from transformers import AutoTokenizer, set_seed, DynamicCache
import utils
import json

def average_excluding_min_max(numbers):
    if len(numbers) <= 2:
        raise ValueError("The list must contain more than two elements.")
    
    numbers_excluding_min_max = numbers.copy()
    numbers_excluding_min_max.remove(min(numbers))
    numbers_excluding_min_max.remove(max(numbers))

    return sum(numbers_excluding_min_max) / len(numbers_excluding_min_max)


def main(args):
    set_seed(args.seed)

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model, device_map='auto', trust_remote_code=True)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token

    # Load Model based on mode
    if args.mode == 'take':
        version_num = args.version.split(".")[0]
        tao_kwargs = None
        
        if version_num == "9":
            from take.take.chunk import TakeKwargs
            from take.take.transformers_take.llama.modeling_llama_take import LlamaForCausalLM
            tao_kwargs = TakeKwargs(
                kv_budget=args.kv_budget,
                task_query_len=args.task_query_len,
                kv_warmup_budget=args.kv_warmup_budget,
                kv_prune_trigger_size=args.kv_prune_trigger_size,
                chunk_size=args.chunk_size,
                kernel_size=args.kernel_size,
                pooling=args.pooling,
                warmup_layers=args.warmup_layers,
                chunk_window_size=args.chunk_window_size,
                chunk_sink=args.chunk_sink,
                use_task_cache=args.use_task_cache,
                alpha=args.alpha,
                separators=None,
                test_performance=args.test_performance
            )
        else:
            raise ValueError("Invalid version number for TAO")
            
        model = LlamaForCausalLM.from_pretrained(args.model, device_map='auto', attn_implementation='flash_attention_2',
                                                 torch_dtype=torch.float16, tao_kwargs=tao_kwargs,
                                                 tokenizer_path=args.model)
    elif args.mode == 'full_kv':
        from transformers import LlamaForCausalLM
        model = LlamaForCausalLM.from_pretrained(args.model, device_map='auto', attn_implementation='flash_attention_2', torch_dtype=torch.float16)
    else:
        raise ValueError(f"We does not support {args.mode} mode")
        
    model.eval()

    # Input Sequence
    input_id = torch.ones((1,args.seqlen), dtype=torch.int64).to(model.device)
    attn_mask = torch.ones((1,args.seqlen), dtype=torch.int64).to(model.device)
    context_length = input_id.shape[-1]
    position_ids = torch.arange(context_length, dtype=torch.int64).to(model.device).unsqueeze(0)

    # warmup
    if args.num_warmups > 0:
        for i in range(args.num_warmups):

            total_time = 0
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            past_key_values = DynamicCache()
            # Prefill
            if args.mode in ['full_kv', 'take']:
                with torch.no_grad():
                    start.record()
                    outputs = model(input_id, attn_mask, position_ids=position_ids, past_key_values=past_key_values)
                    end.record()
                torch.cuda.synchronize()
                total_time += start.elapsed_time(end)
                del outputs
            else:
                raise ValueError(f"We does not support {args.mode} mode")

            utils.cleanup_memory()

    result_list = []
    for i in range(args.num_runs):        

        total_time = 0
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        past_key_values = DynamicCache()
        # Prefill
        if args.mode in ['full_kv', 'take']:
            with torch.no_grad():
                start.record()
                outputs = model(input_id, attn_mask, position_ids=position_ids, past_key_values=past_key_values)
                end.record()
            torch.cuda.synchronize()
            total_time += start.elapsed_time(end)
            del outputs
        else:
            raise ValueError(f"We does not support {args.mode} mode")
        
        result_list.append(total_time)
        utils.cleanup_memory()

    mean_ttft = average_excluding_min_max(result_list)

    print(f"\nMode: {args.mode}")
    print(f"Context Length = {context_length}")
    if args.mode == "fullkv":
        print(f"Context Capacity = {context_length}")
    elif args.mode == "take":
        print(f"Context Capacity = {args.kv_budget if hasattr(args, 'kv_budget') else 'N/A'}")
        print(f"TAO Version = {args.version}")
    print(f"TTFT: {(mean_ttft):.5f} msec")
    print(f"Number of Warmup Runs: {args.num_warmups}, Number of Runs: {args.num_runs}, Min & Max values excluded")
    print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1000**2 / 1000:.2f} GB\n")

    # Save results to JSON file
    results = {
        "mode": args.mode,
        "model": args.model,
        "context_length": context_length,
        "ttft_msec": mean_ttft,
        "max_memory_gb": torch.cuda.max_memory_allocated() / 1000 ** 2 / 1000,
        "all_ttft_times": result_list
    }

    # Create filename based on mode
    filename = f"{args.mode}_ttft_prefill_seqlen{args.seqlen}.json"

    # Save to file
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"Results saved to: {filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Model Arguments
    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="model name of model path")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--save_path", default="", type=str, help="Path to save the output")

    # KV Compression
    parser.add_argument("--mode", type=str, default="take", choices=["full_kv", "take"])

    # TAO specific parameters
    parser.add_argument("--version", type=str, default="9.1", help="TAO version, e.g. 1.0, 2.0, 3.0, 4.0")
    parser.add_argument("--use_task_cache", type=bool, default=True)
    parser.add_argument("--kv_prune_trigger_size", type=int, default=4096)
    parser.add_argument("--chunk_size", type=int, default=4096, choices=[1024, 2048, 4096, 8192])
    parser.add_argument("--task_query_len", type=int, default=15)
    parser.add_argument("--warmup_layers", type=int, default=16)
    parser.add_argument("--kv_budget", type=int, default=512, choices=[128, 256, 512, 1024, 2048, 4096, 8192])
    parser.add_argument("--kv_warmup_budget", type=int, default=10000)
    parser.add_argument("--alpha", type=float, default=0.32)
    parser.add_argument("--chunk_window_size", type=int, default=16)
    parser.add_argument("--chunk_sink", type=int, default=16)
    parser.add_argument("--kernel_size", type=int, default=7, choices=[1, 3, 5, 7, 9, 11, 13, 15])
    parser.add_argument("--pooling", type=str, default="avg", choices=["avg", "max"])
    parser.add_argument("--test_performance", type=bool, default=True)

    # Benchmark Option
    parser.add_argument("--seqlen", type=int, default=131072, help="")
    parser.add_argument("--num_warmups", type=int, default=2, help="")
    parser.add_argument("--num_runs", type=int, default=10, help="num_runs must be larger than 2")

    args = parser.parse_args()

    main(args)
    utils.cleanup_memory()
