import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["OMP_NUM_THREADS"] = "4"

import argparse
import json
import logging
import time
from typing import Literal, Tuple

import torch
from inference.generate import Generator, BaseGenerator, SpeculativeGenerator, WETAPGenerator
from model.llama_tree_attn.modeling_llama import LlamaForCausalLM
from model.llama_tree_attn.tokenization_llama import LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaConfig
from tqdm import tqdm
import random
import numpy as np
from collections import OrderedDict
from utils import extract_first_function
from safetensors.torch import load_file
from utils import get_score
import subprocess
from torch.profiler import profile, record_function, ProfilerActivity

import json
import numpy as np
from pathlib import Path


def ensure_json_file_exists(file_path, default_content=None):

    path = Path(file_path)

    path.parent.mkdir(parents=True, exist_ok=True)

    if not path.exists():
        print(f"Have Established: {file_path}")
    else:
        print(f"Have Existed: {file_path}")

    return path


class MetricsLogger:
    def __init__(self, json_file='metrics.json'):
        self.json_file = json_file

        if not os.path.exists(json_file):
            ensure_json_file_exists(json_file)
            with open(json_file, 'w') as f:
                json.dump([], f)

    def log_sample_metrics(self, sample_id, metrics):


        with open(self.json_file, 'r') as f:
            data = json.load(f)


        sample_data = {
            'sample_id': sample_id,
            'metrics': metrics,
        }
        data.append(sample_data)


        with open(self.json_file, 'w') as f:
            json.dump(data, f, indent=2)

    def calculate_mean_metrics(self):

        with open(self.json_file, 'r') as f:
            data = json.load(f)

        if not data:
            return {}


        all_metrics = {}
        for sample in data:
            for metric_name, value in sample['metrics'].items():
                if metric_name not in all_metrics:
                    all_metrics[metric_name] = []
                all_metrics[metric_name].append(value)


        mean_metrics = {
            metric_name: np.mean(values)
            for metric_name, values in all_metrics.items()
        }

        return mean_metrics


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
                       
    # Ensure deterministic behavior in PyTorch operations (if possible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)
torch.set_num_threads(4)

def get_dataset(dataset_name, path_to_data="./data"):

    if dataset_name in ["oasst_prompts", "wikitext_prompts", "mtbench_prompts"]:
        try:
            file_path = f"{path_to_data}/{dataset_name}.json"
            with open(file_path, "r") as f:
                dataset = json.load(f)
            dataset = [x[1] for x in dataset]
        except FileNotFoundError:
            raise FileNotFoundError(f"Missing `data/{dataset_name}.json` file.")

    elif dataset_name in ["humaneval_prompts","gsm8k_prompts"]:

        try:
            file_path = f"{path_to_data}/{dataset_name}.json"
            with open(file_path, "r") as f:
                dataset = json.load(f)
            dataset = [ x["turns"][0] for x in dataset]
        except FileNotFoundError:
            raise FileNotFoundError(f"Missing `data/{dataset_name}.json` file.")

    else:
        raise FileNotFoundError(f"Missing `data/{dataset_name}.json` file.")

    return dataset

def run_eval(
    draft_model,
    target_model,
    target_model_name,
    tokenizer,
    dataset,
    dataloader,
    k_config: Tuple[int],
    max_new_tokens: int = 1024,
    replacement=False,
    speculative_sampling=True,
    tree_attn=True,
    sampling_type: Literal["argmax", "sampling"] = "sampling",
    disable_tqdm: bool = False,
    wetap: bool = True,
    trim_type: Literal["prob", "prob&len"] = "prob",
    dynwidth :bool = False,
    width_list = None,
    beam_width: int = 64,
    max_budget: int = 64,
    accept_thres: float = 0.5,
    top_k: int = 10,
    top_p: float = 0.9,
):
    if sampling_type not in ["argmax", "sampling"]:
        raise ValueError(
            f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"'
        )

    if sampling_type == "argmax":
        target_model_temp = 0
        draft_model_temp = 0
    else:
        target_model_temp = 1
        draft_model_temp = 1

    if wetap:
        generator = WETAPGenerator(
            draft_model,
            target_model,
            eos_token_id=tokenizer.eos_token_id,
            k_config=k_config,
            beam_width=beam_width,
            max_budget=max_budget,
            max_new_tokens=max_new_tokens,
            accept_thres=accept_thres,
            trim_type=trim_type,
            dynwidth=dynwidth,
            width_list=width_list,
            draft_model_temp=draft_model_temp,
            target_model_temp=target_model_temp,
            replacement=replacement,
            speculative_sampling=speculative_sampling,
            top_k=top_k,
            top_p=top_p,
        )
    else:
        generator = SpeculativeGenerator(
            draft_model,
            target_model,
            eos_token_id=tokenizer.eos_token_id,
            k_config=k_config,
            max_new_tokens=max_new_tokens,
            draft_model_temp=draft_model_temp,
            target_model_temp=target_model_temp,
            replacement=replacement,
            speculative_sampling=speculative_sampling,
            tree_attn=tree_attn,
            top_k = top_k,
            top_p = top_p,
        )

    draft_model.eval()
    target_model.eval()

    logger.info("evaluation start.")
    start_time = time.time()

    acceptance_count = 0
    draft_token_count = 0
    invocation_count = 0

    iterator = range(len(dataloader))
    pred_seq = []
    score_list = []

    detailed_runtime = 0

    gen_type = "wetap" if wetap else "spd"

    mapping = {
        "Llama-2-7b": "llama2-7b",
        "Llama-2-13b": "llama2-13b",
        "Vicuna-7b": "vicuna-7b",
        "Vicuna-13b": "vicuna-13b",
    }

    model_name = None
    for k, v in mapping.items():
        if k in target_model_name:
            model_name = v
            break

    if model_name is None:
        raise ValueError(f"Unknown model type: {target_model_name}")

    loggerpath = f'/root/CKPT/WETAP/{gen_type}/{model_name}/{dataset}/len{len(k_config)}-{draft_model_temp}-{target_model_temp}.json'
    Logger = MetricsLogger(loggerpath)

    with torch.no_grad():

            gen_rate_per_ques = []
            speed_per_ques = []
            accept_heatmap = np.zeros( (len(k_config), beam_width) )
            accept_idx = np.zeros( (len(dataloader), len(k_config) )  )
            for sample_idx in iterator if disable_tqdm else tqdm(iterator):
                if sample_idx < 0  :
                    continue
                else:
                    pass

                print(f'Now is Num.{sample_idx+1} question')
                prompt_text = dataloader[sample_idx]
                inputs = tokenizer(prompt_text, return_tensors="pt", max_length=2048, truncation=True).to("cuda:0")
                input_ids = inputs.input_ids
                input_len = input_ids.size(-1)

                s = time.time()
                output = generator.generate(input_ids)  #, heatmap_per_sample, accept_idx_per_sample

                runtime = time.time() - s
                detailed_runtime += runtime

                #accept_idx[sample_idx,:] = accept_idx_per_sample

                acceptance_count += output.acceptance_count
                draft_token_count += output.draft_token_count
                invocation_count += output.invocation_count
                gen_rate_per_ques.append( (output.acceptance_count + output.invocation_count) / output.invocation_count )
                speed_per_ques.append( (output.acceptance_count + output.invocation_count) / runtime )
                #accept_heatmap += heatmap_per_sample

                #print(f'Generated Length is :{round( (output.acceptance_count + output.invocation_count) / output.invocation_count , 2)}')
                #print(f'Speed is :{round((output.acceptance_count + output.invocation_count) / runtime, 2)}')

                if dataset in ['humaneval_prompts','mtbench_prompts','gsm8k_prompts']:
                    string = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
                    pred_seq.append(extract_first_function(string))

                score = get_score(output.sequences, target_model, input_len)
                score_list.append(score.item())
                #print(f'Score: {score}')

                sample_metrics = {"gen_rate":round(gen_rate_per_ques[-1], 2),"speed":round(speed_per_ques[-1], 2),"score":round(score.item(), 2)}

                Logger.log_sample_metrics(f'sample_{sample_idx+1}', sample_metrics)

                torch.cuda.empty_cache()


    end_time = time.time()

    logger.info("evaluation complete.")

    run_time = end_time - start_time

    print("Running time: {:.2f} s".format(run_time))
    print("Avg Score: {:.4f}".format(np.mean(score_list)))
    print("PPL: {:.4f}".format(np.exp(-np.mean(score_list))))
    print("Avg generate rate: {:.2f}".format(np.mean(gen_rate_per_ques)))
    print("Avg speed: {:.2f} token/s".format(np.mean(speed_per_ques)))

    return pred_seq


def run_baseline_eval(
    target_model,
    tokenizer,
    dataloader,
    max_new_tokens: int = 1024,
    sampling_type: Literal["argmax", "sampling"] = "sampling",
    disable_tqdm: bool = False,
    top_k: int = 10,
    top_p: float = 0.9,
):
    if sampling_type not in ["argmax", "sampling"]:
        raise ValueError(
            f'`sampling_type` can be either `"argmax"` or `"sampling"`, but received "{sampling_type}"'
        )
    if sampling_type == "argmax":
        target_model_temp = 0
    else:
        target_model_temp = 1

    generator = BaseGenerator(
        target_model,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=max_new_tokens,
        temp=target_model_temp,
        top_k = top_k,
        top_p = top_p,
    )

    target_model.eval()

    logger.info("evaluation start.")
    start_time = time.time()

    invocation_count = 0

    iterator = range(len(dataloader))
    with torch.no_grad():
        for sample_idx in iterator if disable_tqdm else tqdm(iterator):
            prompt_text = dataloader[sample_idx]
            inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
            input_ids = inputs.input_ids
            output = generator.generate(input_ids)

            invocation_count += output.invocation_count
    end_time = time.time()

    logger.info("evaluation complete.")

    run_time = end_time - start_time

    latency = run_time / invocation_count

    logger.info("Running time: {:.2f} s".format(run_time))
    logger.info("Token latency: {:.2f} ms".format(latency * 1000))


def main(args):


    set_seed(args.seed)  # Set a fixed seed
    torch_dtype = torch.float16 if args.fp16 else torch.float32


    if args.auto_model and not args.disable_tree_attn:
        logger.warning(
            "Tree Attn is currently not supported for models other than LLaMA. Therefore, "
            "when using '--auto-model', Tree Attn will be disabled."
        )
        args.disable_tree_attn = True


    if args.dataset == 'mtbench_prompts':
        dataloader = get_dataset('mtbench_prompts')
        print(f'Have loaded，Length:{len(dataloader)}')
    elif args.dataset == 'humaneval_prompts':
        dataloader = get_dataset('humaneval_prompts')
        print(f'Have loaded，Length: {len(dataloader)}')
    elif args.dataset == 'gsm8k_prompts':
        dataloader = get_dataset('gsm8k_prompts')
        print(f'Have loaded，Length: {len(dataloader)}')

    else:
        raise NotImplementedError

    ModelLoader = AutoModelForCausalLM if args.auto_model else LlamaForCausalLM
    TokenizerLoader = AutoTokenizer if args.auto_model else LlamaTokenizer

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, lefacy=False)

    logger.info("Loading draft model: {}".format(args.draft_model))
    draft_model = ModelLoader.from_pretrained(
        args.draft_model,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True,
        use_flash_attention_2=True if args.flash_attn else False,
    )

    logger.info("Loading target model: {}".format(args.target_model))
    target_model = ModelLoader.from_pretrained(
        args.target_model,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True,
        use_flash_attention_2=True if args.flash_attn else False,

    )


    if args.run_baseline:
        run_baseline_eval(
            target_model,
            tokenizer=tokenizer,
            dataloader=dataloader,
            max_new_tokens=args.max_new_tokens,
            sampling_type=args.sampling_type,
            disable_tqdm=args.disable_tqdm,
            top_k = args.top_k,
            top_p = args.top_p,
        )
    else:
        pred_seq = run_eval(
            draft_model,
            target_model,
            args.target_model,
            tokenizer=tokenizer,
            dataset=args.dataset,
            dataloader=dataloader,
            k_config=args.k_config,
            beam_width=args.beam_width,
            max_budget=args.max_budget,
            accept_thres=args.accept_thres,
            max_new_tokens=args.max_new_tokens,
            replacement=args.replacement,
            speculative_sampling=not args.naive_sampling,
            tree_attn=True,
            sampling_type=args.sampling_type,
            disable_tqdm=args.disable_tqdm,
            wetap = args.wetap,
            trim_type = args.trim_type,
            dynwidth = args.dynwidth,
            width_list = args.width_list,
            top_k = args.top_k,
            top_p = args.top_p,
        )



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True, help="dataset: mt-bench, human_eval or gsm8k")
    parser.add_argument(
        "--draft-model", type=str, required=True, help="Draft model path."
    )
    parser.add_argument(
        "--target-model", type=str, required=True, help="Target model path."
    )
    parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path.")
    parser.add_argument("--fp16", action="store_true", help="use float16 dtype.")

    parser.add_argument(
        "--k-config",
        type=lambda x: tuple(map(int, x.split(","))),
        required=True,
        help="Use comma separations, e.g. `--k-config 4,1,1`.",
    )

    parser.add_argument(
        "--datapath", type=str, required=True, help="The json data file."
    )
    parser.add_argument("--max-new-tokens", type=int, default=1024)
    parser.add_argument(
        "--replacement",
        action="store_true",
        help="Sampling with replacement.",
    )

    parser.add_argument(
        "--naive-sampling",
        action="store_true",
        help="Use multi-candidate naive sampling.",
    )


    parser.add_argument(
        "--sampling-type", type=str, default="sampling", choices=["argmax", "sampling"]
    )

    parser.add_argument("--disable-tqdm", action="store_true")

    parser.add_argument("--auto-model", action="store_true")
    parser.add_argument("--run-baseline", action="store_true")

    parser.add_argument("--flash-attn", action="store_true")

    #wetap parameters
    parser.add_argument("-wetap", action="store_true")
    parser.add_argument("--max_budget", type=int, default=64)
    parser.add_argument("--trim_type", type=str, default="prob",choices=["prob","prob&len","prob&quota"])
    parser.add_argument("--dynwidth", action="store_true")
    parser.add_argument("--width_list", type=int, nargs="+", help="Width range")

    parser.add_argument("--top-p", type=float, default=0.9)
    parser.add_argument("--top-k", type=int, default=10)

    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    if args.tokenizer is None:
        args.tokenizer = args.draft_model
    main(args)
