import os
import json
import time
import argparse
from pathlib import Path
from collections import defaultdict

from accelerate.utils import set_seed

import torch
import torch.nn as nn

from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset

from models import LlamaDraftForCausalLM, LlamaForCausalLM
from utils import Timer
from preprocess import get_tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache


def load_questions(question_file: str, begin, end):
    """Load questions from a file."""
    questions = []
    with open(question_file, "r") as ques_file:
        for line in ques_file:
            if line:
                questions.append(json.loads(line))
    questions = questions[begin:end]
    return questions


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ea-model-path",
        type=str,
        default="checkpoint-60000",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="meta-llama/Meta-Llama-3-8B-Instruct",
        help="The model ID",
    )
    parser.add_argument(
        "--dataset_id",
        type=str,
        choices=["mt_bench", "alpaca", "gsm8k", "humaneval", "qa", "sum", "sharegpt"],
        default="mt_bench",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Debug mode",
    )
    parser.add_argument(
        "--original",
        action="store_true",
        help="Original mode",
    )
    parser.add_argument(
        "--vanilla",
        action="store_true",
        help="vanilla mode",
    )
    parser.add_argument(
        "--vanilla-model-path",
        type=str,
        default="JackFram/llama-68m",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument(
        "--eagle",
        action="store_true",
        help="eagle mode",
    )
    parser.add_argument(
        "--output_file",
        type=Path,
        default="output/output.txt",
    )
    parser.add_argument(
        "--top_base",
        type=int,
        default=60,
    )
    parser.add_argument(
        "--top_draft",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--top_node",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=7,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.0,
    )
    parser.add_argument(
        "--token_freq_json",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--token_freq",
        type=float,
        default=0.8,
    )

    args = parser.parse_args()

    valid_tokens = None
    if args.token_freq_json is not None:
        token_freq = json.load(open(args.token_freq_json))
        tokens = [(v, k) for k, v in token_freq.items()]
        tokens.sort(reverse=True)
        valid_tokens = []
        total_token_cnts = sum(token_freq.values())
        token_cnts = 0
        for cnt, token in tokens:
            token_cnts += cnt
            valid_tokens.append(token)
            if token_cnts / total_token_cnts >= args.token_freq:
                break
        print(f"Total tokens: {len(token_freq)}")
        print(f"Valid tokens: {len(valid_tokens)}")
        print(f"Total token cnts: {total_token_cnts}")
        print(f"Valid token cnts: {token_cnts}")
        print(f"Valid token count ratio: {token_cnts / total_token_cnts}")
        print(f"Valid token ratio: {len(valid_tokens) / len(token_freq)}")

    with Timer("Loading dataset..."):
        dataset_file_path = f"data/{args.dataset_id}/question.jsonl"
        dataset = load_questions(dataset_file_path, None, None)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    with Timer("Loading model..."):
        model = LlamaForCausalLM.from_pretrained(
            args.model_id,
            torch_dtype="float16",
            device_map=device,
        )
        base_model = model
        if not args.debug and not args.original and not args.vanilla:
            model = LlamaDraftForCausalLM.from_pretrained(
                args.ea_model_path,
                torch_dtype="float16",
                device_map=device,
            )
            model.set_base_model(base_model.model)
            model.lm_head = base_model.lm_head
            model.draft_lm_head = base_model.lm_head

        print(base_model.dtype)
        print(model.dtype)

        if args.vanilla:
            assistant_model = LlamaForCausalLM.from_pretrained(
                args.vanilla_model_path,
                torch_dtype="float16",
                device_map=device,
            )
            assistant_tokenizer = AutoTokenizer.from_pretrained(
                args.vanilla_model_path,
            )

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_id)

    model.eval()
    print('Check model training state:', model.training)

    total_input_tokens = 0
    total_output_tokens = 0
    total_iterations = 0
    total_avg_output_tokens = []
    cnts = defaultdict(int)

    args.output_file.parent.mkdir(parents=True, exist_ok=True)
    with args.output_file.open("w") as f:
        for ith_conv, conversation in enumerate(dataset):
            data = {
                "conversation": conversation["turns"],
                "responses": [],
                "times": [],
                "iterations": [],
                "num_tokens": [],
            }
            messages = [{
                "role": "system",
                "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            }]
            for idx, turn in enumerate(conversation["turns"]):
                messages.append({
                    "role": "user",
                    "content": turn,
                })

                input_text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(
                    input_text,
                    return_tensors="pt",
                    add_special_tokens=False
                ).to(model.device)

                if device == "cuda":
                    torch.cuda.synchronize()
                start_time = time.time()

                if args.original:
                    output_ids = model.generate(
                        **inputs,
                        do_sample=False,
                        max_new_tokens=1024,
                        top_k=None,
                        top_p=None,
                        temperature=args.temperature,
                    )
                    output_ids = output_ids[:, inputs.input_ids.size(1):]
                    num_iteration = 1
                    draft_cnts = {}
                elif args.vanilla:
                    assistant_model.generation_config.num_assistant_tokens = args.depth
                    output_ids = model.generate(
                        **inputs,
                        do_sample=False,
                        max_new_tokens=1024,
                        top_k=None,
                        top_p=None,
                        temperature=args.temperature,
                        assistant_model=assistant_model,
                        tokenizer=tokenizer,
                        assistant_tokenizer=assistant_tokenizer if tokenizer.vocab_size != assistant_tokenizer.vocab_size else None,
                        use_cache=True,
                    )
                    output_ids = output_ids[:, inputs.input_ids.size(1):]
                    num_iteration = 1
                    draft_cnts = {}
                else:
                    output_ids, num_iteration, draft_cnts = model.eagenerate(
                        **inputs,
                        max_new_tokens=1024,
                        eos_token_id=tokenizer.eos_token_id,
                        temperature=args.temperature,
                        threshold=args.threshold,
                        tokenizer=tokenizer,
                        top_base=args.top_base,
                        top_draft=args.top_draft,
                        top_node=args.top_node,
                        depth=args.depth,
                        return_iter=True,
                        return_stats=True,
                        eagle=args.eagle,
                    )

                if device == "cuda":
                    torch.cuda.synchronize()
                end_time = time.time()

                total_input_tokens += inputs.input_ids.numel()
                total_output_tokens += len(output_ids[0])
                total_iterations += num_iteration
                for k, v in draft_cnts.items():
                    cnts[k] += v

                output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

                print(f"Conversation {idx + 1} of {ith_conv + 1}")
                print(f"Total input tokens: {total_input_tokens}")
                print(f"Total output tokens: {total_output_tokens}")
                print(f"Total iterations: {total_iterations}")
                print(f"Average output tokens per iter: {total_output_tokens / total_iterations}")
                total_avg_output_tokens.append(len(output_ids[0]) / num_iteration)
                print(f"Average output tokens: {sum(total_avg_output_tokens) / len(total_avg_output_tokens)}")

                messages.append({
                    "role": "assistant",
                    "content": output_text,
                })

                data["responses"].append(output_text)
                data["times"].append(end_time - start_time)
                data["iterations"].append(num_iteration)
                data["num_tokens"].append(len(output_ids[0]))

            f.write(json.dumps(data) + "\n")
            f.flush()
