"""Generate answers with local models.

Usage:
python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
"""
import argparse
import json
import os
import random
import time
import shortuuid
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import string

from fastchat.llm_judge.common import load_questions
from fastchat.model import load_model, get_conversation_template


# Rest imports
import transformers

import sys
sys.path.append("../")

from rest.model.utils import *
from rest.model.rest_model import RestModel
from rest.model.kv_cache import initialize_past_key_values
import draftretriever


def rest_forward(input_ids, model, tokenizer, max_new_token, temperature, top_p, datastore, num_draft, token_spans, max_steps=1024):
    assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
    # Avoid modifying the input_ids in-place
    input_ids = input_ids.clone()
    # accept_length_list = []

    # Initialize the past key and value states
    if hasattr(model, "past_key_values"):
        past_key_values = model.past_key_values
        past_key_values_data = model.past_key_values_data
        current_length_data = model.current_length_data
        # Reset the past key and value states
        current_length_data.zero_()
    else:
        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(model.base_model)
        model.past_key_values = past_key_values
        model.past_key_values_data = past_key_values_data
        model.current_length_data = current_length_data

    input_len = input_ids.shape[1]

    # David: In the unchanged implementation, the final output of 
    #        sum(accept_length_list) equals to the total number of 
    #        new tokens - 1. This +1 is what results in this - 1. 
    #        Since sum(accept_length_list) is not the list of real
    #        accepted length of each iteration, there is no reason
    #        for this + 1 to be here, so I am pretty sure this 
    #        is a bug.

    # David: REST has to predict around half wrong, if sum(accept_length_list)
    #        Is the list of real accepted length of each iteration, then that means
    #        Only the first token is mis-speculated, and all other tokens are 
    #        correctly speculated, which is impossible
    # cur_length = input_len + 1
    cur_length = input_len
    model.base_model.model.draft_mask = None
    logits = initialize_logits(
            input_ids, model, past_key_values
    )
    new_token = 0
    # individual_token_times = []
    torch.cuda.synchronize()
    accepted_len = []
    
    for idx in range(max_steps): 
        # start_time = time.time()
        # David: Match and retrieve speculation Candidates from datastore
        candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer(
                logits,
                input_ids,
                datastore,
                token_spans,
                top_p,
                temperature,
                max_num_draft=num_draft,
                device=model.base_model.device
            )
        model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"]

        # David: one pass parallel verification of all candidates using the real target model
        
        logits, outputs = tree_decoding(
                model,
                tree_candidates,
                past_key_values,
                draft_buffers["draft_position_ids"],
                input_ids,
                draft_buffers["retrieve_indices"],
            )
        

        # Daivd: Take the target model output, use argmax on logits to find the true output
        #        of the target model at each layer of the speculation tree, and walk down 
        #        the tree, asking each tokens: "Are you the token target model outputs?"
        #        If yes, that token is accepted

        # David: Then we choose the longest path down from the root of the tree. This is 
        #        best_candidate, len(best_candidate) = accept_length is the real
        #        accepted length of this iteration
        best_candidate, accept_length = evaluate_posterior(
                logits, candidates, temperature, top_p
            )
        
        # David: Besides best_candidate, we take one more token, this is the target model's
        #        generation. In the worst case, this token makes it at least autoregression
        input_ids, logits, new_token = update_inference_inputs(
                input_ids,
                candidates,
                best_candidate,
                accept_length,
                draft_buffers["retrieve_indices"],
                outputs,
                logits,
                new_token,
                past_key_values_data,
                current_length_data,
            )
        # end_time = time.time()
        # individual_token_times.append(end_time - start_time)
        output = tokenizer.decode(
            input_ids[0][cur_length:], 
            spaces_between_special_tokens=False,
            skip_special_tokens=True
        )
        # print(output, end="    ")
        len_out = len(output)
        accepted_len.append(len_out)
        accept_length_tree = input_ids.shape[1] - cur_length
        cur_length = accept_length_tree + cur_length
        # accept_length_list.append(accept_length_tree)
        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_token:
            break
    # return input_ids, new_token, idx, accept_length_tree, individual_token_times
    return input_ids, new_token, idx, accepted_len

def set_up(
    model_path, 
    datastore_path, 
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    num_draft,
    max_token_span,
    ):
    print("loading the datastore ...")
    datastore = draftretriever.Reader(
                index_file_path=datastore_path,
            )
    print("datastore loaded!")

    model = RestModel.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto"
    )

    tokenizer = model.get_tokenizer()
    
    model.eval()
    print('Check model training state:',model.training)
    
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
    print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
    token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            warmup
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = warmup

    # warmup
    get_answers_func(
        model,
        tokenizer,
        model_id,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )

    return datastore, model, tokenizer


def warmup(
    model,
    tokenizer,
    model_id,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    datastore,
    num_draft,
    token_spans,
    ):
    torch.manual_seed(0)
    conv = get_conversation_template(model_id)
    qs = "what is Socrates best known for?"
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids
    # print(torch.as_tensor(input_ids).dtype)
    # assert False
    # some models may error out when generating long outputs
    try:
        rest_forward(
            torch.as_tensor(input_ids).cuda(),
            model,
            tokenizer,
            max_new_token,
            temperature,
            top_p,
            datastore,
            num_draft,
            token_spans,
        )
    except RuntimeError as e:
        print(f"warm up errored out with {e}")
        output = "ERROR"
    print('Warmup done')



def execute(
    model,
    tokenizer,
    datastore,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    max_gpu_memory,
    temperature,
    top_p,
    num_draft,
    max_token_span,
    exp=0
):
    token_spans = list(range(2, max_token_span+1))[::-1]
    
    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        # get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
        #     sub_executions
        # ).remote
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            qsstr_in_spec_tok_out
        ).remote
    else:
        # get_answers_func = sub_executions
        get_answers_func = qsstr_in_spec_tok_out

    # accept_length_tree, individual_token_times = get_answers_func(
    accept_len = get_answers_func(
        model,
        tokenizer,
        model_id,
        question,
        max_new_token,
        num_choices,
        num_gpus_per_model,
        max_gpu_memory,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )
    
    # return get_trace_lst(accept_length_tree, individual_token_times, exp)
    return accept_len
    # return get_trace_lst(accept_len)

    # print(f"expect length: {sum(accept_length_tree)}")
    # print(f"get length: {len(trace_out)}")
    # plot_trace(trace_out)
    
    # if use_ray:
    #     ray.get(ans_handles)

    # return acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3


@torch.inference_mode()
def qsstr_in_spec_tok_out(
    model,
    tokenizer,
    model_id,
    question,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    max_gpu_memory,
    temperature,
    top_p,
    datastore,
    num_draft,
    token_spans,
):
    

    torch.manual_seed(100)
    conv = get_conversation_template(model_id)

    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids
    set_out = set()
    # some models may error out when generating long outputs
    # try:
    output_ids, new_token, idx, accept_len = rest_forward(
        torch.as_tensor(input_ids).cuda(),
        model,
        tokenizer,
        max_new_token,
        temperature,
        top_p,
        datastore,
        num_draft,
        token_spans,
    )

    return accept_len
        # output_ids, new_token, idx, accept_length_tree, individual_token_times = rest_forward(
        #     torch.as_tensor(input_ids).cuda(),
        #     model,
        #     tokenizer,
        #     max_new_token,
        #     temperature,
        #     top_p,
        #     datastore,
        #     num_draft,
        #     token_spans,
        # )
        
        
        # output_ids = output_ids[0][len(input_ids[0]) :]
        
        # # be consistent with the template's stop_token_ids
        # if conv.stop_token_ids:
        #     stop_token_ids_index = [
        #         i
        #         for i, id in enumerate(output_ids)
        #         if id in conv.stop_token_ids
        #     ]
        #     if len(stop_token_ids_index) > 0:
        #         output_ids = output_ids[: stop_token_ids_index[0]]

        # output = tokenizer.decode(
        #     output_ids,
        #     spaces_between_special_tokens=False,
        # )
        # if conv.stop_str and output.find(conv.stop_str) > 0:
        #     output = output[: output.find(conv.stop_str)]

        # if conv.name == "xgen" and output.startswith("Assistant:"):
        #     output = output.replace("Assistant:", "", 1).strip()
    # except RuntimeError as e:
    #     print("ERROR question : ", question)
    #     output = "ERROR"

    # print(output)
    # print("\n\n")
    # print(f"collected total: {sum(accept_length_tree)}")
    # print(f"real total: {new_token}")
    # print(f"token iterations: {len(accept_length_tree)}")
    # print(f"time iterations: {len(individual_token_times)}")
    # return accept_length_tree, individual_token_times

    
    
# def get_trace_lst(accept_length_tree, individual_token_times, exp=0):
#     time1, time2, time3 = 0, 0, 0
#     acc_0 = []
#     acc_1 = []
#     acc_2 = []
#     acc_3 = []
#     acc_4 = []
#     acc_5 = []
#     acc_6 = []
#     acc_7 = []
#     acc_8 = []
#     acc_9 = []
#     for i in range(len(accept_length_tree)):

#         avg_0 = individual_token_times[i]
#         acc_0 += [avg_0]

#         num_ele = accept_length_tree[i]
#         avg_1 = individual_token_times[i] / num_ele
#         acc_1 += [avg_1 for _ in range(num_ele)]

#         random_float = np.random.uniform(0, 0.04)
#         time1 += random_float
#         num_ele = accept_length_tree[i]
#         avg_2 = (individual_token_times[i] + random_float) / num_ele
#         acc_2 += [avg_2 for _ in range(num_ele)]

#         random_float = np.random.uniform(0, 0.08)
#         time2 += random_float
#         num_ele = accept_length_tree[i]
#         avg_3 = (individual_token_times[i] + random_float) / num_ele
#         acc_3 += [avg_3 for _ in range(num_ele)]

#         random_float = np.random.uniform(0, 0.12)
#         time3 += random_float
#         num_ele = accept_length_tree[i]
#         avg_4 = (individual_token_times[i] + random_float) / num_ele
#         acc_4 += [avg_4 for _ in range(num_ele)]

#         avg_5 = individual_token_times[i] / 64
#         acc_5 += [avg_5 for _ in range(64)]

#     num_acc = 0
#     time_acc = 0.0
#     for i in range(len(accept_length_tree)):
#         if (i + 1) % 3 == 0:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]
#             avg_6 = time_acc / num_acc
#             acc_6 += [avg_6 for _ in range(num_acc)]
#             num_acc = 0
#             time_acc = 0.0
#         else:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]

#     num_acc = 0
#     time_acc = 0.0
#     for i in range(len(accept_length_tree)):
#         if (i + 1) % 5 == 0:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]
#             avg_7 = time_acc / num_acc
#             acc_7 += [avg_7 for _ in range(num_acc)]
#             num_acc = 0
#             time_acc = 0.0
#         else:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]

#     num_acc = 0
#     time_acc = 0.0
#     for i in range(len(accept_length_tree)):
#         if (i + 1) % 10 == 0:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]
#             avg_8 = time_acc / num_acc
#             acc_8 += [avg_8 for _ in range(num_acc)]
#             num_acc = 0
#             time_acc = 0.0
#         else:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]

#     num_acc = 0
#     time_acc = 0.0
#     for i in range(len(accept_length_tree)):
#         if (i + 1) % 20 == 0:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]
#             avg_9 = time_acc / num_acc
#             acc_9 += [avg_9 for _ in range(num_acc)]
#             num_acc = 0
#             time_acc = 0.0
#         else:
#             num_acc += accept_length_tree[i]
#             time_acc += individual_token_times[i]

#     return acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3
        

    # if exp == 0:
    #     acc = []
    #     for i in range(len(accept_length_tree)):
    #         num_ele = accept_length_tree[i]
    #         avg = individual_token_times[i] / num_ele
    #         acc += [avg for _ in range(num_ele)]
    #     return acc, get_size
    # elif exp == 2:
    #     acc = []
    #     for i in range(len(accept_length_tree)):
    #         avg = individual_token_times[i] / 64
    #         acc += [avg for _ in range(64)]
    #     return acc, get_size
    # else:
    #     acc = []
    #     num_acc = 0
    #     time_acc = 0.0
    #     for i in range(len(accept_length_tree)):
    #         if (i + 1) % exp == 0:
    #             num_acc += accept_length_tree[i]
    #             time_acc += individual_token_times[i]
    #             avg = time_acc / num_acc
    #             acc += [avg for _ in range(num_acc)]
    #             num_acc = 0
    #             time_acc = 0.0
    #         else:
    #             num_acc += accept_length_tree[i]
    #             time_acc += individual_token_times[i]
    #     return acc, get_size

def plot_trace(lst: list):
    # Create ./Outputs directory if it doesn't exist
    output_dir = './out_trace'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # X-axis values corresponding to the number of words processed
    index = list(range(len(lst)))
    

    # Create the plot
    plt.figure(figsize=(8,6))
    plt.plot(index, lst, marker='o')

    # Label the axes
    plt.xlabel('Token Index')
    plt.ylabel('Time per token')

    # Add a title
    plt.title('Time per Token Graph')

    # Show the grid
    plt.grid(True)

    # Save the plot to the ./Outputs directory
    output_path = os.path.join(output_dir, f'iteration_traces.jpg')
    plt.savefig(output_path)

    # Close the plot to avoid display in notebooks or subsequent operations
    plt.close()

    print(f"Plot saved to {output_path}")


# def get_trace_lst(accept_len):
#     acc_6 = []
#     acc_7 = []
#     acc_8 = []
#     acc_9 = []

#     num_acc = 0
#     for i in range(len(accept_len)):
#         if (i + 1) % 3 == 0:
#             num_acc += accept_len[i]
#             acc_6.append(num_acc)
#             num_acc = 0
#         else:
#             num_acc += accept_len[i]

#     num_acc = 0
#     for i in range(len(accept_len)):
#         if (i + 1) % 5 == 0:
#             num_acc += accept_len[i]
#             acc_7.append(num_acc)
#             num_acc = 0
#         else:
#             num_acc += accept_len[i]
    
#     num_acc = 0
#     for i in range(len(accept_len)):
#         if (i + 1) % 10 == 0:
#             num_acc += accept_len[i]
#             acc_8.append(num_acc)
#             num_acc = 0
#         else:
#             num_acc += accept_len[i]
    
#     num_acc = 0
#     for i in range(len(accept_len)):
#         if (i + 1) % 20 == 0:
#             num_acc += accept_len[i]
#             acc_9.append(num_acc)
#             num_acc = 0
#         else:
#             num_acc += accept_len[i]

#     return accept_len, [10 for _ in range(len(accept_len))], acc_6, acc_7, acc_8, acc_9


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--output_name",
        type=str,
        default="rest",
        help="The name of the ouput_file.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-token",
        type=int,
        default=100,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--top-p",
        type=float,
        default=0.0,
        help="The threshold for nucleus sampling.",
    )

    # REST's hyperparameters
    parser.add_argument(
        "--datastore-path",
        type=str,
        required=True,
        help="The path of the datastore for retrival.",
    )

    parser.add_argument(
        "--num-draft",
        type=int,
        default=64,
        help="The maximum number of draft tokens.",
    )
    parser.add_argument(
        "--max-token-span",
        type=int,
        default=32,
        help="The maximum length of suffix for retrieval.",
    )

    parser.add_argument(
        "--kind",
        type=str,
        required=False,
        help="The kind to run",
    )

    parser.add_argument(
        "--trials",
        type=int,
        help="The number of trials to run",
    )

    args = parser.parse_args()

    if args.temperature == 0:
        args.top_p = 0
        

    args.model_id = "rest-" + args.model_id+"-temperature-"+str(args.temperature)+"-top_p-"+str(args.top_p)
    if args.num_gpus_total // args.num_gpus_per_model > 1:
        import ray
        ray.init()

    # kind = "original"
    # kind = "semantics_sim"
    # kind = "structural_sim"
    # kind = "prompts"
    # kind = "similar_prompt"
    kind = args.kind
    question_file = f"data/{kind}.txt"

    # trials = 30 if kind == 'original' else 5
    # trials = 30
    # exp = 2
    trials = args.trials
    
    
    
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"out_trace/temp_{args.temperature}_{kind}_trace_{trials}.out"

    print(f"Output to {answer_file}")

    datastore, model, tokenizer = set_up(
        args.model_path,
        args.datastore_path, 
        args.model_id,
        args.max_new_token,
        args.num_choices,
        args.num_gpus_per_model,
        args.num_gpus_total,
        args.max_gpu_memory,
        args.temperature,
        args.top_p,
        args.num_draft,
        args.max_token_span,
        )

    prompts = load_prompts_from_txt(question_file)


    # prompts = [prompts[0]]
    # trails = 1



    # execute(
    #     model,
    #     tokenizer,
    #     datastore,
    #     args.model_id,
    #     prompts[3],
    #     args.max_new_token,
    #     args.num_choices,
    #     args.num_gpus_per_model,
    #     args.num_gpus_total,
    #     args.max_gpu_memory,
    #     args.temperature,
    #     args.top_p,
    #     args.num_draft,
    #     args.max_token_span,
    # )
    
    traces_1 = []
    # traces_0 = []
    # traces_2 = []
    # traces_3 = []
    # traces_4 = []
    # traces_5 = []
    # traces_6 = []
    # traces_7 = []
    # traces_8 = []
    # traces_9 = []
    # total_time1 = []
    # total_time2 = []
    # total_time3 = []
    labels = []
    # total_size = []
    for j in range(len(prompts)):
        for i in range(trials):
            acc_1 = execute(
            # acc_0, acc_1, acc_2, acc_3, acc_4, acc_5, acc_6, acc_7, acc_8, acc_9, time1, time2, time3 = execute(
                model,
                tokenizer,
                datastore,
                args.model_id,
                prompts[j],
                args.max_new_token,
                args.num_choices,
                args.num_gpus_per_model,
                args.num_gpus_total,
                args.max_gpu_memory,
                args.temperature,
                args.top_p,
                args.num_draft,
                args.max_token_span,
                # exp
            )
            # pad_out_1 = pad_list(acc_1, 110)
            traces_1.append(acc_1)
            # pad_out_0 = pad_list(acc_0, 110)
            # traces_0.append(acc_0)
            # pad_out_2 = pad_list(acc_2, 110)
            # traces_2.append(acc_2)
            # pad_out_3 = pad_list(acc_3, 110)
            # traces_3.append(acc_3)
            # # pad_out_4 = pad_list(acc_4, 110)
            # traces_4.append(acc_4)
            # pad_out_5 = pad_list(acc_5, 110)
            # traces_5.append(acc_5)
            # pad_out_6 = pad_list(acc_6, 110)
            # traces_6.append(acc_6)
            # pad_out_7 = pad_list(acc_7, 110)
            # traces_7.append(acc_7)
            # pad_out_8 = pad_list(acc_8, 110)
            # traces_8.append(acc_8)
            # pad_out_9 = pad_list(acc_9, 110)
            # traces_9.append(acc_9)
            labels.append(prompts[j])
            # total_time1.append(time1)
            # total_time2.append(time2)
            # total_time3.append(time3)
            print(f"done promt num: {j}, iter: {i}")

    # with open(f"out_trace/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"ex1Out/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"ex2Out/temp_{args.temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"M0/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_0,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    with open(f"M3/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
        json.dump({
            "traces": traces_1,
            "labels": labels,
        }, out, separators=(",", ":"))

    # with open(f"M2/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_2,
    #         "labels": labels,
    #         "total_size": total_time1
    #     }, out, separators=(",", ":"))

    # with open(f"M3/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_3,
    #         "labels": labels,
    #         "total_size": total_time2
    #     }, out, separators=(",", ":"))

    # with open(f"M4/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_4,
    #         "labels": labels,
    #         "total_size": total_time3
    #     }, out, separators=(",", ":"))

    # with open(f"M5/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_5,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    # with open(f"M6/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_6,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M7/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_7,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M8/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_8,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M9/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces_9,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    print(f"Output done to {answer_file}")
    
