import argparse
from fastchat.utils import str_to_torch_dtype
from fastchat.utils import str_to_torch_dtype
from fastchat.llm_judge.common import load_questions
from transformers import AutoModelForCausalLM, AutoTokenizer

from fastchat.model import get_conversation_template
from tqdm import tqdm
import json
import os
import time
import torch
import numpy as np
import shortuuid
import copy
from model.eaglev2.ea_model import EaModel
from model.eaglev2.kv_cache import initialize_past_key_values
from model.eaglev2.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from model.eaglev2.utils import *
from model.eaglev2.choices import *
import random
def reject_sample(
    candidate_input_ids,
    candidate_logits,
    candidate_length,
    new_logits,
):
    new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
    # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
    # selected by the assistant, respectively.
    q = candidate_logits.softmax(dim=-1)
    q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    p = new_logits.softmax(dim=-1)
    p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    probability_ratio = p_i / q_i
    r_i = torch.rand_like(probability_ratio)
    is_accepted = r_i <= probability_ratio
    n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()  # this is `n` in algorithm 1

    return n_matches

def run_eval(
        tgt_model,
        model,
        tokenizer,
        model_id,
        question_file,
        question_begin,
        question_end,
        answer_file,
        max_new_tokens,
        num_choices,
        num_gpus_per_model,
        num_gpus_total,
        tree_choices,
        logits_processor,
        max_steps,
        is_llama3
):
    questions = load_questions(question_file, question_begin, question_end)

    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        import ray
        ray.init()
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            get_model_answers
        ).remote
    else:
        get_answers_func = get_model_answers

    chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)  # // 2
    ans_handles = []
    for i in range(0, len(questions), chunk_size):
        ans_handles.append(
            get_answers_func(
                tgt_model,
                model,
                tokenizer,
                model_id,
                questions[i: i + chunk_size],
                answer_file,
                max_new_tokens,
                num_choices,
                tree_choices,
                logits_processor,
                max_steps,
                is_llama3
            )
        )

    if use_ray:
        ray.get(ans_handles)


def get_model_answers(
        tgt_model,
        model,
        tokenizer,
        model_id,
        questions,
        answer_file,
        max_new_tokens,
        num_choices,
        tree_choices,
        logits_processor,
        max_steps,
        is_llama3
):

    tgt_model.eval()
    print('Check model training state:', tgt_model.training)

    model.eval()
    print('Check model training state:', model.training)

    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
    
    
    question = questions[0]
    accept_lengths_tree = []
    for question in tqdm(questions):
        #breakpoint()
        if question["question_id"] in [481,241,97]:
            #breakpoint()
            continue
        choices = []
        for i in range(num_choices):
            cur_accept_lengths_tree = []
            torch.manual_seed(i)
            conv = get_conversation_template("vicuna")
            #breakpoint()
            turns = []
            steps = []
            new_tokens = []
            wall_time = []
            quan_time={"writing":0.00514933059,"roleplay":0.00514933059,"reasoning":0.00514933059,"math":0.00514933059,
                        "coding":0.00514933059,"extraction":0.00514933059,"stem":0.00514933059,"humanities":0.00514933059,
                        "translation":0.00503778338,"summarization": 0.00537634409, "qa": 0.00499251123,"math_reasoning":0.00500250125,
                        "rag":0.00542888165}
            for j in range(len(question["turns"])):
                qs = question["turns"][j]
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                conv.stop_str = "</s>"
                prompt = conv.get_prompt()
                
                inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
                input_ids_init = inputs.input_ids

                tree_decoding_cnt=0
                tree_decoding_time=[]
                evaluate_posterior_cnt=0
                evaluate_posterior_time=[]
                verify_cnt=0
                verify_time=[]
                update_inference_inputs_cnt=0
                update_inference_inputs_time=[]
                exit_cnt=0
                exit_time=[]
                
                try:
                    torch.cuda.synchronize()
                    start_time = time.time()
                    input_ids = inputs.input_ids
                    #breakpoint()
                    padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
                    assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
                    # Avoid modifying the input_ids in-place
                    input_ids = input_ids.clone()
                    model.ea_layer.reset_kv()
                    accept_length_list = []
                    if is_llama3:
                        stop_token_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
                    if hasattr(model, "past_key_values"):
                        past_key_values = model.past_key_values
                        past_key_values_data = model.past_key_values_data
                        current_length_data = model.current_length_data
                        # Reset the past key and value states
                        current_length_data.zero_()
                    else:
                        (
                            past_key_values,
                            past_key_values_data,
                            current_length_data,
                        ) = initialize_past_key_values(model.base_model)
                        model.past_key_values = past_key_values
                        model.past_key_values_data = past_key_values_data
                        model.current_length_data = current_length_data
                    
                    if hasattr(tgt_model, "past_key_values"):
                        tgt_past_key_values = tgt_model.past_key_values
                        tgt_past_key_values_data = tgt_model.past_key_values_data
                        tgt_current_length_data = tgt_model.current_length_data
                        # Reset the past key and value states
                        tgt_current_length_data.zero_()
                    else:
                        (
                            tgt_past_key_values,
                            tgt_past_key_values_data,
                            tgt_current_length_data,
                        ) = initialize_past_key_values(tgt_model)
                        tgt_model.past_key_values = tgt_past_key_values
                        tgt_model.past_key_values_data = tgt_past_key_values_data
                        tgt_model.current_length_data = tgt_current_length_data
                    input_len = input_ids.shape[1]
                    cur_length = input_len
                    tgt_length = input_len
                    reset_tree_mode(model)

                    #past_key_values是base_model的KV
                    candidate_logits=None
                    draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
                                input_ids, model, past_key_values, logits_processor
                            )
                    new_token = 0
                    verify_token = 0
                    #breakpoint()
                    
                    for idx in range(max_steps):
                        
                        model.base_model.model.tree_mask = tree_mask
                        draft_tokens=draft_tokens.to(input_ids.device)
                    
                        #past_key_values是base_model的KV
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        logits, hidden_state_new, outputs = tree_decoding(
                            model,
                            draft_tokens,
                            past_key_values,
                            tree_position_ids,
                            input_ids,
                            retrieve_indices,
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        
                        tree_decoding_cnt+=1
                        tree_decoding_time.append(elapsed)

                        draft_tokens=torch.cat((draft_tokens,padding),dim=1)
                        candidates=draft_tokens[0,retrieve_indices]

                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        best_candidate, accept_length, sample_p = evaluate_posterior(
                            logits, candidates, logits_processor
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        evaluate_posterior_cnt+=1
                        evaluate_posterior_time.append(elapsed)
                        if candidate_logits==None:
                            candidate_logits=logits[None,best_candidate,:accept_length+1]
                        else:
                            candidate_logits=torch.cat((candidate_logits,logits[None,best_candidate,:accept_length+1]),dim=1)
                        
                        #breakpoint()
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        if verify_token>=10 :
                            #breakpoint()     
                            torch.cuda.synchronize()
                            start = time.perf_counter() 
                            input_ids = torch.cat(
                                [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
                            )
                            tgt_prev_length=tgt_current_length_data[0].item()
                            with torch.no_grad():
                                outputs = tgt_model(
                                    input_ids=input_ids[:,tgt_prev_length:] ,
                                    past_key_values=tgt_past_key_values,
                                )
                            if tgt_length == input_len:
                                new_logits = outputs.logits[:, tgt_length-1:]
                            else:
                                new_logits = outputs.logits
                            candidate_length=input_ids.shape[1]-tgt_length
                            #breakpoint()
                            selected_tokens = new_logits.argmax(dim=-1)

                            r=random.random()
                            if r>0.1:
                                n_matches = reject_sample(input_ids,candidate_logits,candidate_length,new_logits)
                                if n_matches==candidate_length:
                                    input_ids = torch.cat([input_ids,selected_tokens[:,-1:]], dim=-1)
                                    accept_length_list.append(int(selected_tokens.shape[-1]))
                                    tgt_length = tgt_length + selected_tokens.shape[-1]
                                    new_token = new_token + selected_tokens.shape[-1]
                                    cur_length = tgt_length
                                    candidate_logits=None
                                else:
                                    index=n_matches
                                    input_ids = input_ids[:,:tgt_length]                    
                                    input_ids = torch.cat([input_ids,selected_tokens[:,:index+1]], dim=-1)
                                    accept_length_list.append(int(index+1))
                                    tgt_length = tgt_length + index + 1
                                    new_token = new_token + index + 1
                                    cur_length = tgt_length
                                    candidate_logits=None
                            else:
                                candidate_logits=None
                                if not torch.equal(selected_tokens[:,:-1],input_ids[:, tgt_length:]):
                                    index = (selected_tokens[:,:-1] != input_ids[:, tgt_length:]).nonzero(as_tuple=False)[0][1]
                                    input_ids = input_ids[:,:tgt_length]                    
                                    input_ids = torch.cat([input_ids,selected_tokens[:,:index+1]], dim=-1)
                                    accept_length_list.append(int(index+1))
                                    tgt_length = tgt_length + index.item() + 1
                                    new_token = new_token + index.item() + 1
                                    cur_length = tgt_length
                                else:
                                    input_ids = torch.cat([input_ids,selected_tokens[:,-1:]], dim=-1)
                                    accept_length_list.append(selected_tokens.shape[-1])
                                    tgt_length = tgt_length + selected_tokens.shape[-1]
                                    new_token = new_token + selected_tokens.shape[-1]
                                    cur_length = tgt_length 
                            
                            model.ea_layer.reset_kv()
                            model.base_model.model.tree_mask = None
                            current_length_data.fill_(0)
                            tgt_current_length_data.fill_(cur_length-1)
                            draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
                                input_ids, model, past_key_values, logits_processor
                            )
                            verify_token=0
                            torch.cuda.synchronize()
                            elapsed = time.perf_counter() - start
                            exit_cnt+=1
                            exit_time.append(elapsed)
                            continue

                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        verify_cnt+=1
                        verify_time.append(elapsed)
                        #past_key_values_data: 会随着base_model的更改而更改
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, cnt, hidden_state, sample_token = update_inference_inputs(
                            input_ids,
                            candidates,
                            best_candidate,
                            accept_length,
                            retrieve_indices,
                            logits_processor,
                            new_token,
                            past_key_values_data,
                            current_length_data,
                            model,
                            hidden_state_new,
                            sample_p
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        update_inference_inputs_cnt+=1
                        update_inference_inputs_time.append(elapsed)

                        accept_length_tree = input_ids.shape[1] - cur_length
                        cur_length = accept_length_tree + cur_length
                        verify_token = verify_token + accept_length_tree
                        
                        #accept_length_list.append(verify_token)
                        
                        if is_llama3:
                            if stop_token_id in input_ids[0, input_len:].tolist():
                                accept_length_list.append(int(verify_token))
                                new_token=new_token+verify_token
                                break
                        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                            accept_length_list.append(int(verify_token))
                            new_token=new_token+verify_token
                            break
                        if new_token > max_new_tokens:
                            break
                        if input_ids.shape[1] > 1960:
                            break
                        
                        
                    torch.cuda.synchronize()
                    total_time = time.time() - start_time
                    accept_lengths_tree.extend(accept_length_list)
                    output_ids = input_ids[0][len(input_ids_init[0]):]

                    if conv.stop_token_ids:
                        stop_token_ids_index = [
                            i
                            for i, id in enumerate(output_ids)
                            if id in conv.stop_token_ids
                        ]
                        if len(stop_token_ids_index) > 0:
                            output_ids = output_ids[: stop_token_ids_index[0]]

                    output = tokenizer.decode(
                        output_ids,
                        spaces_between_special_tokens=False,
                    )
                    
                    if conv.stop_str and output.find(conv.stop_str) > 0:
                        output = output[: output.find(conv.stop_str)]
                    for special_token in tokenizer.special_tokens_map.values():
                        if isinstance(special_token, list):
                            for special_tok in special_token:
                                output = output.replace(special_tok, "")
                        else:
                            output = output.replace(special_token, "")
                    output = output.strip()

                    if conv.name == "xgen" and output.startswith("Assistant:"):
                        output = output.replace("Assistant:", "", 1).strip()
                except RuntimeError as e:
                    print("ERROR question ID: ", question["question_id"])
                    output = "ERROR"
                
                # 计算每个列表的总和
                tree_decoding_sum = sum(tree_decoding_time)
                affinequant_time= (tree_decoding_cnt + exit_cnt) * quan_time[question["category"]]
                evaluate_posterior_sum = sum(evaluate_posterior_time)
                verify_sum = sum(verify_time)
                update_inference_inputs_sum = sum(update_inference_inputs_time)
                exit_sum = tree_decoding_sum / (tree_decoding_cnt ) * exit_cnt
                print(f"tree_decoding_cnt: {tree_decoding_cnt}, tree_decoding_time sum: {tree_decoding_sum}")
                print(f"evaluate_posterior_cnt: {evaluate_posterior_cnt}, evaluate_posterior_time sum: {evaluate_posterior_sum}")
                print(f"verify_cnt: {verify_cnt}, verify_time sum: {verify_sum}")
                print(f"update_inference_inputs_cnt: {update_inference_inputs_cnt}, update_inference_inputs_time sum: {update_inference_inputs_sum}")
                print(f"exit_cnt: {exit_cnt}, exit_time sum: {exit_sum}")
                # 计算求和的值与cnt的比值
                tree_decoding_avg = tree_decoding_sum / (tree_decoding_cnt + 1e-6)  # 避免除以零
                evaluate_posterior_avg = evaluate_posterior_sum / (evaluate_posterior_cnt + 1e-6)
                verify_avg = verify_sum / (verify_cnt + 1e-6)
                update_inference_inputs_avg = update_inference_inputs_sum / (update_inference_inputs_cnt + 1e-6)
                exit_avg = exit_sum / (exit_cnt + 1e-6)

                # 输出求和的值与cnt的比值
                print(f"Average tree_decoding_time: {tree_decoding_avg}")
                print(f"Average evaluate_posterior_time: {evaluate_posterior_avg}")
                print(f"Average verify_time: {verify_avg}")
                print(f"Average update_inference_inputs_time: {update_inference_inputs_avg}")
                print(f"Average exit_time: {exit_avg}")
                
                step=idx+1
                turns.append(output)
                steps.append(int(step))
                new_tokens.append(int(new_token))
                total_time=total_time - tree_decoding_sum - exit_sum + affinequant_time
                wall_time.append(total_time)
                cur_accept_lengths_tree.extend(accept_length_list)
                conv.messages[-1][-1] = output
            # torch.cuda.empty_cache()
            
            choices.append({"index": i, "turns": turns, "decoding_steps": steps, "new_tokens": new_tokens, "wall_time": wall_time,
                            "accept_lengths": cur_accept_lengths_tree})

        # Dump answers
        os.makedirs(os.path.dirname(answer_file), exist_ok=True)
        with open(os.path.expanduser(answer_file), "a") as fout:
            ans_json = {
                "question_id": question["question_id"],
                "category": question["category"],
                "answer_id": shortuuid.uuid(),
                "model_id": model_id,
                "choices": choices,
                "tstamp": time.time(),
            }
            fout.write(json.dumps(ans_json) + "\n")
    print("#Mean accepted tokens: ", np.mean(accept_lengths_tree))


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))
    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
    )
    parser.add_argument("--base-model-path", type=str, default="/home/lyh/weights/hf/llama2chat/70B/",
                        help="1")
    parser.add_argument(
        "--ea-model-path",
        type=str,
        default="down_checkpoints/LC70B",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end",
        type=int,
        help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--total-token",
        type=int,
        default=60,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=5,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=10,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=2048,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for medusa sampling.",
    )
    parser.add_argument(
        "--tree-choices",
        type=str,
        default="mc_sim_7b_63",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float32", "float64", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU.",
    )

    args = parser.parse_args()

    args.model_id = "vicuna" + "-temperature-" + str(args.temperature)
    args.tree_choices = eval(args.tree_choices)

    question_file = f"data/{args.bench_name}/question.jsonl"
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
        
    print(f"Output to {answer_file}")

    
    tgt_model = KVLlamaForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=str_to_torch_dtype(args.dtype),
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    
    model = EaModel.from_pretrained(
        base_model_path=args.base_model_path,
        ea_model_path=args.ea_model_path,
        total_token=args.total_token,
        depth=args.depth,
        top_k=args.top_k,
        torch_dtype=str_to_torch_dtype(args.dtype),
        low_cpu_mem_usage=True,
        # load_in_8bit=True,
        device_map="auto"
    )

    tokenizer = model.get_tokenizer()
    #breakpoint()
    if args.temperature > 1e-5:
        logits_processor = prepare_logits_processor(temperature=args.temperature)
    else:
        logits_processor = None
    if 'llama-3' in args.base_model_path.lower():
        is_llama3=True
        
    else:
        is_llama3=False
    run_eval(
        tgt_model=tgt_model,
        model=model,
        tokenizer=tokenizer,
        model_id=args.model_id,
        question_file=question_file,
        question_begin=args.question_begin,
        question_end=args.question_end,
        answer_file=answer_file,
        max_new_tokens=args.max_new_tokens,
        is_llama3=is_llama3,
        num_choices=args.num_choices,
        num_gpus_per_model=args.num_gpus_per_model,
        num_gpus_total=args.num_gpus_total,
        tree_choices=args.tree_choices,
        logits_processor=logits_processor,
        max_steps=args.max_steps,
        
    )

    reorg_answer_file(answer_file)
