from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time 
import os, sys
import pandas as pd 
from datasets import load_dataset
from typing import List
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import json
# import openpyxl
import string
import numpy as np
import argparse

def plot_and_save_graph(x, y, xlabel, ylabel, graph_header, file_name):
    """
    Plots a graph using the given x and y data, labels, and header, and saves it to a file.

    Parameters:
    x (list): List of x-axis items
    y (list): List of y-axis variables (same length as x)
    xlabel (str): Label for the x-axis
    ylabel (str): Label for the y-axis
    graph_header (str): Title of the graph
    file_name (str): File name to save the graph
    """
    plt.figure(figsize=(10, 6))
    plt.plot(x, y, marker='o', linestyle='-', color='b')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(graph_header)
    plt.grid(True)
    plt.savefig("Outputs/" + file_name + ".png")
    plt.close()


def setup(model_name, torch_device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
    model.tokenizer = tokenizer

    return model, tokenizer


def warmup(model, tokenizer, torch_device, temperature: float):
    prompt = "How do you fine tune a large language model?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )

    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=1, do_sample=True, temperature=temperature, top_p=0.8)


def executions(model, tokenizer, torch_device, prompt, temperature: float, exp=0):
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a domain expert.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    input_size = model_inputs["input_ids"].size(1)

    os.environ["CHAT"] = "1"

    output = model.generate(**model_inputs, max_new_tokens=100, do_sample=False, temperature=temperature)
    # print("Sample output: ", tokenizer.decode(output[0], skip_special_tokens=False))
    all_len = [x["len_out"] for x in lade.utils.get_all_log()]
    lade.utils.clear_log()
    return all_len


    # all_diffs = lade.utils.get_all_log()
    # time_diffs = [x["time"] for x in all_diffs][1:]

    # full_size = output.size(1) - input_size
    
    # acc_0 = []
    # for i in range(len(time_diffs)):
    #     acc_0.append(time_diffs[i])


    # lade.utils.clear_log()

    # return acc_0
    # return acc_1, acc_6, acc_7, acc_8, acc_9


    # if exp == 0:
    #     acc = []
    #     for i in range(len(time_diffs)):
    #         prev = all_tokens[i]
    #         last = all_tokens[i + 1]
    #         num_tokens = last - prev
    #         if num_tokens == 0:
    #             avg = time_diffs[i]
    #         else:
    #             avg = time_diffs[i] / num_tokens
    #         curr = [avg for _ in range(num_tokens)]
    #         acc += curr
    # elif exp == 1:
    #     acc = []
    #     for i in range(len(time_diffs)):
    #         random_float = np.random.uniform(0, 0.004)
    #         get_size += random_float
    #         prev = all_tokens[i]
    #         last = all_tokens[i + 1]
    #         num_tokens = last - prev
    #         if num_tokens == 0:
    #             avg = time_diffs[i] + random_float
    #         else:
    #             avg = (time_diffs[i] + random_float) / num_tokens
    #         curr = [avg for _ in range(num_tokens)]
    #         acc += curr
    # elif exp == 2:
    #     acc = []
    #     for i in range(len(time_diffs)):
    #         # random_value = np.random.choice([0, 1])
    #         avg = time_diffs[i] / 6
    #         curr = [avg for _ in range(6)]
    #         acc += curr
    # else:
    #     acc = []
    #     num_acc = 0
    #     time_acc = 0.0
    #     for i in range(len(time_diffs)):
    #         if (i + 1) % exp == 0:
    #             prev = all_tokens[i]
    #             last = all_tokens[i + 1]
    #             num_tokens = last - prev
    #             num_acc += num_tokens
    #             time_acc += time_diffs[i]
    #             avg = time_acc / num_acc
    #             acc += [avg for _ in range(num_acc)]
    #             num_acc = 0
    #             time_acc = 0.0
    #         else:
    #             prev = all_tokens[i]
    #             last = all_tokens[i + 1]
    #             num_tokens = last - prev
    #             num_acc += num_tokens
    #             time_acc += time_diffs[i]

    # acc = []
    # boo_acc = []
    # for i in range(len(time_diffs)):
    #     prev = all_tokens[i]
    #     last = all_tokens[i + 1]
    #     num_tokens = last - prev
    #     if num_tokens == 0:
    #         avg = time_diffs[i]
    #     else:
    #         avg = time_diffs[i] / num_tokens
    #     curr = [avg for _ in range(num_tokens)]
    #     acc += curr
    #     if num_tokens > 1:
    #         boo_curr = [0 for _ in range(num_tokens)]
    #     else:
    #         boo_curr = [1]
    #     # boo_curr += [1]
    #     boo_acc += boo_curr

    # lade.utils.clear_log()

    # return acc, get_size


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def load_xlsx(file_path, max_row=51, min_row=2, min_col=13, max_col=13):
    workbook = openpyxl.load_workbook(file_path)
    sheet = workbook.active

    # Extract the prompts from the M column, starting from the second row
    prompts = []
    for row in sheet.iter_rows(min_row, max_row, min_col, max_col):
        for cell in row:
            prompts.append(cell.value)

    # Display the first 50 prompts
    for i, prompt in enumerate(prompts, start=1):
        print(f"{prompt}\n")

    return prompts


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--kind",
        type=str,
        required=False,
        help="The kind to run",
    )

    parser.add_argument(
        "--trials",
        type=int,
        help="The number of trials to run",
    )

    args = parser.parse_args()

    assert torch.cuda.is_available()
    import lade
    lade.augment_generate()
    lade.config_lade(DEBUG=0, POOL_FROM_PROMPT=True)

    torch_device = 0
    kind = args.kind
    question_file = f"data/{kind}.txt"
    trials = args.trials
    temperature = args.temperature

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model, tokenizer = setup(model_name, torch_device)
    prompts = load_prompts_from_txt(question_file)

    # file_path = "data/prompts.txt"
    
    
    # file_path = "data/medical_prompts.xlsx"
    # num_prompt = 50
    # prompts = load_xlsx(file_path, 1 + num_prompt)
    # prompts = [
    #     "How do you fine tune a large language model?",
    #     "Which methods did Socrates employ to challenge the prevailing thoughts of his time?",
    #     "Can you explain the concept of blockchain and how it works?",
    #     "What are the key principles of effective project management?"
    # ]

    warmup(model, tokenizer, torch_device, temperature)

    # executions(model, tokenizer, torch_device, prompts[0])

    traces_0 = []
    labels = []
    for j in range(len(prompts)):
        for i in range(trials):
            acc_0 = executions(model, tokenizer, torch_device, prompts[j], temperature)
            # acc_1, acc_6, acc_7, acc_8, acc_9 = executions(model, tokenizer, torch_device, prompts[j], temperature)
        
            # pad_out_0 = pad_list(acc_0, 110)
            traces_0.append(acc_0)
            
            labels.append(prompts[j])
            
            print(f"done promt num: {j}, iter: {i}")

    with open(f"M4/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
        json.dump({
            "traces": traces_0,
            "labels": labels,
        }, out, separators=(",", ":"))

    print(f"done {kind}, trials {trials}, temp {temperature}")

    # new_t = []
    # for trace in traces_0:
    #     new_t.append(len(trace))

    # print(new_t)
    # new_t = [99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99]
    # plt.figure(figsize=(8, 6))
    # sns.histplot(new_t, bins=100, kde=False, color="blue")

    # # Adding labels and title
    # plt.title("Histogram of Trace Lengths")
    # plt.xlabel("Length of acc_1")
    # plt.ylabel("Frequency")

    # # Save the plot as a PNG file
    # plt.savefig("histogram_plot.png", dpi=300, bbox_inches='tight')
    
