from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time 
import os, sys
import pandas as pd 
from datasets import load_dataset
from typing import List
import argparse
import matplotlib.pyplot as plt
import json
import openpyxl


def plot_and_save_graph(x, y, xlabel, ylabel, graph_header, file_name):
    """
    Plots a graph using the given x and y data, labels, and header, and saves it to a file.

    Parameters:
    x (list): List of x-axis items
    y (list): List of y-axis variables (same length as x)
    xlabel (str): Label for the x-axis
    ylabel (str): Label for the y-axis
    graph_header (str): Title of the graph
    file_name (str): File name to save the graph
    """
    plt.figure(figsize=(12, 8))
    plt.plot(x, y, marker='o', linestyle='-', color='b', linewidth=2)
    plt.xlabel(xlabel, fontsize=35)
    plt.ylabel(ylabel, fontsize=35)
    plt.xticks(fontsize=30)
    plt.yticks(fontsize=30)
    plt.savefig(file_name + ".png")
    plt.savefig(file_name + ".pdf", format="pdf")
    plt.close()


def setup(model_name, torch_device):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
    model.tokenizer = tokenizer

    return model, tokenizer


def warmup(model, tokenizer, torch_device):
    prompt = "How do you fine tune a large language model?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )

    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    os.environ["CHAT"] = "1"

    model.generate(**model_inputs, max_new_tokens=1)
    lade.utils.clear_log()


def executions(model, tokenizer, torch_device, prompt):
    input_text = (
        f"<|system|>\nYou are a chatbot who always responds in the style of a domain expert.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    input_size = model_inputs["input_ids"].size(1)

    os.environ["CHAT"] = "1"
    greedy_output = model.generate(**model_inputs, do_sample=False)
    all_len = [x["num_tok"] for x in lade.utils.get_all_log()]
    lade.utils.clear_log()
    return all_len


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ChatExperiment script")
    parser.add_argument('--data_size', type=int, default=1, help='Size of the data to be collected')
    parser.add_argument(
        "--G",
        type=int,
        default=10,
        help="G.",
    )
    parser.add_argument(
        "--N",
        type=int,
        default=4,
        help="N.",
    )
    args = parser.parse_args()
    torch.cuda.empty_cache()

    G = args.G
    P = 4 # this is the number of sentences used to probe LADE (only for Graph title purposes
    N = args.N
    args = parser.parse_args()
    torch.cuda.empty_cache()

    if int(os.environ.get("LOAD_LADE", 0)):
        import lade
        lade.augment_all()
        #For a 7B model, set LEVEL=5 (N), WINDOW_SIZE=7 (W), GUESS_SET_SIZE=7 (G)
        lade.config_lade(LEVEL=N, WINDOW_SIZE=6, GUESS_SET_SIZE=G, DEBUG=0, POOL_FROM_PROMPT=True)

    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        print(f"Using GPU: {gpu_name}")
        torch_device = torch.device("cuda")
    else:
        print("CUDA is not available. Using CPU.")
        torch_device = torch.device("cpu")

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model, tokenizer = setup(model_name, torch_device)

    warmup(model, tokenizer, torch_device)

    prompt = """Repeat the following sentences 5 times: 
    I run to the coffee shop. I run as fast as cars. I run with my best friend. I run from a lecture hall. I run to the coffee shop. I run as fast as cars. I run with my best friend. I run from a lecture hall.
    """
    # prompt = """Repeat the following sentences 5 times: 
    # I run to the large shop. I run with my best friend. I run down the road side. I run to the large shop. I run with my best friend. I run down the road side.
    # """
    # prompt = """Repeat the following sentences 10 times: 
    # I run to a large shop. I run with my best friend. I run to a large shop. I run with my best friend.
    # """

    # prompt = """Repeat the following sentences 2 times: 
    # I run to the big shop. I run down the road side. I run to the big shop. I run down the road side.
    # """

    # prompt = """Repeat the following sentences 2 times: 
    # I run to the big shop. I run to the big shop.
    # """
    # prompt = "What disease do I have if I experience severe headaches, dizziness, and blurred vision?"
    input_text = (
        f"<|system|>\nYou are a friendly chatbot who always responds in the style of a domain expert.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>"
    )
    model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)

    input_size = model_inputs["input_ids"].size(1)

    os.environ["CHAT"] = "1"
   
    out = executions(model, tokenizer, torch_device, prompt)
    with open(f'G={G}, P={P}.json', 'w') as file:
        json.dump(out, file)

    with open(f'G={G}, P={P}.json', "r") as f:
        out = json.load(f)

    plot_and_save_graph([i for i in range(30)], out[:30], "Token ID", "Time", f"Trace Pattern for G={G}, |prompt|={P}", f"G={G}_P={P}")
    
