"""Generate answers with local models.

Usage:
python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
"""
import argparse
import json
import os
import random
import time
import shortuuid
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import string
import matplotlib.pyplot as plt
import seaborn as sns
from fastchat.llm_judge.common import load_questions
from fastchat.model import load_model, get_conversation_template


# Rest imports
import transformers

import sys
sys.path.append("../")

from rest.model.utils import *
from rest.model.rest_model import RestModel
from rest.model.kv_cache import initialize_past_key_values
import draftretriever


def baseline_forward(input_ids, model, tokenizer, max_new_token, temperature, top_p, max_steps=1024):
    assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
    # Avoid modifying the input_ids in-place
    input_ids = input_ids.clone()

    # Initialize the past key and value states
    if hasattr(model, "past_key_values"):
        past_key_values = model.past_key_values
        past_key_values_data = model.past_key_values_data
        current_length_data = model.current_length_data
        # Reset the past key and value states
        current_length_data.zero_()
    else:
        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(model.base_model)
        model.past_key_values = past_key_values
        model.past_key_values_data = past_key_values_data
        model.current_length_data = current_length_data

    input_len = input_ids.shape[1]

    # David: In the unchanged implementation, the final output of 
    #        sum(accept_length_list) equals to the total number of 
    #        new tokens - 1. This +1 is what results in this - 1. 
    #        Since sum(accept_length_list) is not the list of real
    #        accepted length of each iteration, there is no reason
    #        for this + 1 to be here, so I am pretty sure this 
    #        is a bug.

    # David: REST has to predict around half wrong, if sum(accept_length_list)
    #        Is the list of real accepted length of each iteration, then that means
    #        Only the first token is mis-speculated, and all other tokens are 
    #        correctly speculated, which is impossible
    # cur_length = input_len + 1
    model.base_model.model.draft_mask = None
    outputs = model.base_model(input_ids, past_key_values = past_key_values, use_cache=True)
    new_token = 0
    # individual_token_times = []

    accept_length_list = []
    torch.cuda.synchronize()

    
    for idx in range(max_steps): 
        # start_time = time.time()
        if top_p > 0:
            assert top_p < 1, "top_p should between 0.0 and 1"
            next_token_logits = outputs.logits[:, -1, :]
            next_token_logits = next_token_logits / (temperature if temperature > 0 else 1.)
            filtered_logits = top_p_filtering(next_token_logits, top_p=top_p)
            input_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            input_id = input_id.view(input_id.shape[0], 1)
        else:
            input_id = outputs.logits[:, -1:].argmax(dim=-1)
        outputs = model.base_model(input_id, use_cache=True, past_key_values = past_key_values)
        curr_out = decoded_text = tokenizer.decode(input_id[0], spaces_between_special_tokens=False, skip_special_tokens=True)
        # print(curr_out, end="   ")
        input_ids = torch.cat([input_ids, input_id], dim=-1)
        
        # Finish generation here
        # end_time = time.time()
        # individual_token_times.append(end_time - start_time)
        

        new_token += 1
        accept_length_list.append(len(curr_out))
        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_token:
            break
    return accept_length_list #, individual_token_times

def set_up(
    model_path, 
    # datastore_path, 
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    # num_draft,
    # max_token_span,
    ):
    model = RestModel.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )

    tokenizer = model.get_tokenizer()
    
    model.eval()
    print('Check model training state:',model.training)
    
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
    # token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            warmup
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = warmup

    # warmup
    get_answers_func(
        model,
        tokenizer,
        model_id,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        # datastore,
        # num_draft,
        # token_spans,
    )

    return model, tokenizer


def warmup(
    model,
    tokenizer,
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    # datastore,
    # num_draft,
    # token_spans,
    ):
    torch.manual_seed(0)
    conv = get_conversation_template(model_id)
    qs = "what is Socrates best known for?"
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids

    # some models may error out when generating long outputs
    try:
        baseline_forward(
            torch.as_tensor(input_ids).cuda(),
            model,
            tokenizer,
            max_new_token,
            temperature,
            top_p,
        )
    except RuntimeError as e:
        print(f"warm up errored out with {e}")
        output = "ERROR"
    print('Warmup done')



def execute(
    model,
    tokenizer,
    # datastore,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    # num_draft,
    # max_token_span,
    # exp=0
):
    # token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            qsstr_in_spec_tok_out
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = qsstr_in_spec_tok_out

    accept_length_tree = get_answers_func(
        model,
        tokenizer,
        model_id,
        question,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        # datastore,
        # num_draft,
        # token_spans,
    )
    return accept_length_tree
    # return get_trace_lst(accept_length_tree, individual_token_times)

    # print(f"expect length: {sum(accept_length_tree)}")
    # print(f"get length: {len(trace_out)}")
    # plot_trace(trace_out)
    
    # if use_ray:
    #     ray.get(ans_handles)

    # return acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3


@torch.inference_mode()
def qsstr_in_spec_tok_out(
    model,
    tokenizer,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    # datastore,
    # num_draft,
    # token_spans,
):
    

    torch.manual_seed(100)
    conv = get_conversation_template(model_id)

    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids
    set_out = set()
    # some models may error out when generating long outputs
    # try:
        
        # output_ids, new_token, idx, accept_length_tree, individual_token_times = baseline_forward(
    accept_length_tree = baseline_forward(
        torch.as_tensor(input_ids).cuda(),
        model,
        tokenizer,
        max_new_token,
        temperature,
        top_p,
    )
        
        
        # output_ids = output_ids[0][len(input_ids[0]) :]
        
        # # be consistent with the template's stop_token_ids
        # if conv.stop_token_ids:
        #     stop_token_ids_index = [
        #         i
        #         for i, id in enumerate(output_ids)
        #         if id in conv.stop_token_ids
        #     ]
        #     if len(stop_token_ids_index) > 0:
        #         output_ids = output_ids[: stop_token_ids_index[0]]

        # output = tokenizer.decode(
        #     output_ids,
        #     spaces_between_special_tokens=False,
        # )
        # if conv.stop_str and output.find(conv.stop_str) > 0:
        #     output = output[: output.find(conv.stop_str)]

        # if conv.name == "xgen" and output.startswith("Assistant:"):
        #     output = output.replace("Assistant:", "", 1).strip()
    # except RuntimeError as e:
    #     print("ERROR question : ", question)
    #     output = "ERROR"

    # print(output)
    # print("\n\n")
    # print(f"collected total: {sum(accept_length_tree)}")
    # print(f"real total: {new_token}")
    # print(f"token iterations: {len(accept_length_tree)}")
    # print(f"time iterations: {len(individual_token_times)}")
    return accept_length_tree #, individual_token_times
    
def get_trace_lst(accept_length_tree, individual_token_times):
    # time1, time2, time3 = 0, 0, 0
    acc_0 = []
    # acc_1 = []
    # acc_2 = []
    # acc_3 = []
    # acc_4 = []
    # acc_5 = []
    # acc_6 = []
    # acc_7 = []
    # acc_8 = []
    # acc_9 = []
    for i in range(len(accept_length_tree)):

        avg_0 = individual_token_times[i]
        acc_0 += [avg_0]

    #     num_ele = accept_length_tree[i]
    #     avg_1 = individual_token_times[i] / num_ele
    #     acc_1 += [avg_1 for _ in range(num_ele)]

    #     random_float = np.random.uniform(0, 0.04)
    #     time1 += random_float
    #     num_ele = accept_length_tree[i]
    #     avg_2 = (individual_token_times[i] + random_float) / num_ele
    #     acc_2 += [avg_2 for _ in range(num_ele)]

    #     random_float = np.random.uniform(0, 0.08)
    #     time2 += random_float
    #     num_ele = accept_length_tree[i]
    #     avg_3 = (individual_token_times[i] + random_float) / num_ele
    #     acc_3 += [avg_3 for _ in range(num_ele)]

    #     random_float = np.random.uniform(0, 0.12)
    #     time3 += random_float
    #     num_ele = accept_length_tree[i]
    #     avg_4 = (individual_token_times[i] + random_float) / num_ele
    #     acc_4 += [avg_4 for _ in range(num_ele)]

    #     avg_5 = individual_token_times[i] / 64
    #     acc_5 += [avg_5 for _ in range(64)]

    # num_acc = 0
    # time_acc = 0.0
    # for i in range(len(accept_length_tree)):
    #     if (i + 1) % 3 == 0:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]
    #         avg_6 = time_acc / num_acc
    #         acc_6 += [avg_6 for _ in range(num_acc)]
    #         num_acc = 0
    #         time_acc = 0.0
    #     else:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]

    # num_acc = 0
    # time_acc = 0.0
    # for i in range(len(accept_length_tree)):
    #     if (i + 1) % 5 == 0:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]
    #         avg_7 = time_acc / num_acc
    #         acc_7 += [avg_7 for _ in range(num_acc)]
    #         num_acc = 0
    #         time_acc = 0.0
    #     else:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]

    # num_acc = 0
    # time_acc = 0.0
    # for i in range(len(accept_length_tree)):
    #     if (i + 1) % 10 == 0:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]
    #         avg_8 = time_acc / num_acc
    #         acc_8 += [avg_8 for _ in range(num_acc)]
    #         num_acc = 0
    #         time_acc = 0.0
    #     else:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]

    # num_acc = 0
    # time_acc = 0.0
    # for i in range(len(accept_length_tree)):
    #     if (i + 1) % 20 == 0:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]
    #         avg_9 = time_acc / num_acc
    #         acc_9 += [avg_9 for _ in range(num_acc)]
    #         num_acc = 0
    #         time_acc = 0.0
    #     else:
    #         num_acc += accept_length_tree[i]
    #         time_acc += individual_token_times[i]

    return acc_0 #, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3
        

    # if exp == 0:
    #     acc = []
    #     for i in range(len(accept_length_tree)):
    #         num_ele = accept_length_tree[i]
    #         avg = individual_token_times[i] / num_ele
    #         acc += [avg for _ in range(num_ele)]
    #     return acc, get_size
    # elif exp == 2:
    #     acc = []
    #     for i in range(len(accept_length_tree)):
    #         avg = individual_token_times[i] / 64
    #         acc += [avg for _ in range(64)]
    #     return acc, get_size
    # else:
    #     acc = []
    #     num_acc = 0
    #     time_acc = 0.0
    #     for i in range(len(accept_length_tree)):
    #         if (i + 1) % exp == 0:
    #             num_acc += accept_length_tree[i]
    #             time_acc += individual_token_times[i]
    #             avg = time_acc / num_acc
    #             acc += [avg for _ in range(num_acc)]
    #             num_acc = 0
    #             time_acc = 0.0
    #         else:
    #             num_acc += accept_length_tree[i]
    #             time_acc += individual_token_times[i]
    #     return acc, get_size

def plot_trace(lst: list):
    # Create ./Outputs directory if it doesn't exist
    output_dir = './out_trace'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # X-axis values corresponding to the number of words processed
    index = list(range(len(lst)))
    

    # Create the plot
    plt.figure(figsize=(8,6))
    plt.plot(index, lst, marker='o')

    # Label the axes
    plt.xlabel('Token Index')
    plt.ylabel('Time per token')

    # Add a title
    plt.title('Time per Token Graph')

    # Show the grid
    plt.grid(True)

    # Save the plot to the ./Outputs directory
    output_path = os.path.join(output_dir, f'iteration_traces.jpg')
    plt.savefig(output_path)

    # Close the plot to avoid display in notebooks or subsequent operations
    plt.close()

    print(f"Plot saved to {output_path}")



def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default="rest",
        help="The name of the ouput_file.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-token",
        type=int,
        default=100,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--top-p",
        type=float,
        default=0.0,
        help="The threshold for nucleus sampling.",
    )

    # REST's hyperparameters

    parser.add_argument(
        "--kind",
        type=str,
        required=False,
        help="The kind to run",
    )

    parser.add_argument(
        "--trials",
        type=int,
        help="The number of trials to run",
    )

    args = parser.parse_args()

    if args.temperature == 0:
        args.top_p = 0
        

    args.model_id = "rest-" + args.model_id+"-temperature-"+str(args.temperature)+"-top_p-"+str(args.top_p)
    if args.num_gpus_total // args.num_gpus_per_model > 1:
        import ray
        ray.init()

    # kind = "original"
    # kind = "semantics_sim"
    # kind = "structural_sim"
    # kind = "prompts"
    # kind = "similar_prompt"
    kind = args.kind
    question_file = f"data/{kind}.txt"

    # trials = 30 if kind == 'original' else 5
    # trials = 1
    trials = args.trials
    
    
    
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"out_trace/temp_{args.temperature}_{kind}_trace_{trials}.out"

    print(f"Output to {answer_file}")

    model, tokenizer = set_up(
        args.model_path,
        # args.datastore_path, 
        args.model_id,
        args.max_new_token,
        args.num_choices,
        args.num_gpus_per_model,
        args.num_gpus_total,
        args.max_gpu_memory,
        args.temperature,
        args.top_p,
        )

    prompts = load_prompts_from_txt(question_file)
    # execute(
    #     model,
    #     tokenizer,
    #     datastore,
    #     args.model_id,
    #     prompts[3],
    #     args.max_new_token,
    #     args.num_choices,
    #     args.num_gpus_per_model,
    #     args.num_gpus_total,
    #     args.max_gpu_memory,
    #     args.temperature,
    #     args.top_p,
    #     args.num_draft,
    #     args.max_token_span,
    # )
    
    # traces_1 = []
    traces_0 = []
    # traces_2 = []
    # traces_3 = []
    # traces_4 = []
    # traces_5 = []
    # traces_6 = []
    # traces_7 = []
    # traces_8 = []
    # traces_9 = []
    # total_time1 = []
    # total_time2 = []
    # total_time3 = []
    labels = []
    # total_size = []
    for j in range(len(prompts)):
        for i in range(trials):
            # acc_1, acc_6, acc_7, acc_8, acc_9 = execute(
            # acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3 = execute(
            acc_0 = execute(
                model,
                tokenizer,
                args.model_id,
                prompts[j],
                args.max_new_token,
                args.num_choices,
                args.num_gpus_per_model,
                args.num_gpus_total,
                args.max_gpu_memory,
                args.temperature,
                args.top_p,
                # exp
            )
           
            # pad_out_0 = pad_list(acc_0, 110)
            # traces_0.append(pad_out_0)
            traces_0.append(acc_0)
            
            labels.append(prompts[j])
            
            print(f"done promt num: {j}, iter: {i}")

    # with open(f"out_trace/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"ex1Out/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"ex2Out/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    with open(f"M4/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
        json.dump({
            "traces": traces_0,
            "labels": labels,
        }, out, separators=(",", ":"))
    
    print(f"Output done to {answer_file}")


    # print(traces_0)
    # traces_0 = [81, 77, 88, 101, 101, 101, 101, 101, 101, 101, 4, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 93, 101, 101, 101, 101, 94, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101]
    # plt.figure(figsize=(8, 6))
    # sns.histplot(traces_0, bins=100, kde=False, color="blue")

    # # Adding labels and title
    # plt.title("Histogram of Trace Lengths")
    # plt.xlabel("Length of acc_0")
    # plt.ylabel("Frequency")

    # # Save the plot as a PNG file
    # plt.savefig("histogram_plot.png", dpi=300, bbox_inches='tight')
    
