from preprocessing import *
from word_swap import *
from activations import *
import torch
from transformers import AutoTokenizer,AutoModel,pipeline
from nltk.corpus import brown
import matplotlib
import matplotlib.pyplot as plt
from einops import rearrange, reduce
from nesim.utils.grid_size import find_rectangle_dimensions
# Laod a custom color map for better visualization
from scipy.io import loadmat
from tqdm import tqdm

colormap = loadmat('colormap-custom-lightblue-to-yellow1.mat')['cmap']
colormap = matplotlib.colors.ListedColormap(colormap)

from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.experiments.gpt_neo_125m import get_checkpoint, get_untrained_model_and_tokenizer

topo_scales = [1,5,10,50]
global_step = 10500

checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
device = "cuda:0"

checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )


from activations import get_activations, get_activations_hacked
from analysis import *
from plots import *

def calculate_differences(swapped_seqs,original_seqs,tokenizer,model,device='cuda', hacked = False):
    """Calculates the differences for a list of swapped and original sequences.

    Args:
        swapped_seqs (list): list of lists of swapped sequences
        original_seqs (list): list of original sequences
        tokenizer: Huggingface tokenizer
        model: Huggingface model
        device (str): device to use for the model (e.g., 'cpu', 'cuda')

    Returns:
        out (torch.Tensor): tensor of differences with shape (n_layers, swap position, measured position, n_features)
    """

    out = None
    n = 0
    for zz,(swapped,original) in tqdm(enumerate(zip(swapped_seqs,original_seqs))):
        try:
            swapped_dfs = []
            for s in swapped:
                df = text_to_df(s,tokenizer=tokenizer)
                swapped_dfs.append(df)
            original_df = text_to_df(original,tokenizer=tokenizer)
            if not hacked:
                swapped_activations = get_activations(swapped_dfs,model=model,device=device)
                original_activations = get_activations([original_df],model=model,device=device)
            else:
                swapped_activations = get_activations_hacked(swapped_dfs,model=model,device=device)
                original_activations = get_activations_hacked([original_df],model=model,device=device)

            if (swapped_activations is not None) and (original_activations is not None):
                difference = torch.abs(swapped_activations - original_activations)
                if out is None:
                    out = difference
                else:
                    out += difference
                n += 1
        except ValueError:
            print("Error: ",zz,original)
            continue
    print("Finished calculating difference tensor for ",n," sequences")
    return out/n


def average_over_units_indiv_plots(layers,fitobj,D_delta):  
    # plot differences and fits
    max_window_size = 21
    idx = 0
    t = np.arange(0,max_window_size)
    D_delta_all = np.stack(D_delta,0).mean(0)
    fig,axes = plt.subplots(1,len(layers))
    for k,layer in enumerate(layers):
        ax = axes.flat[k]


        for u in D_delta:
            ax.plot(t, np.median(u[:, :, layer],axis=1), color='gray',label = 'Indiv.' '\n' 'Units',linewidth=.5)
        ax.plot(t, np.median(D_delta_all[:, :, layer], axis=1), color = 'black', linestyle="-", linewidth=2,label="Mean")
        ax.set_ylim([0, 1.2])
        
        
        ax.set_xlim([0, max_window_size - 1])

        #     ax.set_xticks([])
        # ax.tick_params(labelright= False,labeltop= False,labelleft= False, labelbottom= False)
        ax.set_title(f"Layer {layer}")
        ax.grid(False)        
        ax.invert_xaxis()
        if k == 0:
            ax.set_ylabel(r"$\theta_{norm}[\Delta]$")
        idx += 1
    fig.tight_layout()


checkpoint_names = [
    "baseline",
    "topo_1",
    "topo_5",
    "topo_10",
    "topo_50"
]

for checkpoint_name in tqdm(checkpoint_names):
    print(f"{checkpoint_name}")
    model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)
    tokenizer.add_prefix_space = True
    overall_integration_corpus = Corpus(brown,single_token_words=False,tokenizer=tokenizer)
    overall_integration_swapper = RandomPosWordSwap(overall_integration_corpus.word_lookup,
                                                    overall_integration_corpus.pos_dict,
                                                    tokenizer)
    natural_sequences_40 = overall_integration_corpus.get_natural_sequences_of_length(40)
    len(natural_sequences_40)

    overall_integration_swapper(natural_sequences_40)

    differences = calculate_differences(overall_integration_swapper.swapped,
                                        overall_integration_swapper.original_sequences,
                                        tokenizer,model,device='cuda', hacked = True)
    
    print("differences shape: ",differences.shape)
    D = np.transpose(differences.numpy(), (3, 1, 2, 0))
    print("D shape: ",D.shape)
    n_features, n_stim_time, n_model_time, n_layers = D.shape
    all_fits, all_D_delta = fit_curves(D)
    stacked_fits = [np.stack(fit) for fit in all_fits]
    stacked_fits = np.stack(stacked_fits)
    # average_over_units_indiv_plots([1,3,6,9,11],all_fits,all_D_delta)

    for parameter_index, parameter_name in zip([2,0,1],['c','a','b']):

        # Prepare the figure for 3 rows and 4 columns
        fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))
        fig.suptitle(f"{checkpoint_name}")
        # Flatten the 2D array of axes for easy iteration
        axes = axes.flatten()

        for layer_index in range(12):
            print(layer_index)
            c_values = stacked_fits[:, layer_index, parameter_index]
            size = find_rectangle_dimensions(c_values.shape[0])
            
            # Plot in the corresponding subplot with aspect ratio preserved
            ax = axes[layer_index]
            img = c_values.reshape(size.height, size.width)
            im = ax.imshow(img, aspect='equal', vmin=0, vmax=1, cmap = "coolwarm")
            np_filename = f"raw_data/{checkpoint_name}_layer_{layer_index}_param_{parameter_name}.npy"
            np.save(np_filename)
            print(f"saved: {np_filename}")
            
            # Disable the grid and remove ticks
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])
            
            fig.colorbar(im, ax=ax)
            ax.set_title(f"Layer: {layer_index}")

        # Adjust layout to avoid overlap
        plt.tight_layout()
        plt.show()
        filename = f"assets/{checkpoint_name}_{parameter_name}_values.pdf"
        fig.savefig(filename)
        print(f"saved: {filename}")