from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time 
import os, sys
import pandas as pd 
from datasets import load_dataset
from typing import List
import argparse
# import matplotlib.pyplot as plt
import json
# import openpyxl
import string
import numpy as np
import argparse

def plot_and_save_graph(x, y, xlabel, ylabel, graph_header, file_name):
    """
    Plots a graph using the given x and y data, labels, and header, and saves it to a file.

    Parameters:
    x (list): List of x-axis items
    y (list): List of y-axis variables (same length as x)
    xlabel (str): Label for the x-axis
    ylabel (str): Label for the y-axis
    graph_header (str): Title of the graph
    file_name (str): File name to save the graph
    """
    plt.figure(figsize=(10, 6))
    plt.plot(x, y, marker='o', linestyle='-', color='b')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(graph_header)
    plt.grid(True)
    plt.savefig("Outputs/" + file_name + ".png")
    plt.close()


def setup(model_name, torch_device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
    model.tokenizer = tokenizer

    return model, tokenizer


def warmup(model, tokenizer, torch_device, temperature: float):
    prompt = "How do you fine tune a large language model?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )

    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=1, do_sample=True, temperature=temperature, top_p=0.8)


def executions(model, tokenizer, torch_device, prompt, temperature: float, exp=0):
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a domain expert.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    input_size = model_inputs["input_ids"].size(1)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=100, do_sample=True, temperature=temperature, top_p=0.8)
    # print("Sample output: ", tokenizer.decode(output[0], skip_special_tokens=False))

    # all_len = [x["len_out"] for x in lade.utils.get_all_log()]
    # lade.utils.clear_log()
    # return all_len

    all_len = [x["num_tok"] for x in lade.utils.get_all_log()]
    lade.utils.clear_log()
    return all_len


    ####################################################################################
    # Archived Experiments for Mitigation for section 6
    ####################################################################################
    # acc_1 = []
    # acc_5 = []
    # acc_6 = []
    # acc_7 = []
    # acc_8 = []
    # acc_9 = []

    # num_acc = 0
    # for i, leng in enumerate(all_len):
    #     if (i + 1) % 3 == 0:
    #         num_acc += leng
    #         acc_6.append(num_acc)
    #         num_acc = 0
    #     else:
    #         num_acc += leng

    # num_acc = 0
    # for i, leng in enumerate(all_len):
    #     if (i + 1) % 5 == 0:
    #         num_acc += leng
    #         acc_7.append(num_acc)
    #         num_acc = 0
    #     else:
    #         num_acc += leng
    
    # num_acc = 0
    # for i, leng in enumerate(all_len):
    #     if (i + 1) % 10 == 0:
    #         num_acc += leng
    #         acc_8.append(num_acc)
    #         num_acc = 0
    #     else:
    #         num_acc += leng
    
    # num_acc = 0
    # for i, leng in enumerate(all_len):
    #     if (i + 1) % 20 == 0:
    #         num_acc += leng
    #         acc_9.append(num_acc)
    #         num_acc = 0
    #     else:
    #         num_acc += leng

    # lade.utils.clear_log()

    # return all_len, [10 for _ in range(len(all_len))], acc_6, acc_7, acc_8, acc_9


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def load_xlsx(file_path, max_row=51, min_row=2, min_col=13, max_col=13):
    workbook = openpyxl.load_workbook(file_path)
    sheet = workbook.active

    # Extract the prompts from the M column, starting from the second row
    prompts = []
    for row in sheet.iter_rows(min_row, max_row, min_col, max_col):
        for cell in row:
            prompts.append(cell.value)

    # Display the first 50 prompts
    for i, prompt in enumerate(prompts, start=1):
        print(f"{prompt}\n")

    return prompts


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--kind",
        type=str,
        required=False,
        help="The kind to run",
    )

    parser.add_argument(
        "--trials",
        type=int,
        help="The number of trials to run",
    )

    args = parser.parse_args()

    if int(os.environ.get("LOAD_LADE", 0)):
        import lade
        lade.augment_all()
        #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7 
        lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=0, POOL_FROM_PROMPT=True)

    assert torch.cuda.is_available()

    torch_device = 0
    kind = args.kind
    question_file = f"data/{kind}.txt"
    trials = args.trials
    temperature = args.temperature

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model, tokenizer = setup(model_name, torch_device)
    prompts = load_prompts_from_txt(question_file)

    # prompts = [prompts[0]]
    # trails = 1

    # file_path = "data/prompts.txt"
    
    
    # file_path = "data/medical_prompts.xlsx"
    # num_prompt = 50
    # prompts = load_xlsx(file_path, 1 + num_prompt)
    # prompts = [
    #     "How do you fine tune a large language model?",
    #     "Which methods did Socrates employ to challenge the prevailing thoughts of his time?",
    #     "Can you explain the concept of blockchain and how it works?",
    #     "What are the key principles of effective project management?"
    # ]

    warmup(model, tokenizer, torch_device, temperature)

    # executions(model, tokenizer, torch_device, prompts[0])

    traces = []
    labels = []
    for j in range(len(prompts)):
        for i in range(trials):
            acc_1 = executions(model, tokenizer, torch_device, prompts[j], temperature)
            traces.append(acc_1)
            labels.append(prompts[j])
            print(f"done promt num: {j}, iter: {i}")

    # with open(f"Newexp/temp_{temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"exp1NewExp/temp_{temperature}_{kind}_trace_{trials}.out", "w") as out:
    # # with open(f"exp2NewExp/temp_{temperature}_{kind}_trace_{trials}.out", "w") as out:
    # with open(f"M0/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    with open(f"T3/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
        json.dump({
            "traces": traces,
            "labels": labels,
        }, out, separators=(",", ":"))

    #############################################################################
    # Archived Experiments for Mitigation Section 6
    #############################################################################
    # with open(f"M2/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #         "total_size": total_time1
    #     }, out, separators=(",", ":"))

    # with open(f"M3/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #         "total_size": total_time2
    #     }, out, separators=(",", ":"))

    # with open(f"M4/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #         "total_size": total_time3
    #     }, out, separators=(",", ":"))

    # with open(f"M5/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    # with open(f"M6/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M7/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M8/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))

    # with open(f"M9/{kind}_temp_{args.temperature}_t_{trials}.out", "w") as out:
    #     json.dump({
    #         "traces": traces,
    #         "labels": labels,
    #     }, out, separators=(",", ":"))
    
print(f"done {kind}, trials {trials}, temp {temperature}")
    
