# %%
# model_name = 'microsoft/Phi-3-mini-4k-instruct'
# task = 'color'
# ag_news
# navigate
# color
def getInfo(model_name, task):
    if(model_name == 'microsoft/Phi-3-mini-4k-instruct'):
        layers = 12
        if(task == 'ag_news'):
            fileName = 'BaysianOptimization/error/phi-3/ag_news/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_v2_0.pkl'
        if(task == 'navigate'):
            fileName = 'BaysianOptimization/error/phi-3/navigate/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_v2_3.pkl'
        if(task == 'color'):
            fileName = 'BaysianOptimization/error/phi-3/color/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_v2_0.pkl'
        if(task == 'entailed_polarity'):
            fileName = 'BaysianOptimization/error/phi-3/entailed_polarity/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_v2_4.pkl'
        if(task == 'winowhy'):
            fileName = 'BaysianOptimization/error/phi-3/winowhy/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_0.pkl'
    if(model_name == 'meta-llama/Meta-Llama-3-8B-Instruct'):
        layers = 32
        if(task == 'ag_news'):
            fileName = 'BaysianOptimization/error/llama-3-8b/ag_news/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_0.pkl'
        if(task == 'navigate'):
            fileName = 'BaysianOptimization/error/llama-3-8b/navigate/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_4.pkl'
        if(task == 'color'):
            fileName = 'BaysianOptimization/error/llama-3-8b/color/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_3.pkl'
        if(task == 'entailed_polarity'):
            fileName = 'BaysianOptimization/error/llama-3-8b/entailed_polarity/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_1.pkl'
        if(task == 'winowhy'):
            fileName = 'BaysianOptimization/error/llama-3-8b/winowhy/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_1.pkl'
    if(model_name == 'Qwen/Qwen2-1.5B-Instruct'):
        layers = 10
        if(task == 'ag_news'):
            fileName = 'BaysianOptimization/error/qwen_2/ag_news/NN_kernel_layer_0_5_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_0.pkl'
        if(task == 'navigate'):
            fileName = 'BaysianOptimization/error/qwen_2/navigate/NN_kernel_layer_0_5_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_2.pkl'
        if(task == 'color'):
            fileName = 'BaysianOptimization/error/qwen_2/color/NN_kernel_layer_0_5_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_2.pkl'
        if(task == 'entailed_polarity'):
            fileName = 'BaysianOptimization/error/qwen_2/entailed_polarity/NN_kernel_layer_0_5_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_2.pkl'
        if(task == 'winowhy'):
            fileName = 'BaysianOptimization/error/qwen_2/winowhy/NN_kernel_layer_0_5_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_2.pkl'
    if(model_name == 'mistralai/Mistral-7B-Instruct-v0.1'):
        layers = 32
        if(task == 'ag_news'):
            fileName = 'BaysianOptimization/error/mixtral/ag_news/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_3.pkl'
        if(task == 'navigate'):
            fileName = 'BaysianOptimization/error/mixtral/navigate/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_4.pkl'
        if(task == 'color'):
            fileName = 'BaysianOptimization/error/mixtral/color/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_3.pkl'
        if(task == 'entailed_polarity'):
            fileName = 'BaysianOptimization/error/mixtral/entailed_polarity/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_1.pkl'
        if(task == 'winowhy'):
            fileName = 'BaysianOptimization/error/mixtral/winowhy/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_3.pkl'
    return layers, fileName
# %%
import pickle
from transformers import AutoTokenizer
modelNames = ['microsoft/Phi-3-mini-4k-instruct', 'meta-llama/Meta-Llama-3-8B-Instruct', 'Qwen/Qwen2-1.5B-Instruct', 'mistralai/Mistral-7B-Instruct-v0.1']
tasks = ['ag_news','navigate', 'color', 'entailed_polarity', 'winowhy']
for task in tasks:
    minlogit = []
    min_values = {}
    max_values = {}
    mean_values = {}
    for model_name in modelNames:
        layers, fileName = getInfo(model_name, task)
        with open(fileName, 'rb') as f:
            data = pickle.load(f)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model_name = model_name.split('/')[1]
        print(f'Processing {model_name} for task {task}', len(data))
        if(len(data) == 50):
            import torch
            layer_wise = [0] * layers
            layer_wise_error = [0] * layers

            count = 0
            from tqdm import tqdm
            rotated_accuracy = 0
            normal_accuracy = 0
            normal_model_values = {'min': [], 'max': [], 'mean': []}
            rotated_model_values = {'min': [], 'max': [], 'mean': []}
            for i in tqdm(range(len(data))):
                prompt = data[i]['prompt']
                label = data[i]['label']
                normal_token = data[i]['normal']
                rotated_token = data[i]['rotated']
                rotated_cache = data[i]['rotated_cache']
                normal_cache = data[i]['normal_cache']
                umembed = data[i]['umembed']
                answer_token = tokenizer.encode(label['complete'], add_special_tokens=False)[0]
                # print(answer_token, rotated_token, normal_token)
                if(label['complete'].startswith(rotated_token)):
                    
                    count += 1
                    j = layers - 1
                    
                        # print(j)
                    if(f'blocks.{j}.hook_resid_post' in rotated_cache):
                        key = f'blocks.{j}.hook_resid_post'
                    else:
                        key = f'blocks.{j}.hook_resid_mid'
                    normal_logits_layer = umembed(normal_cache[key].to('cuda')).detach().cpu()
                    rotated_logits_layer = umembed(rotated_cache[key].to('cuda')).detach().cpu()
                    
                    normal_logits_layer = normal_logits_layer[:, -1, :].squeeze()
                    rotated_logits_layer = rotated_logits_layer[:, -1, :].squeeze()
                    
                    normal_model_values['min'].append(normal_logits_layer)
                    normal_model_values['max'].append(normal_logits_layer)
                    normal_model_values['mean'].append(normal_logits_layer)
                    
                    rotated_model_values['min'].append(rotated_logits_layer)
                    rotated_model_values['max'].append(rotated_logits_layer)
                    rotated_model_values['mean'].append(rotated_logits_layer)
                    
                    
            min_values[model_name] = {'normal': torch.min(torch.stack(normal_model_values['min'], dim=0)), 'rotated': torch.min(torch.stack(rotated_model_values['min'], dim=0))}
            max_values[model_name] = {'normal': torch.max(torch.stack(normal_model_values['max'], dim=0)), 'rotated': torch.max(torch.stack(rotated_model_values['max'], dim=0))}
            mean_values[model_name] = {'normal': torch.mean(torch.stack(normal_model_values['mean'], dim=0)), 'rotated': torch.mean(torch.stack(rotated_model_values['mean'], dim=0))}
            
        elif(len(data) == 51):
            import torch
            layer_wise = [0] * layers
            layer_wise_error = [0] * layers
            umembed = data[-1]['unembed']
            count = 0
            from tqdm import tqdm
            rotated_accuracy = 0
            normal_accuracy = 0
            normal_model_values = {'min': [], 'max': [], 'mean': []}
            rotated_model_values = {'min': [], 'max': [], 'mean': []}
            for i in tqdm(range(len(data) - 1)):
                prompt = data[i]['prompt']
                label = data[i]['label']
                normal_token = data[i]['normal']
                rotated_token = data[i]['rotated']
                rotated_cache = data[i]['rotated_cache']
                normal_cache = data[i]['normal_cache']
                # umembed = data[i]['umembed']
                answer_token = tokenizer.encode(label['complete'], add_special_tokens=False)[0]
                # print(answer_token, rotated_token, normal_token)
                if(label['complete'].startswith(rotated_token)):
                    
                    count += 1
                    j = layers - 1
                    # for j in range(layers):
                        # print(j)
                    if(f'blocks.{j}.hook_resid_post' in rotated_cache):
                        key = f'blocks.{j}.hook_resid_post'
                    else:
                        key = f'blocks.{j}.hook_resid_mid'
                    normal_logits_layer = umembed(normal_cache[key].to('cuda')).detach().cpu()
                    rotated_logits_layer = umembed(rotated_cache[key].to('cuda')).detach().cpu()
                    
                    normal_logits_layer = normal_logits_layer[:, -1, :].squeeze()
                    rotated_logits_layer = rotated_logits_layer[:, -1, :].squeeze()
                    
                    normal_model_values['min'].append(normal_logits_layer)
                    normal_model_values['max'].append(normal_logits_layer)
                    normal_model_values['mean'].append(normal_logits_layer)
                    
                    rotated_model_values['min'].append(rotated_logits_layer)
                    rotated_model_values['max'].append(rotated_logits_layer)
                    rotated_model_values['mean'].append(rotated_logits_layer)
                    
            # breakpoint()
            min_values[model_name] = {'normal': torch.min(torch.stack(normal_model_values['min'], dim=0)), 'rotated': torch.min(torch.stack(rotated_model_values['min'], dim=0))}
            max_values[model_name] = {'normal': torch.max(torch.stack(normal_model_values['max'], dim=0)), 'rotated': torch.max(torch.stack(rotated_model_values['max'], dim=0))}
            mean_values[model_name] = {'normal': torch.mean(torch.stack(normal_model_values['mean'], dim=0)), 'rotated': torch.mean(torch.stack(rotated_model_values['mean'], dim=0))}
        # breakpoint()            
    # Extract the models and categories
    import numpy as np
    import matplotlib.pyplot as plt
    models = list(min_values.keys())
    # for i in range(len(models)):
    #     models[i] = models[i].split('/')[1]
    categories = ['normal', 'rotated']
    
    fig, ax = plt.subplots(figsize=(10, 6))

    # Define the positions for the models and categories
    bar_width = 0.35
    index = np.arange(len(models))

    # Extract values for 'normal' and 'rotated'
    normal_min = [min_values[model]['normal'] for model in models]
    rotated_min = [min_values[model]['rotated'] for model in models]

    normal_max = [max_values[model]['normal'] for model in models]
    rotated_max = [max_values[model]['rotated'] for model in models]

    # Plotting
    fig, ax = plt.subplots(figsize=(10, 6))

    # Adjust spacing between min and max values to align
    ax.bar(index - bar_width / 2, normal_min, bar_width, label='Base Min', color='blue', alpha=0.8)
    ax.bar(index - bar_width / 2, normal_max, bar_width, label='Base Max', color='lightblue', alpha=0.6)

    ax.bar(index + bar_width / 2, rotated_min, bar_width, label='TaRot Min', color='green', alpha=0.8)
    ax.bar(index + bar_width / 2, rotated_max, bar_width, label='TaRot Max', color='lightgreen', alpha=0.6)

    # Add labels and titles
    ax.set_xlabel('Models')
    ax.set_ylabel('Values')
    ax.set_title('Min and Max Values for Base and TaRot Models')
    ax.set_xticks(index)
    ax.set_xticklabels(models)  # Rotate for better readability
    ax.legend()

    plt.tight_layout()
    # normal_min = [min_values[model]['normal'] for model in models]
    # rotated_min = [min_values[model]['rotated'] for model in models]

    # normal_max = [max_values[model]['normal'] for model in models]
    # rotated_max = [max_values[model]['rotated'] for model in models]

    # normal_mean = [mean_values[model]['normal'] for model in models]
    # rotated_mean = [mean_values[model]['rotated'] for model in models]

    # # Create the plot
    # fig, ax = plt.subplots(figsize=(10, 6))

    # # Define bar width and positions
    # bar_width = 0.35
    # index = np.arange(len(models))

    # # Plot min, max, and mean for 'normal' category
    # ax.bar(index, normal_min, bar_width, label='Normal Min', color='blue', alpha=0.5)
    # ax.bar(index, np.array(normal_max) - np.array(normal_min), bar_width, bottom=normal_min, label='Normal Max', color='blue')
    # ax.scatter(index, normal_mean, color='red', label='Normal Mean', zorder=5, marker='o', s=100)

    # # Plot min, max, and mean for 'rotated' category (shifted to the right)
    # ax.bar(index + bar_width, rotated_min, bar_width, label='Rotated Min', color='green', alpha=0.5)
    # ax.bar(index + bar_width, np.array(rotated_max) - np.array(rotated_min), bar_width, bottom=rotated_min, label='Rotated Max', color='green')
    # ax.scatter(index + bar_width, rotated_mean, color='orange', label='Rotated Mean', zorder=5, marker='o', s=100)

    # # Add labels, title, and grid
    # ax.set_xlabel('Models')
    # ax.set_ylabel('Values')
    # ax.set_title('Min, Max, and Mean Values for Normal and Rotated Categories')
    # ax.set_xticks(index + bar_width / 2)
    # ax.set_xticklabels(models, rotation=45, ha="right")
    # ax.legend()

    # plt.tight_layout()
    # plt.show()
    plt.savefig(f'svg/Combined_Model_Task_{fileName.split("/")[-2]}.svg')
    plt.close()
    plt.clf()
    # breakpoint()
        # for j in range(layers):
        #     # Mean difference for each layer
        #     layer_wise[j] /= count
            
        #     # Calculate standard deviation or standard error
        #     variance = (layer_wise_error[j] / count) - (layer_wise[j] ** 2)
        #     layer_wise_error[j] = torch.sqrt(torch.tensor(variance / count))
        # import matplotlib.pyplot as plt

        # layers_range = range(layers)
        # plt.errorbar(layers_range, layer_wise, yerr=layer_wise_error, fmt='-o', capsize=5, label='Layer-wise Difference')
        # plt.xlabel('Layers')
        # plt.ylabel('Difference in Answer Token Probability')
        # plt.title(f'Layer wise Prob Difference for Model {model_name}, Task {fileName.split("/")[-2]}')
        # # plt.legend()
        # # plt.show()
        # # save pdf
        # model_save = model_name.replace('/', '_')
        # # plt.savefig(f'Model_{model_save}_Task_{fileName.split("/")[-2]}.pdf')
        # # save svg
        # plt.savefig(f'svg/Model_{model_save}_Task_{fileName.split("/")[-2]}.svg')
        # plt.close()
        # plt.clf()
        
# model_name = 'microsoft/Phi-3-mini-4k-instruct'

# layers = 12
# fileName = 'BaysianOptimization/error/phi-3/winowhy/NN_kernel_layer_0_6_angle_-0.7853981633974483_0.7853981633974483_reasoning_prob_mix_rotary_v2_0.pkl'

# %%
# BaysianOptimization/error/llama-3-8b/entailed_polarity/NN_kernel_layer_0_16_angle_-0.5235987755982988_0.5235987755982988_reasoning_prob_mix_rotary_v2_2.pkl

# %%
# from transformers import AutoTokenizer

# # %%
# # tokenizer = AutoTokenizer.from_pretrained(model_name)

# # %%
# with open(fileName, 'rb') as f:
#     data = pickle.load(f)

# # %%
# import torch.nn.functional as F

# def kl_divergence(input_logit, target_logit):
#     input_next_token_logit = input_logit[:, -1, :]
#     target_next_token_logit = target_logit[:, -1, :]

#     input_logProb = F.softmax(input_next_token_logit, dim=-1)
#     target_logProb = F.softmax(target_next_token_logit, dim=-1)

#     kl_div = F.kl_div(input_logProb, target_logProb,log_target=True, reduction="none").sum(dim=-1)
#     return kl_div.mean().detach().cpu()

# # %%
# len(data)

# # %%
# umembed = data[-1]['unembed']

# # %%
# import torch
# layer_wise = [0] * layers
# layer_wise_error = [0] * layers

# count = 0
# from tqdm import tqdm
# rotated_accuracy = 0
# normal_accuracy = 0
# for i in tqdm(range(len(data) - 1)):
#     prompt = data[i]['prompt']
#     label = data[i]['label']
#     normal_token = data[i]['normal']
#     rotated_token = data[i]['rotated']
#     rotated_cache = data[i]['rotated_cache']
#     normal_cache = data[i]['normal_cache']
#     # umembed = data[i]['umembed']
#     answer_token = tokenizer.encode(label['complete'], add_special_tokens=False)[0]
#     # print(answer_token, rotated_token, normal_token)
#     if(label['complete'].startswith(rotated_token)):
        
#         count += 1
#         for j in range(layers):
#             # print(j)
#             normal_logits_layer = umembed(normal_cache[f'blocks.{j}.hook_resid_post'].to('cuda')).detach().cpu()
#             rotated_logits_layer = umembed(rotated_cache[f'blocks.{j}.hook_resid_post'].to('cuda')).detach().cpu()
            
#             rotated_answer_token_prob = torch.nn.functional.softmax(rotated_logits_layer[:, -1, :], dim=-1)[0, answer_token].item()
#             normal_answer_token_prob = torch.nn.functional.softmax(normal_logits_layer[:, -1, :], dim=-1)[0, answer_token].item()
            
#             _, sorted_indices = torch.sort(rotated_logits_layer[:, -1, :][0], descending=True)
#             rotated_rank = (sorted_indices == answer_token).nonzero(as_tuple=True)[0].item()
            
#             _, sorted_indices = torch.sort(normal_logits_layer[:, -1, :][0], descending=True)
#             normal_rank = (sorted_indices == answer_token).nonzero(as_tuple=True)[0].item()
#             # print(rotated_rank, normal_rank)
#             # if(j == '31?otated_rank, normal_rank)
#             rotated_logit_normalised = (rotated_logits_layer[:, -1] - rotated_logits_layer[:, -1].min()) / (rotated_logits_layer[:, -1].max() - rotated_logits_layer[:, -1].min())
#             normal_logit_normalised = (normal_logits_layer[:, -1] - normal_logits_layer[:, -1].min()) / (normal_logits_layer[:, -1].max() - normal_logits_layer[:, -1].min())
            
#             prob_diff = normal_answer_token_prob - rotated_answer_token_prob
#             layer_wise[j] += prob_diff
#             layer_wise_error[j] += prob_diff ** 2

#         # layer_wise[j].append(rotated_answer_token_prob - normal_answer_token_prob)
    
        
        
        
        
        
        
    
    
    

# # %%
# count

# # %%
# for j in range(layers):
#     # Mean difference for each layer
#     layer_wise[j] /= count
    
#     # Calculate standard deviation or standard error
#     variance = (layer_wise_error[j] / count) - (layer_wise[j] ** 2)
#     layer_wise_error[j] = torch.sqrt(torch.tensor(variance / count))

# # %%
# import matplotlib.pyplot as plt

# layers_range = range(layers)
# plt.errorbar(layers_range, layer_wise, yerr=layer_wise_error, fmt='-o', capsize=5, label='Layer-wise Difference')
# plt.xlabel('Layers')
# plt.ylabel('Difference in Answer Token Probability')
# plt.title(f'Layer wise Prob Difference for Model {model_name}, Task {fileName.split("/")[-2]}')
# # plt.legend()
# # plt.show()
# # save pdf
# model_save = model_name.replace('/', '_')
# # plt.savefig(f'Model_{model_save}_Task_{fileName.split("/")[-2]}.pdf')
# # save svg
# plt.savefig(f'svg/Model_{model_save}_Task_{fileName.split("/")[-2]}.svg')

# # %%
# import numpy as np
# layer_wise_mean = [np.mean(layer) for layer in layer_wise]
# layer_wise_std = [np.std(layer) for layer in layer_wise]  # This will be used as the error margin


# # %%
# import matplotlib.pyplot as plt

# plt.figure(figsize=(10, 6))
# layers_range = range(layers)
# plt.errorbar(layers_range, layer_wise_mean, yerr=layer_wise_std, fmt='-o', capsize=5, label='Difference in Answer Token Probabilities')

# plt.xlabel('Layer')
# plt.ylabel('Difference in Answer Token Probability (Rotated - Normal)')
# plt.legend()
# plt.title(f'Layer wise Prob Difference for Model {model_name}, Task {fileName.split("/")[-2]}')


# # %%
# rotated_logits_layer[0, -1] - normal_logits_layer[0, -1]

# # %%
# print(rotated_accuracy, normal_accuracy, count)

# # %%
# for i in range(len(layer_wise)):
#     layer_wise[i] = layer_wise[i] / count

# # %%
# count

# # %%
# layer_wise

# # %%
# # plot the layer wise difference
# import matplotlib.pyplot as plt
# plt.plot(layer_wise)
# # xaxis
# plt.xlabel('Layer')
# # yaxis
# # plt.ylabel('KL Divergence')
# # plt.ylabel('logit difference')
# plt.ylabel('Prob Difference')
# # title
# plt.title(f'Layer wise Prob Difference for Model {model_name}, Task {fileName.split("/")[-2]}')
# # plt.title(f'Layer wise logit difference for Model {model_name}, Task {fileName.split("/")[-2]}')
# # plt.title(f'Layer wise KL Divergence for Model {model_name}, Task {fileName.split("/")[-2]}')
# # plt.title(f'Layer wise Rank Difference for Model {model_name}, Task {fileName.split("/")[-2]}')

# # %%
# torch.nn.functional.softmax(rotated_logits_layer[:, -1, :], dim=1).shape

# # %%
# import torch
# torch.nn.functional.softmax(normal_logits_layer, dim=-1).shape

# # %%



