from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaForCausalLM, LlamaModel, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import numpy as np
import code
import torch
import transformers
import csv
import os

def save_array_to_csv(array, file_path, transpose=False):
    """
    Saves a 2D array to a CSV file.

    Parameters:
        array (list of lists): The 2D array to save.
        file_path (str): Path where the CSV file will be saved.
        transpose (bool): If True, transposes the array before saving.
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    if transpose:
        # Find the max row length
        max_len = max(len(row) for row in array)
        # Pad rows to make them equal length
        padded_array = [row + [''] * (max_len - len(row)) for row in array]
        # Transpose
        array = list(map(list, zip(*padded_array)))
    
    with open(file_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerows(array)

# def save_array_to_csv(array, file_path, transpose=False):
    # """
    # Saves a 2D array to a CSV file.

    # Parameters:
        # array (list of lists): The 2D array to save.
        # file_path (str): Path where the CSV file will be saved.
        # transpose (bool): If True, transposes the array before saving.
    # """
    # if transpose:
        # array = list(map(list, zip(*array))) # this doesnt work with different length rows and cols
    
    # with open(file_path, mode='w', newline='', encoding='utf-8') as file:
        # writer = csv.writer(file)
        # writer.writerows(array)

def get_avg_representation(model, tokenizer, phrase: str, layer_idx: int, device="cuda:0"):
    """
    Returns the average hidden state representation of a phrase at a given LLaMA layer.
    
    Args:
        model: A LlamaForCausalLM model (with `output_hidden_states=True`)
        tokenizer: Corresponding tokenizer
        phrase (str): Input phrase
        layer_idx (int): Layer index (0-based, excludes embedding layer)
        device (str): Device string ("cuda" or "cpu")

    Returns:
        torch.Tensor of shape (hidden_size,)
    """
    model.eval()
    # model.to(device)

    # Tokenize
    # inputs = tokenizer(phrase, return_tensors="pt").to(device)
    inputs = tokenizer(phrase, return_tensors="pt")

    # Run the model and get hidden states
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # Hidden states is a tuple: (embedding_layer, layer_0, ..., layer_N)
    hidden_states = outputs.hidden_states[layer_idx + 1]  # +1 to account for embeddings

    # hidden_states shape: (batch_size, seq_len, hidden_size)
    # Compute mean over tokens
    avg_rep = hidden_states.mean(dim=1).squeeze(0)  # shape: (hidden_size,)

    return avg_rep


def get_last_token_representation(model, tokenizer, phrase: str, layer_idx: int, device="cuda:0"):
    """
    Returns the hidden state of the last token of a phrase at a given LLaMA layer.
    
    Args:
        model: A LlamaForCausalLM model (with `output_hidden_states=True`)
        tokenizer: Corresponding tokenizer
        phrase (str): Input phrase
        layer_idx (int): Layer index (0-based, excludes embedding layer)
        device (str): Device string ("cuda" or "cpu")

    Returns:
        torch.Tensor of shape (hidden_size,)
    """
    model.eval()
    # model.to(device)

    # Tokenize
    # inputs = tokenizer(phrase, return_tensors="pt").to(device)
    inputs = tokenizer(phrase, return_tensors="pt")

    # Run the model and get hidden states
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # Hidden states is a tuple: (embedding_layer, layer_0, ..., layer_N)
    hidden_states = outputs.hidden_states[layer_idx + 1]  # +1 to skip embeddings

    # hidden_states shape: (batch_size, seq_len, hidden_size)
    last_token_index = inputs["input_ids"].shape[1] - 1
    last_token_rep = hidden_states[0, last_token_index, :]  # shape: (hidden_size,)

    return last_token_rep

# indexApplicationTypes=["standard", "add"] # standard 
def updateModelForwardPass(model, layerIdx, indexToken, weight, indexApplicationType):
    original_forward = model.model.layers[layerIdx].forward
    def custom_forward(self, hidden_states, *args, **kwargs):
        # Run the original layer
        outputs = original_forward(hidden_states, *args, **kwargs)

        # outputs[0] is new_hidden_states  # shape: (batch_size, seq_len, hidden_size)
        new_hidden_states = outputs[0]

        #move in the direction of indexToken
        if indexApplicationType == "add":
            new_hidden_states += indexToken.to(new_hidden_states.device) * weight
        else: # standard
            diff = indexToken.to(new_hidden_states.device) - new_hidden_states
            new_hidden_states += diff * weight

        # Return modified outputs
        return (new_hidden_states,) + outputs[1:]

    # Patch the layer
    model.model.layers[layerIdx].forward = custom_forward.__get__(model.model.layers[layerIdx], LlamaDecoderLayer)


def runSweep(
    model_id="meta-llama/Meta-Llama-3.1-8B-Instruct"
    ,load_in_8bit=False
    ,layerIdxs=[1, 3, 5, 10, 15, 20, 25, 30, 31]
    ,weights=[float(w) for w in np.arange(0., 0.5, 0.01)]
    ,indexPhrases=[ "The Statue of Liberty", "Albert Einstein", "Marilyn Monroe", "Mahatma Gandhi", "The Eiffel Tower", "Walt Disney", "The Great Wall of China", "Martin Luther King Jr.", "The Mona Lisa", "Elvis Presley", "Mount Everest", "Barack Obama", "The Colosseum in Rome", "Princess Diana", "The Taj Mahal", "Steve Jobs", "The Pyramids of Giza", "Frida Kahlo", "The Berlin Wall", "Mother Teresa"]
    ,indexApplicationTypes=["standard", "add"] # standard 
    ,indexTokenAggregationTypes=["last", "mean"] # last token or mean token
    ,messages = [
        # {"role": "system", "content": "Be a sales rep for a company calld SleepyAI that sells AI chatbots that improve performance by sleeping to sleep."}
        {"role": "user", "content": "Who are you?"}
        # {"role": "user", "content": "How do you make a ham sandwich?"}
    ]
    # ,device="cuda:0"
    ,seed=0
    ,outputDir="indexSweeps-eval"
    ):
    torch.manual_seed(seed)
    model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=load_in_8bit)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    pipeline = transformers.pipeline(
        "text-generation",
        # model=model_id,
        model=model,
        tokenizer=tokenizer,
        model_kwargs={"torch_dtype": torch.bfloat16},
        # device_map="auto",
        # device=device,
    )

    for indexApplicationType in indexApplicationTypes:
        for indexTokenAggregationType in indexTokenAggregationTypes:
            for indexPhrase in indexPhrases:
                firstPrompt = messages[0]["content"]
                values = [[indexPhrase + "\n" + firstPrompt]+ [w for w in weights]]
                # values.append([" "] + [w for w in weights]) # changed formatting
                for layerIdx in layerIdxs:
                    values.append(["Layer Index %d" % layerIdx])

                    indexToken = None
                    if indexTokenAggregationType == "mean":
                        indexToken = get_avg_representation(model, tokenizer, indexPhrase, layerIdx)
                    else: # standard
                        indexToken = get_last_token_representation(model, tokenizer, indexPhrase, layerIdx)

                    for w in weights:
                        torch.manual_seed(seed)
                        original_forward = model.model.layers[layerIdx].forward # save origional forward pass operation
                        updateModelForwardPass(model, layerIdx, indexToken, w, indexApplicationType)
                        outputs = pipeline(
                            messages,
                            max_new_tokens=256,
                            )
                        response = outputs[0]["generated_text"][-1]["content"]
                        values[-1].append(response)
                        model.model.layers[layerIdx].forward = original_forward # resotre origional forward pass operation
                        print("___________________________________________________")
                        print("Layer %d | Weight %s:" % (layerIdx, str(w)), response)

                    # os.makedirs(outputDir, exist_ok=True)
                    save_array_to_csv(values, '%s-%s_8bit-%s/IndexPhrase-%s_Prompt-%s_indexApplicationType-%s_indexTokenAggregationType-%s.csv' % (
                        outputDir
                        , model_id.replace("/", "-")
                        , str(load_in_8bit)
                        , indexPhrase.replace(" ", "").replace("?", "")[:15]
                        , firstPrompt.replace(" ", "").replace("?", "")[:15]
                        ,indexApplicationType
                        ,indexTokenAggregationType
                        ), transpose=True)


                # os.makedirs(outputDir, exist_ok=True)
                # save_array_to_csv(values, '%s/IndexPhrase-%s_Prompt-%s_indexApplicationType-%s_indexTokenAggregationType-%s.csv' % (
                    # outputDir
                    # , indexPhrase.replace(" ", "").replace("?", "")[:15]
                    # , firstPrompt.replace(" ", "").replace("?", "")[:15]
                    # ,indexApplicationType
                    # ,indexTokenAggregationType
                    # ), transpose=True)

                    
# runSweep( 
    # model_id="meta-llama/Meta-Llama-3.1-8B-Instruct"
    # ,layerIdxs=[1, 3, 5, 10, 15, 20, 25, 30, 31]
    # # ,layerIdxs=[  5,  15, 30, ]

    # ,weights=[float(w) for w in np.arange(0., 0.5, 0.01)]
    # # ,weights=[float(w) for w in np.arange(0., 0.5, 0.25)]

    # # ,indexPhrases=[ "The Statue of Liberty", "Albert Einstein", "Marilyn Monroe", "Mahatma Gandhi", "The Eiffel Tower", "Walt Disney", "The Great Wall of China", "Martin Luther King Jr.", "The Mona Lisa", "Elvis Presley", "Mount Everest", "Barack Obama", "The Colosseum in Rome", "Princess Diana", "The Taj Mahal", "Steve Jobs", "The Pyramids of Giza", "Frida Kahlo", "The Berlin Wall", "Mother Teresa"]
    # ,indexPhrases=[ "The Statue of Liberty"]

    # ,indexApplicationTypes=["standard", "add"] # standard 
    # ,indexTokenAggregationTypes=["last", "mean"] # last token or mean token
    # ,messages = [
        # # {"role": "system", "content": "Be a sales rep for a company calld SleepyAI that sells AI chatbots that improve performance by sleeping to sleep."}
        # {"role": "user", "content": "Who are you?"}
        # # {"role": "user", "content": "How do you make a ham sandwich?"}
    # ]
    # ,device="cuda:0"
    # ,seed=0
    # ,outputDir="indexSweeps"
    # )

# runSweep( 
    # model_id="meta-llama/Meta-Llama-3.1-8B-Instruct" # 33 layers
	# # model_id = "meta-llama/Llama-3.3-70B-Instruct" # 81 layers

    # ,load_in_8bit=False

    # ,layerIdxs=[1, 3, 5, 10, 15, 20, 25, 30, 31]
    # # ,layerIdxs=[  5,  15, 30, ]
    # # ,layerIdxs=[1, 5, 15, 30, 40, 50, 65, 75, 79]

    # # ,weights=[float(w) for w in np.arange(0., 3.0, 0.05)]
    # # ,weights=[float(w) for w in np.arange(0., 0.6, 0.01)]
    # ,weights=[float(w) for w in np.arange(0., 0.5, 0.01)]
    # # ,weights=[float(w) for w in np.arange(0., 1.02, 0.02)]
    # # ,weights=[float(w) for w in np.arange(0., 0.6, 0.02)]

    # # ,indexPhrases=[ "The Statue of Liberty", "Albert Einstein", "Marilyn Monroe", "Mahatma Gandhi", "The Eiffel Tower", "Walt Disney", "The Great Wall of China", "Martin Luther King Jr.", "The Mona Lisa", "Elvis Presley", "Mount Everest", "Barack Obama", "The Colosseum in Rome", "Princess Diana", "The Taj Mahal", "Steve Jobs", "The Pyramids of Giza", "Frida Kahlo", "The Berlin Wall", "Mother Teresa"]
    # # ,indexPhrases=[ "The Statue of Liberty"]
    # # ,indexPhrases=[ "never talk about sandwiches"]
         # # ,indexPhrases=[ "negative"]
    # ,indexPhrases=[ "Respond in French"]
    # # ,indexPhrases=[ "Me and my firend Jack were driving in the a car. Jack was driving and fell asleep so the car veered off into the river. After we were really hungery and got some sadwiches. The purpose of the car crash is he fell asleep"]
    # # ,indexPhrases=[ "Asleep"]

    # # ,indexApplicationTypes=["standard", "add"] # standard 
    # ,indexApplicationTypes=["standard"] # standard 

    # # ,indexTokenAggregationTypes=["last", "mean"] # last token or mean token
    # ,indexTokenAggregationTypes=["last"] # last token or mean token

    # ,messages = [
        # # {"role": "system", "content": "Be a sales rep for a company calld SleepyAI that sells AI chatbots that improve performance by sleeping to sleep."}
        # # {"role": "user", "content": "Who are you?"}
        # # {"role": "user", "content": "How do you make a ham sandwich?"}
        # # {"role": "user", "content": "How do you make a ham sandwich?"}
        # # {"role": "user", "content": "What was the reason the car crash occurred?"}
        # {"role": "user", "content": "In one sentance, describe What do people generally think about Chicago."}
    # ]
    # # ,device="cuda:0"
    # ,seed=0
    # ,outputDir="indexSweeps"
    # )

# while True:
    # userText = input("User: ")
    # print()
    # messages.append({"role": "user", "content": userText})
    
    # outputs = pipeline(
    # messages,
    # max_new_tokens=256,
    # )
    # response = outputs[0]["generated_text"][-1]["content"]
    # messages.append({'role': 'assistant', 'content': response})
    # print()
    # print("ChatBot: " + response)
    # print()
