# %%
import json
with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:
    train = [json.loads(line) for line in f]

# %%
input = [d['inputs'] for d in train]
output = [d['targets'][0] for d in train]

# %%
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import numpy as np
from skopt import gp_minimize
from skopt.space import Real
import torch
import transformer_lens.utils as utils
from functools import partial


# %%
def loadTransformerLensModel(modelPath):
    tokenizer = AutoTokenizer.from_pretrained(modelPath)
    hf_model = AutoModelForCausalLM.from_pretrained(modelPath, low_cpu_mem_usage=True)
    model = HookedTransformer.from_pretrained(modelPath, hf_model=hf_model, device='cpu', fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)

    return model, tokenizer

# %%
MODEL_PATH = 'meta-llama/Meta-Llama-3-8B'
# meta-llama/Llama-2-7b-hf
model, tokenizer = loadTransformerLensModel(MODEL_PATH)

# %%
device = 'cuda' #cuda
model = model.to(device)

# %%
import torch
import numpy as np

import torch

def create_rope_rotation_matrix(angles):
    """
    Create a RoPE (Rotary Position Embedding) rotation matrix as a tensor for a given vector of angles.

    Parameters:
    angles (torch.Tensor): A 1D PyTorch tensor of rotation angles for each dimension pair.

    Returns:
    torch.Tensor: A 2D PyTorch tensor representing the RoPE rotation matrix.
    """
    n = angles.shape[0]  # Number of angles
    d = 2 * n  # Dimension of the square matrix (embedding dimension must be even)
    rope_matrix = torch.zeros((d, d))

    # Fill the rotation matrix with 2x2 rotation matrices for each dimension pair
    for i, angle in enumerate(angles):
        cos_angle = torch.cos(angle)
        sin_angle = torch.sin(angle)
        # Set the 2x2 rotation submatrix
        rope_matrix[2 * i, 2 * i] = cos_angle
        rope_matrix[2 * i, 2 * i + 1] = -sin_angle
        rope_matrix[2 * i + 1, 2 * i] = sin_angle
        rope_matrix[2 * i + 1, 2 * i + 1] = cos_angle

    return rope_matrix

# Example usage:
angles = torch.tensor([0, 0])  # Example angles for embedding of size 6 (3 pairs)
rope_matrix = create_rope_rotation_matrix(angles)
print(rope_matrix)

# %%
L = 1
H = 32
N = 64

# %%
from tqdm import tqdm
search_space = []
for i in tqdm(range(L * H * N)):
    search_space.append(Real(0, torch.pi))

# %%
nShots = 8
nShots_example = [(i, o) for i, o in zip(input[:nShots], output[:nShots])]
import random
random.shuffle(nShots_example) 
Template = """"""
for i in range(4):
    Template += f"Input: {nShots_example[i][0]}\noutput: {nShots_example[i][1]}\n\n"


# %%
print(Template)

# %%
number_of_examples = 5
prompts = []
labels = []
for i in range(nShots, nShots + number_of_examples):
    prompts.append(Template + f"Input: {input[i]}\noutput: ")
    labels.append(output[i])
    # print(prompt)
    # break

# %%
prompt = prompts[0]

# %%
def runCache(model, input_id):
    cache = {}
    model.reset_hooks()
    fwd_hooks_list = []
    cache = {}
    def storeHookCache(value, hook):
        cache[hook.name] = torch.from_numpy(value.detach().cpu().numpy())
    for layer in range(L):
        fwd_hooks_list.append((utils.get_act_name("k", layer, "attn"), storeHookCache))
    
    model.run_with_hooks(input_id, return_type=None, fwd_hooks=fwd_hooks_list)
    
    return cache

test_prompt = "The quick brown fox jumped over the lazy dog"
print("Num tokens:", len(model.to_tokens(test_prompt)[0]))

def print_name_shape_hook_function(activation, hook):
    print(hook.name, activation.shape)

not_in_late_block_filter = lambda name: name.startswith("blocks.0.") or not name.startswith("blocks")

model.run_with_hooks(
    test_prompt,
    return_type=None,
    fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],
)
# %%
def runDefaultModel(model, tokenizer, prompt):
    input_id = tokenizer.encode(prompt, return_tensors='pt').to(device)
    model.reset_hooks()
    output = model(input_id)
    predicted_token = torch.argmax(output[:, -1, :], dim=1)[0]
    # breakpoint()
    return str(tokenizer.decode(predicted_token, skip_special_tokens=True)), predicted_token.item()

# %%
from jaxtyping import Float, Int

def rotateMatrix(
        clean_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
        hook,
        head_index,
        dimension_array):
    rotaryMatrix = create_rope_rotation_matrix(dimension_array).to(clean_head_vector.device)
    # assert(clean_head_vector[:, :, head_index, :] == clean_head_vector[:, :, head_index, :] @ rotaryMatrix).all()
    # breakpoint()
    clean_head_vector[:, :, head_index, :] = clean_head_vector[:, :, head_index, :] @ rotaryMatrix
    return clean_head_vector
encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')
encoded_prompt = encoded_prompt.to(device)

# cache = runCache(model, encoded_prompt)
# breakpoint()
# %%
def runRotatedModel(model, tokenizer, prompt, D, answer_token):
    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')
    encoded_prompt = encoded_prompt.to(device)
    
    # cache = runCache(model, encoded_prompt)
    
    
    list_fwd_hooks = []
    for layer in range(L):
        for head in range(H):
            list_fwd_hooks.append((utils.get_act_name("z", layer, "attn"), partial(rotateMatrix, head_index=head, dimension_array=D[layer, head])))
    
    rotated_logits = model.run_with_hooks(encoded_prompt, return_type="logits", fwd_hooks=list_fwd_hooks)
    
    predicted_token = torch.argmax(rotated_logits[:, -1, :], dim=1)[0]
    answer_token_prob = torch.nn.functional.softmax(rotated_logits[:, -1, :], dim=1)[0, answer_token].item()
    token = tokenizer.decode(predicted_token, skip_special_tokens=True)
    return token, answer_token_prob
    

# %%
def getOriginalAccuracy(model, tokenizer, prompts, labels):
    correct = 0
    from tqdm import tqdm
    pbar = tqdm(total=len(prompts))
    answer_tokens = []
    for prompt, label in tqdm(zip(prompts, labels), desc="Original Accuracy", total=len(prompts)):
        output, answer_token = runDefaultModel(model, tokenizer, prompt)
        if output == label:
            correct += 1
        answer_tokens.append(answer_token)
        pbar.set_description(f"Original Accuracy: {correct / len(prompts)}")
        pbar.update(1)
    return correct / len(prompts), answer_tokens

normalAccuracy, answer_tokens = getOriginalAccuracy(model, tokenizer, prompts, labels)
print(f"Normal Accuracy: {normalAccuracy}")
# %%
def objective(params):
    # params is a flattened version of the 2D matrix with each entry being an n-length array
    # D = torch.re(params).reshape((L, H, N))  # Reshape into 2D matrix of n-length arrays
    D = torch.tensor(params).reshape((L, H, N))
    # breakpoint()
    
    accuracy = 0
    count = 0
    prob = 0
    from tqdm import tqdm
    pbar = tqdm(total=len(prompts))
    for prompt, label, answer_token in tqdm(zip(prompts, labels, answer_tokens), desc="Rotated Accuracy", total=len(prompts)):
        predicted_output, answer_token_prob = runRotatedModel(model, tokenizer, prompt, D, answer_token)
        if predicted_output == label:
            accuracy += 1 
        count += 1
        prob += answer_token_prob
        # breakpoint()
        # else:
            # accuracy.append(0)
        pbar.set_description(f"Rotated answer token prob: {prob / count}")
        pbar.update(1)
    
    # Example of a mock objective function
    # Replace with actual model evaluation code
    # score = -np.sum(D) + np.random.normal(0, 1)  # Example score to minimize
    print(f"Accuracy: {accuracy / count}")
    print(f"Answer token prob: {prob / count}")
    return -1 * (prob / count)  # Return the score

# %%

from math import pi
print("PI: ", pi)
print("Generating Initial Points")
initial_points = []
all_zeros = [0 for _ in range(L * H * N)]
# accuracy = objective(all_zeros)
initial_points.append(all_zeros)
# for i in range(4):
#     initial_points.append(torch.rand(L * H * N) * pi)
    # initial_points.append([torch.rand for _ in range(L * H * N)])

# %%
print("Running Baysian Optimization")
result = gp_minimize(
    func=objective,          # The objective function
    dimensions=search_space, # The search space for the hyperparameters
    n_calls=200,              # Number of function evaluations
    x0=initial_points,       # Initial points
    n_initial_points=9,     # Number of random initialization points
    acq_func='EI',           # Acquisition function 'Expected Improvement'
    initial_point_generator='random',  # Initialize with Sobol sequence
    verbose=True             # Verbose mode for detailed output
)


# %%

import pickle
with open('optimized_params_results.pkl', 'wb') as f:
    pickle.dump(result, f)

# optimized_params = np.array(result.x).reshape((L, H, N))
optimized_params = torch.tensor(result.x).reshape((L, H, N))

