import os
import sys
import json
import torch
import argparse
from termcolor import colored
from transformers import AutoTokenizer
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(PROJECT_ROOT)
from model_hub import LlamaModel



def parse_args():
    parser = argparse.ArgumentParser(description="Test example")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device")
    parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], help="Dtype")
    parser.add_argument("--attn_type", type=str, default="Retr_Attn_V1",                                                            \
                        choices=["Retr_Attn_V1"],                                                                                   \
                        help="Attention method")
    parser.add_argument("--model_name", type=str, default="gradientai/Llama-3-8B-Instruct-262k",                                    \
                        choices=["gradientai/Llama-3-8B-Instruct-262k", "01-ai/Yi-9B-200K", "01-ai/Yi-6B-200K"], 
                        help="huggingface model name")
    parser.add_argument("--task_name", type=str, default="needle", choices=["multivalue", "passkey", "needle"],                     \
                        help="Test task name")
    parser.add_argument("--context_len", type=int, default=16000, help="Input context length")
    args = parser.parse_args()
    
    return args


def load_model(model_name, max_len, dtype, device):
    if 'Llama' in model_name or 'Yi' in model_name:
        llm = LlamaModel(model_name,
            max_length=max_len,
            dtype=dtype,
            device_map=device)
    else:
        raise ValueError(f"Unsupported model: {model_name}")
    
    return llm


if __name__ == "__main__":
    args = parse_args()

    model_name = args.model_name
    batch_size = args.batch_size
    attn_type = args.attn_type
    dtype = torch.float16 if args.dtype=='fp16' else torch.bfloat16
    device = args.device
    task_name = args.task_name

    if task_name == "multivalue":
        TEST_DIR = os.path.join(PROJECT_ROOT, "test")
        DATA_NAME = 'niah_multivalue_test.json'
        TEST_FILE = os.path.join(TEST_DIR, DATA_NAME)
        data = json.load(open(TEST_FILE))
        prompt = data['input']
        groundtruth = data['outputs']
    elif task_name == "passkey":
        TEST_DIR = os.path.join(PROJECT_ROOT, "test")
        DATA_NAME = 'passkey_test.json'
        TEST_FILE = os.path.join(TEST_DIR, DATA_NAME)
        data = json.load(open(TEST_FILE))
        prompt = data['context']
        groundtruth = data['answer']
    elif task_name == "needle":
        real_model_name = model_name.split("/")[1]
        TEST_DIR = os.path.join(PROJECT_ROOT, "test")
        TEST_FILE = os.path.join(TEST_DIR, f"needle/needle_test_data/{real_model_name}/{args.context_len}.json")
        data = json.load(open(TEST_FILE))[0]
        prompt = data['input']
        groundtruth = data['answer']
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs_ids = tokenizer(prompt, return_tensors="pt")
    input_ids = torch.cat([inputs_ids.input_ids for i in range(batch_size)], dim=0)

    input_len = input_ids.shape[1]
    gen_len = 10

    max_len = input_len + gen_len

    llm = load_model(model_name, max_len, dtype, device)

    print(colored(f"Input length: {input_len}", 'yellow'))
    print(colored(f"Gen length: {gen_len}", 'yellow'))
    
    out = llm.generate(attention_type=attn_type,
        inputs_ids = input_ids[:, :input_len].to(llm.layers[0].device),
        max_new_length=gen_len)

    result = tokenizer.batch_decode(out, skip_special_tokens=True)

    print(groundtruth)
    print(result)