import shy
shy.err_hook()

import os
import json
import time
import argparse
from collections import defaultdict

from accelerate.utils import set_seed

import torch
import torch.nn as nn

from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset

from models import LlamaDraftForCausalLM, LlamaForCausalLM
from utils import Timer
from preprocess import get_tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ea-model-path",
        type=str,
        default="checkpoint-60000",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="meta-llama/Meta-Llama-3-8B-Instruct",
        help="The model ID",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Debug mode",
    )

    args = parser.parse_args()

    with Timer("Loading model..."):
        model = LlamaForCausalLM.from_pretrained(
            args.model_id,
            torch_dtype="float16",
            device_map="cuda",
        )
        base_model = model
        if not args.debug:
            model = LlamaDraftForCausalLM.from_pretrained(
                args.ea_model_path,
                torch_dtype="float16",
                device_map="cuda",
            )
            model.set_base_model(base_model.model)
            model.lm_head = base_model.lm_head

        print(base_model.dtype)
        print(model.dtype)

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_id)

    model.eval()
    print('Check model training state:', model.training)

    with Timer("Loading dataset..."):
        dataset = load_dataset(
            "Aeala/ShareGPT_Vicuna_unfiltered",
            data_files={
                "valid": "ShareGPT_V4.3_unfiltered_cleaned_split.json",
            },
            split="valid[90%:]" if not args.debug else "valid[99%:]",
        )

    conversations = dataset["conversations"]
    role_mapper = {
        "human": "user",
        "gpt": "assistant",
    }

    total_input_tokens = 0
    total_output_tokens = 0
    total_iterations = 0
    cnts = defaultdict(int)

    for ith_conv, conversation in enumerate(conversations):
        messages = [{
            "role": "system",
            "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
        }]
        for idx, turn in enumerate(conversation):
            role = role_mapper.get(turn["from"])
            if idx == 0 and role == "assistant":
                continue
            content = turn["value"]
            messages.append({
                "role": role,
                "content": content,
            })

        if messages[-1]["role"] == "assistant":
            messages.pop()

        input_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            add_special_tokens=False
        ).to(model.device)

        print(input_text)

        print(inputs.input_ids.shape)

        num_warmup = 1
        num_iter = 1
        for it in range(num_warmup + num_iter):
            if it == num_warmup:
                torch.cuda.synchronize()
                start_time = time.time()
            if args.debug:
                output_ids = model.generate(
                    **inputs,
                    do_sample=False,
                    max_new_tokens=128,
                    temperature=None,
                )
                num_iteration = 0
                draft_cnts = {}
            else:
                output_ids, num_iteration, draft_cnts = model.eagenerate(
                    **inputs,
                    max_new_tokens=2048,
                    eos_token_id=tokenizer.eos_token_id,
                    tokenizer=tokenizer,
                    temperature=0.0,
                    top_base=60,
                    top_draft=8,
                    top_node=8,
                    depth=6,
                    return_iter=True,
                    return_stats=True,
                )

        torch.cuda.synchronize()
        end_time = time.time()
        print('Time:', (end_time - start_time) / num_iter)
        print(len(output_ids[0]))
        print(output_ids)
        draft_output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        print(draft_output_text)

        total_input_tokens += inputs.input_ids.numel()
        total_output_tokens += len(output_ids[0])
        total_iterations += num_iteration
        for k, v in draft_cnts.items():
            cnts[k] += v

        print(f"Conversation {ith_conv + 1}")
        print(f"Total input tokens: {total_input_tokens}")
        print(f"Total output tokens: {total_output_tokens}")
        print(f"Total iterations: {total_iterations}")
        print(f"Average output tokens per iter: {total_output_tokens / total_iterations}")
        keys = sorted(cnts.keys())
        sorted_cnts = [cnts[k] for k in keys]
        sorted_cnts_prob = [sum(sorted_cnts[i+1:]) / sum(sorted_cnts[i:]) for i in range(len(sorted_cnts))]
        print(sorted_cnts_prob)
