import argparse
from fastchat.utils import str_to_torch_dtype
from fastchat.utils import str_to_torch_dtype
from fastchat.llm_judge.common import load_questions
from transformers import AutoModelForCausalLM, AutoTokenizer

from fastchat.model import get_conversation_template
from tqdm import tqdm
import json
import os
import time
import torch
import numpy as np
import shortuuid
import copy
from model.eaglev2.ea_model import EaModel
from model.eaglev2.kv_cache import initialize_past_key_values
from model.eaglev2.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from model.eaglev2.utils import *
from model.eaglev2.choices import *
import random
def reject_sample(
    candidate_input_ids,
    candidate_logits,
    candidate_length,
    new_logits,
):
    new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
    # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
    # selected by the assistant, respectively.
    q = candidate_logits.softmax(dim=-1)
    q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    p = new_logits.softmax(dim=-1)
    p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    probability_ratio = p_i / q_i
    r_i = torch.rand_like(probability_ratio)
    is_accepted = r_i <= probability_ratio
    n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()  # this is `n` in algorithm 1

    return n_matches

def run_eval(
        tgt_model,
        model,
        tokenizer,
        model_id,
        question_file,
        question_begin,
        question_end,
        answer_file,
        max_new_tokens,
        num_choices,
        num_gpus_per_model,
        num_gpus_total,
        tree_choices,
        logits_processor,
        max_steps,
        is_llama3
):
    questions = load_questions(question_file, question_begin, question_end)

    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        import ray
        ray.init()
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            get_model_answers
        ).remote
    else:
        get_answers_func = get_model_answers

    chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)  # // 2
    ans_handles = []
    for i in range(0, len(questions), chunk_size):
        ans_handles.append(
            get_answers_func(
                tgt_model,
                model,
                tokenizer,
                model_id,
                questions[i: i + chunk_size],
                answer_file,
                max_new_tokens,
                num_choices,
                tree_choices,
                logits_processor,
                max_steps,
                is_llama3
            )
        )

    if use_ray:
        ray.get(ans_handles)


def get_model_answers(
        tgt_model,
        model,
        tokenizer,
        model_id,
        questions,
        answer_file,
        max_new_tokens,
        num_choices,
        tree_choices,
        logits_processor,
        max_steps,
        is_llama3
):

    tgt_model.eval()
    print('Check model training state:', tgt_model.training)

    model.eval()
    print('Check model training state:', model.training)

    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
    
    
    question = questions[0]
    accept_lengths_tree = []
    for question in tqdm(questions):
        #breakpoint()
        choices = []
        for i in range(num_choices):
            cur_accept_lengths_tree = []
            torch.manual_seed(i)
            conv = get_conversation_template("llama-2-chat")
            sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            conv.system_message = sys_p
            #breakpoint()
            turns = []
            steps = []
            new_tokens = []
            wall_time = []
            quan_time = {"writing":0.00567260941,"roleplay":0.00567260941,"reasoning":0.00567260941,"math":0.00567260941,
                         "coding":0.00567260941,"extraction":0.00567260941,"stem":0.00567260941,"humanities":0.00567260941,
                         "translation":0.00538026373,"summarization":0.00547867299, "qa": 0.00555285665,"math_reasoning":0.00534214391,
                         "rag":0.00561797753}
            for j in range(len(question["turns"])):
                qs = question["turns"][j]
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt() + " "
                
                inputs = tokenizer([prompt],return_tensors="pt").to("cuda")
                input_ids = inputs.input_ids
                input_ids_init = input_ids

                tree_decoding_cnt=0
                tree_decoding_time=[]
                evaluate_posterior_cnt=0
                evaluate_posterior_time=[]
                verify_cnt=0
                verify_time=[]
                update_inference_inputs_cnt=0
                update_inference_inputs_time=[]
                exit_cnt=0
                exit_time=[]
                
                try:
                    torch.cuda.synchronize()
                    start_time = time.time()
                    #breakpoint()
                    padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
                    assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
                    # Avoid modifying the input_ids in-place
                    input_ids = input_ids.clone()
                    model.ea_layer.reset_kv()
                    accept_length_list = []
                    if is_llama3:
                        stop_token_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
                    if hasattr(model, "past_key_values"):
                        past_key_values = model.past_key_values
                        past_key_values_data = model.past_key_values_data
                        current_length_data = model.current_length_data
                        # Reset the past key and value states
                        current_length_data.zero_()
                    else:
                        (
                            past_key_values,
                            past_key_values_data,
                            current_length_data,
                        ) = initialize_past_key_values(model.base_model)
                        model.past_key_values = past_key_values
                        model.past_key_values_data = past_key_values_data
                        model.current_length_data = current_length_data
                    
                    if hasattr(tgt_model, "past_key_values"):
                        tgt_past_key_values = tgt_model.past_key_values
                        tgt_past_key_values_data = tgt_model.past_key_values_data
                        tgt_current_length_data = tgt_model.current_length_data
                        # Reset the past key and value states
                        tgt_current_length_data.zero_()
                    else:
                        (
                            tgt_past_key_values,
                            tgt_past_key_values_data,
                            tgt_current_length_data,
                        ) = initialize_past_key_values(tgt_model)
                        tgt_model.past_key_values = tgt_past_key_values
                        tgt_model.past_key_values_data = tgt_past_key_values_data
                        tgt_model.current_length_data = tgt_current_length_data
                    input_len = input_ids.shape[1]
                    cur_length = input_len
                    tgt_length = input_len
                    reset_tree_mode(model)

                    #past_key_values是base_model的KV
                    candidate_logits=None
                    draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
                                input_ids, model, past_key_values, logits_processor
                            )
                    new_token = 0
                    verify_token = 0
                    #breakpoint()
                    
                    for idx in range(max_steps):
                        
                        model.base_model.model.tree_mask = tree_mask
                        draft_tokens=draft_tokens.to(input_ids.device)
                    
                        #past_key_values是base_model的KV
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        logits, hidden_state_new, outputs = tree_decoding(
                            model,
                            draft_tokens,
                            past_key_values,
                            tree_position_ids,
                            input_ids,
                            retrieve_indices,
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        
                        tree_decoding_cnt+=1
                        tree_decoding_time.append(elapsed)

                        draft_tokens=torch.cat((draft_tokens,padding),dim=1)
                        candidates=draft_tokens[0,retrieve_indices]

                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        best_candidate, accept_length, sample_p = evaluate_posterior(
                            logits, candidates, logits_processor
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        evaluate_posterior_cnt+=1
                        evaluate_posterior_time.append(elapsed)
                        if candidate_logits==None:
                            candidate_logits=logits[None,best_candidate,:accept_length+1]
                        else:
                            candidate_logits=torch.cat((candidate_logits,logits[None,best_candidate,:accept_length+1]),dim=1)
                        
                        #breakpoint()
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        if verify_token>=10 :
                            #breakpoint()     
                            
                            input_ids = torch.cat(
                                [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
                            )
                            tgt_prev_length=tgt_current_length_data[0].item()
                            with torch.no_grad():
                                outputs = tgt_model(
                                    input_ids=input_ids[:,tgt_prev_length:] ,
                                    past_key_values=tgt_past_key_values,
                                )
                            if tgt_length == input_len:
                                new_logits = outputs.logits[:, tgt_length-1:]
                            else:
                                new_logits = outputs.logits
                            candidate_length=input_ids.shape[1]-tgt_length
                            #breakpoint()
                            selected_tokens = new_logits.argmax(dim=-1)

                            r=random.random()
                            if r>1:
                                n_matches = reject_sample(input_ids,candidate_logits,candidate_length,new_logits)
                                if n_matches==candidate_length:
                                    input_ids = torch.cat([input_ids,selected_tokens[:,-1:]], dim=-1)
                                    accept_length_list.append(int(selected_tokens.shape[-1]))
                                    tgt_length = tgt_length + selected_tokens.shape[-1]
                                    new_token = new_token + selected_tokens.shape[-1]
                                    cur_length = tgt_length
                                    candidate_logits=None
                                else:
                                    index=n_matches
                                    input_ids = input_ids[:,:tgt_length]                    
                                    input_ids = torch.cat([input_ids,selected_tokens[:,:index+1]], dim=-1)
                                    accept_length_list.append(int(index+1))
                                    tgt_length = tgt_length + index + 1
                                    new_token = new_token + index + 1
                                    cur_length = tgt_length
                                    candidate_logits=None
                            else:
                                candidate_logits=None
                                if not torch.equal(selected_tokens[:,:-1],input_ids[:, tgt_length:]):
                                    index = (selected_tokens[:,:-1] != input_ids[:, tgt_length:]).nonzero(as_tuple=False)[0][1]
                                    input_ids = input_ids[:,:tgt_length]                    
                                    input_ids = torch.cat([input_ids,selected_tokens[:,:index+1]], dim=-1)
                                    accept_length_list.append(int(index+1))
                                    tgt_length = tgt_length + index.item() + 1
                                    new_token = new_token + index.item() + 1
                                    cur_length = tgt_length
                                else:
                                    input_ids = torch.cat([input_ids,selected_tokens[:,-1:]], dim=-1)
                                    accept_length_list.append(selected_tokens.shape[-1])
                                    tgt_length = tgt_length + selected_tokens.shape[-1]
                                    new_token = new_token + selected_tokens.shape[-1]
                                    cur_length = tgt_length 
                            
                            model.ea_layer.reset_kv()
                            model.base_model.model.tree_mask = None
                            current_length_data.fill_(0)
                            tgt_current_length_data.fill_(cur_length - 1)
                            torch.cuda.synchronize()
                            start = time.perf_counter() 
                            draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token = initialize_tree(
                                input_ids, model, past_key_values, logits_processor
                            )
                            verify_token=0
                            torch.cuda.synchronize()
                            elapsed = time.perf_counter() - start
                            exit_cnt+=1
                            exit_time.append(elapsed)
                            continue

                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        verify_cnt+=1
                        verify_time.append(elapsed)
                        #past_key_values_data: 会随着base_model的更改而更改
                        torch.cuda.synchronize()
                        start = time.perf_counter()
                        input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, cnt, hidden_state, sample_token = update_inference_inputs(
                            input_ids,
                            candidates,
                            best_candidate,
                            accept_length,
                            retrieve_indices,
                            logits_processor,
                            new_token,
                            past_key_values_data,
                            current_length_data,
                            model,
                            hidden_state_new,
                            sample_p
                        )
                        torch.cuda.synchronize()
                        elapsed = time.perf_counter() - start
                        update_inference_inputs_cnt+=1
                        update_inference_inputs_time.append(elapsed)

                        accept_length_tree = input_ids.shape[1] - cur_length
                        cur_length = accept_length_tree + cur_length
                        verify_token = verify_token + accept_length_tree
                        
                        #accept_length_list.append(verify_token)
                        
                        if is_llama3:
                            if stop_token_id in input_ids[0, input_len:].tolist():
                                accept_length_list.append(int(verify_token))
                                new_token=new_token+verify_token
                                break
                        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                            accept_length_list.append(int(verify_token))
                            new_token=new_token+verify_token
                            break
                        if new_token > max_new_tokens:
                            break
                        if input_ids.shape[1] > 1960:
                            break
                        
                        
                    torch.cuda.synchronize()
                    total_time = time.time() - start_time
                    accept_lengths_tree.extend(accept_length_list)
                    output_ids = input_ids[0][len(input_ids_init[0]):]

                    if conv.stop_token_ids:
                        stop_token_ids_index = [
                            i
                            for i, id in enumerate(output_ids)
                            if id in conv.stop_token_ids
                        ]
                        if len(stop_token_ids_index) > 0:
                            output_ids = output_ids[: stop_token_ids_index[0]]


                    output = tokenizer.decode(
                        output_ids,
                        spaces_between_special_tokens=False,
                    )
                    conv.stop_str = "</s>"
                    if conv.stop_str and output.find(conv.stop_str) > 0:
                        output = output[: output.find(conv.stop_str)]
                    for special_token in tokenizer.special_tokens_map.values():
                        if isinstance(special_token, list):
                            for special_tok in special_token:
                                output = output.replace(special_tok, "")
                        else:
                            output = output.replace(special_token, "")
                    output = output.strip()

                    if conv.name == "xgen" and output.startswith("Assistant:"):
                        output = output.replace("Assistant:", "", 1).strip()
                except RuntimeError as e:
                    print("ERROR question ID: ", question["question_id"])
                    output = "ERROR"
                
                # 计算每个列表的总和
                tree_decoding_sum = sum(tree_decoding_time)
                affinequant_time= (tree_decoding_cnt+exit_cnt)*quan_time[question["category"]]
                evaluate_posterior_sum = sum(evaluate_posterior_time)
                verify_sum = sum(verify_time)
                update_inference_inputs_sum = sum(update_inference_inputs_time)
                exit_sum = tree_decoding_sum / (tree_decoding_cnt + 1e-6) * exit_cnt
                print(f"tree_decoding_cnt: {tree_decoding_cnt}, tree_decoding_time sum: {tree_decoding_sum}")
                print(f"evaluate_posterior_cnt: {evaluate_posterior_cnt}, evaluate_posterior_time sum: {evaluate_posterior_sum}")
                print(f"verify_cnt: {verify_cnt}, verify_time sum: {verify_sum}")
                print(f"update_inference_inputs_cnt: {update_inference_inputs_cnt}, update_inference_inputs_time sum: {update_inference_inputs_sum}")
                print(f"exit_cnt: {exit_cnt}, exit_time sum: {exit_sum}")
                # 计算求和的值与cnt的比值
                tree_decoding_avg = tree_decoding_sum / (tree_decoding_cnt + 1e-6)  # 避免除以零
                evaluate_posterior_avg = evaluate_posterior_sum / (evaluate_posterior_cnt + 1e-6)
                verify_avg = verify_sum / (verify_cnt + 1e-6)
                update_inference_inputs_avg = update_inference_inputs_sum / (update_inference_inputs_cnt + 1e-6)
                exit_avg = exit_sum / (exit_cnt + 1e-6)

                # 输出求和的值与cnt的比值
                print(f"Average tree_decoding_time: {tree_decoding_avg}")
                print(f"Average evaluate_posterior_time: {evaluate_posterior_avg}")
                print(f"Average verify_time: {verify_avg}")
                print(f"Average update_inference_inputs_time: {update_inference_inputs_avg}")
                print(f"Average exit_time: {exit_avg}")
                #breakpoint()
                step=idx+1
                turns.append(output)
                steps.append(int(step))
                new_tokens.append(int(new_token))
                total_time=total_time - tree_decoding_sum - exit_sum + affinequant_time
                wall_time.append(total_time)
                cur_accept_lengths_tree.extend(accept_length_list)
                conv.messages[-1][-1] = output
            # torch.cuda.empty_cache()
            
            choices.append({"index": i, "turns": turns, "decoding_steps": steps, "new_tokens": new_tokens, "wall_time": wall_time,
                            "accept_lengths": cur_accept_lengths_tree})

        # Dump answers
        os.makedirs(os.path.dirname(answer_file), exist_ok=True)
        with open(os.path.expanduser(answer_file), "a") as fout:
            ans_json = {
                "question_id": question["question_id"],
                "category": question["category"],
                "answer_id": shortuuid.uuid(),
                "model_id": model_id,
                "choices": choices,
                "tstamp": time.time(),
            }
            fout.write(json.dumps(ans_json) + "\n")
    print("#Mean accepted tokens: ", np.mean(accept_lengths_tree))


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))
    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
    )
    parser.add_argument("--base-model-path", type=str, default="/home/lyh/weights/hf/llama2chat/70B/",
                        help="1")
    parser.add_argument(
        "--ea-model-path",
        type=str,
        default="down_checkpoints/LC70B",
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end",
        type=int,
        help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--total-token",
        type=int,
        default=60,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=5,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=10,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=2048,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for medusa sampling.",
    )
    parser.add_argument(
        "--tree-choices",
        type=str,
        default="mc_sim_7b_63",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float32", "float64", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU.",
    )

    args = parser.parse_args()

    args.model_id = "llama2chat" + "-temperature-" + str(args.temperature)
    args.tree_choices = eval(args.tree_choices)

    question_file = f"data/{args.bench_name}/question.jsonl"
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
        
    print(f"Output to {answer_file}")

    if 'llama-3' in args.base_model_path.lower():
        is_llama3=True
    else:
        is_llama3=False
    
    tgt_model = KVLlamaForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=str_to_torch_dtype(args.dtype),
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    
    model = EaModel.from_pretrained(
        base_model_path=args.base_model_path,
        ea_model_path=args.ea_model_path,
        total_token=args.total_token,
        depth=args.depth,
        top_k=args.top_k,
        torch_dtype=str_to_torch_dtype(args.dtype),
        low_cpu_mem_usage=True,
        # load_in_8bit=True,
        device_map="auto"
    )

    tokenizer = model.get_tokenizer()
    #breakpoint()
    if args.temperature > 1e-5:
        logits_processor = prepare_logits_processor(temperature=args.temperature)
    else:
        logits_processor = None

    run_eval(
        tgt_model=tgt_model,
        model=model,
        tokenizer=tokenizer,
        model_id=args.model_id,
        question_file=question_file,
        question_begin=args.question_begin,
        question_end=args.question_end,
        answer_file=answer_file,
        max_new_tokens=args.max_new_tokens,
        is_llama3=is_llama3,
        num_choices=args.num_choices,
        num_gpus_per_model=args.num_gpus_per_model,
        num_gpus_total=args.num_gpus_total,
        tree_choices=args.tree_choices,
        logits_processor=logits_processor,
        max_steps=args.max_steps,
        
    )

    reorg_answer_file(answer_file)
