"""Generate answers with local models.

Usage:
python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
"""
import argparse
import json
import os
import random
import time
import shortuuid
import torch
import numpy as np
from tqdm import tqdm
import string
import matplotlib.pyplot as plt
import seaborn as sns
from fastchat.llm_judge.common import load_questions
from fastchat.model import load_model, get_conversation_template


# Rest imports
import transformers

import sys
sys.path.append("../")

from rest.model.utils import *
from rest.model.rest_model import RestModel
from rest.model.kv_cache import initialize_past_key_values
import draftretriever


def rest_forward(input_ids, model, tokenizer, max_new_token, temperature, top_p, datastore, num_draft, token_spans, max_steps=1024):
    assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
    # Avoid modifying the input_ids in-place
    input_ids = input_ids.clone()
    accept_length_list = []

    # Initialize the past key and value states
    if hasattr(model, "past_key_values"):
        past_key_values = model.past_key_values
        past_key_values_data = model.past_key_values_data
        current_length_data = model.current_length_data
        # Reset the past key and value states
        current_length_data.zero_()
    else:
        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(model.base_model)
        model.past_key_values = past_key_values
        model.past_key_values_data = past_key_values_data
        model.current_length_data = current_length_data

    input_len = input_ids.shape[1]

    # David: In the unchanged implementation, the final output of 
    #        sum(accept_length_list) equals to the total number of 
    #        new tokens - 1. This +1 is what results in this - 1. 
    #        Since sum(accept_length_list) is not the list of real
    #        accepted length of each iteration, there is no reason
    #        for this + 1 to be here, so I am pretty sure this 
    #        is a bug.

    # David: REST has to predict around half wrong, if sum(accept_length_list)
    #        Is the list of real accepted length of each iteration, then that means
    #        Only the first token is mis-speculated, and all other tokens are 
    #        correctly speculated, which is impossible
    # cur_length = input_len + 1
    cur_length = input_len
    model.base_model.model.draft_mask = None
    logits = initialize_logits(
            input_ids, model, past_key_values
    )
    new_token = 0
    individual_token_times = []
    torch.cuda.synchronize()
    
    for idx in range(max_steps): 
        start_time = time.time()
        # David: Match and retrieve speculation Candidates from datastore
        candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer(
                logits,
                input_ids,
                datastore,
                token_spans,
                top_p,
                temperature,
                max_num_draft=num_draft,
                device=model.base_model.device
            )
        model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"]

        # David: one pass parallel verification of all candidates using the real target model
        logits, outputs = tree_decoding(
                model,
                tree_candidates,
                past_key_values,
                draft_buffers["draft_position_ids"],
                input_ids,
                draft_buffers["retrieve_indices"],
            )
        

        # Daivd: Take the target model output, use argmax on logits to find the true output
        #        of the target model at each layer of the speculation tree, and walk down 
        #        the tree, asking each tokens: "Are you the token target model outputs?"
        #        If yes, that token is accepted

        # David: Then we choose the longest path down from the root of the tree. This is 
        #        best_candidate, len(best_candidate) = accept_length is the real
        #        accepted length of this iteration
        best_candidate, accept_length = evaluate_posterior(
                logits, candidates, temperature, top_p
            )
        
        # David: Besides best_candidate, we take one more token, this is the target model's
        #        generation. In the worst case, this token makes it at least autoregression
        input_ids, logits, new_token = update_inference_inputs(
                input_ids,
                candidates,
                best_candidate,
                accept_length,
                draft_buffers["retrieve_indices"],
                outputs,
                logits,
                new_token,
                past_key_values_data,
                current_length_data,
            )
        end_time = time.time()
        individual_token_times.append(end_time - start_time)

        accept_length_tree = input_ids.shape[1] - cur_length
        cur_length = accept_length_tree + cur_length
        accept_length_list.append(accept_length_tree)
        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_token:
            break
    return input_ids, new_token, idx, accept_length_list, individual_token_times

def set_up(
    model_path, 
    datastore_path, 
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    num_draft,
    max_token_span,
    ):
    print("loading the datastore ...")
    datastore = draftretriever.Reader(
                index_file_path=datastore_path,
            )
    print("datastore loaded!")

    model = RestModel.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )

    tokenizer = model.get_tokenizer()
    
    model.eval()
    print('Check model training state:',model.training)
    
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
    token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            warmup
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = warmup

    # warmup
    get_answers_func(
        model,
        tokenizer,
        model_id,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )

    return datastore, model, tokenizer


def warmup(
    model,
    tokenizer,
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    datastore,
    num_draft,
    token_spans,
    ):
    torch.manual_seed(0)
    conv = get_conversation_template(model_id)
    qs = "what is Socrates best known for?"
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids

    # some models may error out when generating long outputs
    try:
        rest_forward(
            torch.as_tensor(input_ids).cuda(),
            model,
            tokenizer,
            max_new_token,
            temperature,
            top_p,
            datastore,
            num_draft,
            token_spans,
        )
    except RuntimeError as e:
        print(f"warm up errored out with {e}")
        output = "ERROR"
    print('Warmup done')



def execute(
    model,
    tokenizer,
    datastore,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    num_draft,
    max_token_span,
    exp=0
):
    token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            qsstr_in_spec_tok_out
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = qsstr_in_spec_tok_out

    accept_length_tree, individual_token_times = get_answers_func(
        model,
        tokenizer,
        model_id,
        question,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )
    
    return get_trace_lst(accept_length_tree, individual_token_times, exp)

    # print(f"expect length: {sum(accept_length_tree)}")
    # print(f"get length: {len(trace_out)}")
    # plot_trace(trace_out)
    
    # if use_ray:
    #     ray.get(ans_handles)

    # return acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3


@torch.inference_mode()
def qsstr_in_spec_tok_out(
    model,
    tokenizer,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    datastore,
    num_draft,
    token_spans,
):
    

    torch.manual_seed(100)
    conv = get_conversation_template(model_id)

    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids
    set_out = set()
    # some models may error out when generating long outputs
    
    # try:
        
    output_ids, new_token, idx, accept_length_tree, individual_token_times = rest_forward(
        torch.as_tensor(input_ids).cuda(),
        model,
        tokenizer,
        max_new_token,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )
        
        
        # output_ids = output_ids[0][len(input_ids[0]) :]
        
        # # be consistent with the template's stop_token_ids
        # if conv.stop_token_ids:
        #     stop_token_ids_index = [
        #         i
        #         for i, id in enumerate(output_ids)
        #         if id in conv.stop_token_ids
        #     ]
        #     if len(stop_token_ids_index) > 0:
        #         output_ids = output_ids[: stop_token_ids_index[0]]

        # output = tokenizer.decode(
        #     output_ids,
        #     spaces_between_special_tokens=False,
        # )
        # if conv.stop_str and output.find(conv.stop_str) > 0:
        #     output = output[: output.find(conv.stop_str)]

        # if conv.name == "xgen" and output.startswith("Assistant:"):
        #     output = output.replace("Assistant:", "", 1).strip()
    # except RuntimeError as e:
    #     print("ERROR question : ", question)
    #     output = "ERROR"

    # print(output)
    # print("\n\n")
    # print(f"collected total: {sum(accept_length_tree)}")
    # print(f"real total: {new_token}")
    # print(f"token iterations: {len(accept_length_tree)}")
    # print(f"time iterations: {len(individual_token_times)}")
    return accept_length_tree, individual_token_times
    
def get_trace_lst(accept_length_tree, individual_token_times, exp=0):
    acc_0 = []
    acc_1 = []

    for i in range(len(accept_length_tree)):

        avg_0 = individual_token_times[i]
        acc_0 += [avg_0]

        num_ele = accept_length_tree[i]
        avg_1 = individual_token_times[i] / num_ele
        acc_1 += [avg_1 for _ in range(num_ele)]

    return acc_0, acc_1
     

def plot_trace(lst: list):
    # Create ./Outputs directory if it doesn't exist
    output_dir = './out_trace'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # X-axis values corresponding to the number of words processed
    index = list(range(len(lst)))
    

    # Create the plot
    plt.figure(figsize=(8,6))
    plt.plot(index, lst, marker='o')

    # Label the axes
    plt.xlabel('Token Index')
    plt.ylabel('Time per token')

    # Add a title
    plt.title('Time per Token Graph')

    # Show the grid
    plt.grid(True)

    # Save the plot to the ./Outputs directory
    output_path = os.path.join(output_dir, f'iteration_traces.jpg')
    plt.savefig(output_path)

    # Close the plot to avoid display in notebooks or subsequent operations
    plt.close()

    print(f"Plot saved to {output_path}")



def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


# def pad_list(input_list, target_length, padding_value=0):
#     """
#     Pads the input list with the specified padding value until it reaches the target length.
    
#     Parameters:
#     input_list (list): The list to be padded.
#     target_length (int): The desired length of the list after padding.
#     padding_value: The value to use for padding. Default is 0.
    
#     Returns:
#     list: The padded list.
#     """
#     current_length = len(input_list)
#     if current_length >= target_length:
#         return input_list[:target_length]  # Return truncated list if it's longer than target length
#     else:
#         padding_needed = target_length - current_length
#         return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default="rest",
        help="The name of the ouput_file.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-token",
        type=int,
        default=100,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--top-p",
        type=float,
        default=0.0,
        help="The threshold for nucleus sampling.",
    )

    # REST's hyperparameters
    parser.add_argument(
        "--datastore-path",
        type=str,
        required=True,
        help="The path of the datastore for retrival.",
    )

    parser.add_argument(
        "--num-draft",
        type=int,
        default=64,
        help="The maximum number of draft tokens.",
    )
    parser.add_argument(
        "--max-token-span",
        type=int,
        default=32,
        help="The maximum length of suffix for retrieval.",
    )

    parser.add_argument(
        "--kind",
        type=str,
        required=False,
        help="The kind to run",
    )

    parser.add_argument(
        "--trials",
        type=int,
        help="The number of trials to run",
    )

    args = parser.parse_args()

    if args.temperature == 0:
        args.top_p = 0
        

    args.model_id = "rest-" + args.model_id+"-temperature-"+str(args.temperature)+"-top_p-"+str(args.top_p)
    if args.num_gpus_total // args.num_gpus_per_model > 1:
        import ray
        ray.init()

    # kind = "original"
    # kind = "semantics_sim"
    # kind = "structural_sim"
    # kind = "prompts"
    # kind = "similar_prompt"
    kind = args.kind
    question_file = f"data/{kind}.txt"

    # trials = 30 if kind == 'original' else 5
    # trials = 30
    # exp = 2
    trials = args.trials
    
    
    
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"out_trace/temp_{args.temperature}_{kind}_trace_{trials}.out"

    print(f"Output to {answer_file}")

    datastore, model, tokenizer = set_up(
        args.model_path,
        args.datastore_path, 
        args.model_id,
        args.max_new_token,
        args.num_choices,
        args.num_gpus_per_model,
        args.num_gpus_total,
        args.max_gpu_memory,
        args.temperature,
        args.top_p,
        args.num_draft,
        args.max_token_span,
        )

    prompts = load_prompts_from_txt(question_file)
    # execute(
    #     model,
    #     tokenizer,
    #     datastore,
    #     args.model_id,
    #     prompts[3],
    #     args.max_new_token,
    #     args.num_choices,
    #     args.num_gpus_per_model,
    #     args.num_gpus_total,
    #     args.max_gpu_memory,
    #     args.temperature,
    #     args.top_p,
    #     args.num_draft,
    #     args.max_token_span,
    # )
    
    traces_1 = []
    traces_0 = []
    labels = []
    for j in range(len(prompts)):
        acc_0, acc_1 = execute(
            model,
            tokenizer,
            datastore,
            args.model_id,
            prompts[j],
            args.max_new_token,
            args.num_choices,
            args.num_gpus_per_model,
            args.num_gpus_total,
            args.max_gpu_memory,
            args.temperature,
            args.top_p,
            args.num_draft,
            args.max_token_span,
            # exp
        )
        traces_1.append(len(acc_1))
        # traces_0.append(acc_0)
        # labels.append(prompts[j])
        print(f"done promt num: {j}")

    # with open(f"T0/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_0,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    # with open(f"T1/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_1,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # print(f"Output done to {answer_file}")
    print(traces_1)
    # traces_1 = [90, 133, 76, 272, 300, 284, 158, 282, 294, 224, 4, 346, 414, 385, 383, 220, 384, 308, 380, 107, 326, 355, 279, 157, 277, 187, 276, 322, 103, 277, 368, 229, 513, 326, 264, 325, 378, 326, 241, 351, 390, 215, 364, 364, 199, 301, 180, 327, 378, 258]
    plt.figure(figsize=(8, 6))
    sns.histplot(traces_1, bins=100, kde=False, color="blue")

    # Adding labels and title
    plt.title("Histogram of Trace Lengths")
    plt.xlabel("Length of acc_1")
    plt.ylabel("Frequency")

    # Save the plot as a PNG file
    plt.savefig("histogram_plot.png", dpi=300, bbox_inches='tight')
    
