from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time 
import os, sys
import pandas as pd 
from datasets import load_dataset
from typing import List
import argparse
# import matplotlib.pyplot as plt
import json
# import openpyxl
import string
import numpy as np
import argparse

def plot_and_save_graph(x, y, xlabel, ylabel, graph_header, file_name):
    """
    Plots a graph using the given x and y data, labels, and header, and saves it to a file.

    Parameters:
    x (list): List of x-axis items
    y (list): List of y-axis variables (same length as x)
    xlabel (str): Label for the x-axis
    ylabel (str): Label for the y-axis
    graph_header (str): Title of the graph
    file_name (str): File name to save the graph
    """
    plt.figure(figsize=(10, 6))
    plt.plot(x, y, marker='o', linestyle='-', color='b')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(graph_header)
    plt.grid(True)
    plt.savefig("Outputs/" + file_name + ".png")
    plt.close()


def setup(model_name, torch_device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
    model.tokenizer = tokenizer

    return model, tokenizer


def warmup(model, tokenizer, torch_device, temperature: float):
    prompt = "How do you fine tune a large language model?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )

    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=1, do_sample=True, temperature=temperature, top_p=0.8)

def executions(model, tokenizer, torch_device, prompt, temperature: float, exp=0):
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a domain expert.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    input_size = model_inputs["input_ids"].size(1)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=100, do_sample=True, temperature=temperature, top_p=0.8)

    log_out = lade.utils.get_all_log()
    all_len = [x["len_out"] for x in log_out]
    all_tok = [x["num_tok"] for x in log_out]
    lade.utils.clear_log()
    return all_len, all_tok
    

def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def load_xlsx(file_path, max_row=51, min_row=2, min_col=13, max_col=13):
    workbook = openpyxl.load_workbook(file_path)
    sheet = workbook.active

    # Extract the prompts from the M column, starting from the second row
    prompts = []
    for row in sheet.iter_rows(min_row, max_row, min_col, max_col):
        for cell in row:
            prompts.append(cell.value)

    # Display the first 50 prompts
    for i, prompt in enumerate(prompts, start=1):
        print(f"{prompt}\n")

    return prompts


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    args = parser.parse_args()

    if int(os.environ.get("LOAD_LADE", 0)):
        import lade
        lade.augment_all()
        #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7 
        lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=0, POOL_FROM_PROMPT=True)

    assert torch.cuda.is_available()

    torch_device = 0

    temperature = args.temperature

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model, tokenizer = setup(model_name, torch_device)
    # prompts = load_prompts_from_txt(question_file)


    warmup(model, tokenizer, torch_device, temperature)
    
    # p_id = 0
    # all_len_0, all_tok_0 = executions(model, tokenizer, torch_device, "What is a narcissistic personality disorder?", temperature)
    # all_len_1, all_tok_1 = executions(model, tokenizer, torch_device, "What is a narcissistic personality disorder?", temperature)
    # with open(f"trace_out/p_{p_id}_t_0.out", "w") as out:
    #     json.dump({
    #         "lens": all_len_0,
    #         "toks": all_tok_0,
    #     }, out, separators=(",", ":"))
    
    # with open(f"trace_out/p_{p_id}_t_1.out", "w") as out:
    #     json.dump({
    #         "lens": all_len_1,
    #         "toks": all_tok_1,
    #     }, out, separators=(",", ":"))

    # p_id = 1
    # all_len_0, all_tok_0 = executions(model, tokenizer, torch_device, "What are the symptoms of cancer?", temperature)
    # all_len_1, all_tok_1 = executions(model, tokenizer, torch_device, "What are the symptoms of cancer?", temperature)
    # with open(f"trace_out/p_{p_id}_t_0.out", "w") as out:
    #     json.dump({
    #         "lens": all_len_0,
    #         "toks": all_tok_0,
    #     }, out, separators=(",", ":"))
    
    # with open(f"trace_out/p_{p_id}_t_1.out", "w") as out:
    #     json.dump({
    #         "lens": all_len_1,
    #         "toks": all_tok_1,
    #     }, out, separators=(",", ":"))

    # all_len_0, all_tok_0 = executions(model, tokenizer, torch_device, "How to fine-tune a large language model?", temperature)
    # with open(f"trace_out/f1.out", "w") as out:
    #     json.dump({
    #         "lens": all_len_0,
    #         "toks": all_tok_0,
    #     }, out, separators=(",", ":"))

    all_len_0, all_tok_0 = executions(model, tokenizer, torch_device, "I found a lump, should I be worried?", temperature)
    with open(f"trace_out/f2.out", "w") as out:
        json.dump({
            "lens": all_len_0,
            "toks": all_tok_0,
        }, out, separators=(",", ":"))

    
    