import torch
import os
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.experiments.gpt_neo_125m import get_checkpoint
import numpy as np
import matplotlib.pyplot as plt


from evaluation import evaluate_model_on_list_of_text
from nesim.utils.json_stuff import load_json_as_dict
import argparse


parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument('--eval-topic', type=str, help='Topic for evaluation')
parser.add_argument('--model', type=str, help='Topic for evaluation')

args = parser.parse_args()

checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
device = "cuda:0"
global_step = 10500
topo_scales = [1,5,10,50]


checkpoints_map = {
    "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

lesion_sizes = range(0, 325, 25)
assert args.model in list(checkpoints_map.keys())
checkpoint_path = checkpoints_map[args.model]

def lesion_model(
    model,
    lesion_mask_2d,
    factor = 0.1,
    layer_index = 0
):
    topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]
    layer_name = topo_layer_names[layer_index]
    layer = get_module_by_name(module = model, name = layer_name)

    lesion_mask_flat = lesion_mask_2d.reshape(-1).bool()
    # lesion_mask_flat = torch.stack(
    #     [lesion_mask_flat for _ in range(layer.weight.shape[1])],dim = -1
    # )
    # print(f"Scaling {layer.weight.data[lesion_mask_flat,:].shape[0]} cortical sheet items")
    layer.weight.data[lesion_mask_flat,:] = layer.weight.data[lesion_mask_flat,:] * factor
    setattr_pytorch_model(model, layer_name, layer)
    return model


def load_map(
    checkpoint_name,
    layer_index,
    category
):
    filename = os.path.join(
        "assets",
        checkpoint_name,
        f"transformer.h.{layer_index}.mlp.c_fc",
        f"{category}.npy"
    )
    return torch.tensor(np.load(filename))

def threshold_map_binary(x, threshold):
    x[x<threshold] = 0.
    x[x!=0] = 1.
    return x

def top_n_percent_mask(x, n):
    """
    Returns a binary mask for the top n% highest values in the map x.
    
    Parameters:
    - x (np.ndarray): The input map to be masked.
    - n (float): The percentage of highest values to keep (0 < n <= 100).
    
    Returns:
    - np.ndarray: A binary mask with 1s for the top n% values and 0s elsewhere.
    """
    if not (0 < n <= 100):
        raise ValueError("n should be between 0 and 100")
    
    # Flatten the array to work with it easily
    flattened_x = x.flatten()
    
    # Determine the number of top values to select
    num_values = int(np.ceil(len(flattened_x) * n / 100))
    
    # Get the threshold value for the top n% values
    threshold = np.partition(flattened_x, -num_values)[-num_values]
    
    # Create the binary mask
    mask = (x >= threshold).float()
    
    return mask

import numpy as np

def top_n_mask(x, n):
    """
    Returns a binary mask for the top n highest values in the map x.
    
    Parameters:
    - x (np.ndarray): The input map to be masked.
    - n (int): The number of highest values to keep.
    
    Returns:
    - np.ndarray: A binary mask with 1s for the top n values and 0s elsewhere.
    """
    if n <= 0:
        raise ValueError("n should be a positive integer")
    
    # Flatten the array to work with it easily
    flattened_x = x.flatten()
    
    if n > len(flattened_x):
        raise ValueError("n should not be greater than the number of elements in x")
    
    # Get the threshold value for the top n values
    threshold = np.partition(flattened_x, -n)[-n]
    
    # Create the binary mask
    mask = (x >= threshold)
    
    return mask

def shuffle_tensor_2d(x: torch.Tensor) -> torch.Tensor:
    
    flat = x.reshape(-1)
    # Generate a random permutation of indices
    indices = torch.randperm(flat.shape[0])
    
    # Use index_select to shuffle along the specified dimension
    shuffled_tensor = flat.index_select(0,indices)
    return shuffled_tensor.reshape(x.shape[0], x.shape[1])
    

losses = []
losses_control = [] ## shuffled mask losses

model, tokenizer = get_checkpoint(
    checkpoint_path, device=device
)

original_loss = evaluate_model_on_list_of_text(
    model=model,
    dataset=load_json_as_dict(f"datasets/test/{args.eval_topic}.json"),
    tokenizer=tokenizer,
    device=device,
    progress = False
)
print(f"Original Loss: {original_loss}")

for lesion_size in lesion_sizes:

    model, tokenizer = get_checkpoint(
        checkpoint_path, device=device
    )

    for layer_index in range(12):
        
        if lesion_size == 0:
            ## just load a bunch of zeros, lesion nothing
            selectivity_map = load_map(args.model, layer_index=layer_index, category = args.eval_topic)
            mask = top_n_mask(selectivity_map, n = 10) * 0
        
        else:
            ## load mask for single layer
            selectivity_map = load_map(args.model, layer_index=layer_index, category = args.eval_topic)
            mask = top_n_mask(selectivity_map, n = lesion_size)

        model = lesion_model(
            model=model,
            layer_index=layer_index,
            lesion_mask_2d=mask,
            factor = 0.0
        )

    eval_loss = evaluate_model_on_list_of_text(
        model=model,
        dataset=load_json_as_dict(f"datasets/test/{args.eval_topic}.json"),
        tokenizer=tokenizer,
        device=device,
        progress=False
    )
    losses.append(
        eval_loss
    )

    model, tokenizer = get_checkpoint(
            checkpoint_path, device=device
        )

    mask_shuffled = shuffle_tensor_2d(mask)
    
    model = lesion_model(
        model=model,
        layer_index=11,
        lesion_mask_2d=mask_shuffled,
        factor = 0.0
    )
    
    eval_loss_control = evaluate_model_on_list_of_text(
        model=model,
        dataset=load_json_as_dict(f"datasets/test/{args.eval_topic}.json"),
        tokenizer=tokenizer,
        device=device,
        progress=False
    )
    losses_control.append(eval_loss_control)
    print(f"Lesion size: {lesion_size} Loss: {eval_loss} Control: {eval_loss_control}")

fig = plt.figure()
plt.title(f"Eval topic: {args.eval_topic}")
plt.plot(lesion_sizes, losses, label = "lesion mask")
plt.scatter(lesion_sizes, losses)

plt.plot(lesion_sizes, losses_control, label = "random mask with same number of lesioned items")
plt.scatter(lesion_sizes, losses_control)

plt.legend()
plt.xlabel(f"Lesion size ->")
plt.ylabel(f"Test Loss ->")
plt.grid()
fig.savefig(os.path.join("results", f"{checkpoint_name}_{args.eval_topic}.png"))