from size_animacy_dataset import SizeAnimacyDataset
from torchvision.models import resnet18
import torchvision.transforms as transforms
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
from nesim.utils.hook import ForwardHook
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.normalization import z_score_normalize
import os
import argparse
from scipy.stats import pearsonr
from nesim.utils.json_stuff import dict_to_json

parser = argparse.ArgumentParser(description="Script for handling forward hook module and checkpoint path.")

parser.add_argument(
    "--layer-name",
    type=str,
    help="Name of the forward hook module (example: layer3.0.conv2)",
    required = True
)

parser.add_argument(
    "--checkpoint-path",
    type=str,
    default=None,
    help="Path to the checkpoint file (default: None)",
)

args = parser.parse_args()

"""
topo model checkpoint path:
"/research/XXXX-1/nesim_old/training/imagenet/resnet18/checkpoints/imagenet/shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps/best/best_model.ckpt"
"""

device = "cuda:0"
figures_folder = "./figures"
results_json_folder = "./results"
model_name = 'topo' if args.checkpoint_path is not None else 'baseline'

eval_transforms = transforms.Compose(
    [
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]
)

model = resnet18(weights = "DEFAULT")
model = model.eval().to(device)

if args.checkpoint_path is not None:
    state_dict = load_and_filter_state_dict_keys(checkpoint_filename=args.checkpoint_path)
    model.load_state_dict(state_dict)

forward_hook = ForwardHook(
    module = get_module_by_name(
        module = model, 
        name = args.layer_name
    )
)

dataset = SizeAnimacyDataset(
    folder = "Size-Animacy"
)

all_hook_outputs = []
with torch.no_grad():
    for i in tqdm(range(len(dataset))):
        item = dataset[i]
        image_tensor = eval_transforms(item["image"]).to(device).unsqueeze(0)
        logits = model.forward(image_tensor)
        hook_output = forward_hook.output.cpu()
        assert hook_output.ndim == 4, f'Expected 4 dimensional hook output but got: {hook_output.ndim}'
        all_hook_outputs.append(hook_output)

all_hook_outputs = torch.cat(
    all_hook_outputs,
    dim = 0
)

## take mean along last height and width dims
## (dataset_index, num_channels, height, width) -> (dataset_index, num_channels)
all_hook_outputs = all_hook_outputs.mean(-1).mean(-1)
all_hook_outputs = z_score_normalize(all_hook_outputs, dim = 0)

def plot_difference(
    model_name: str,
    layer_name: str,
    figures_folder: str,
    results_json_folder: str
):
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (15 , 6))
    mean_activations_flat = {
        "animacy": None,
        "size": None
    }
    for index, field in enumerate(["animacy" ,"size"]):
        labels = torch.tensor(dataset.labels[field])

        ## mean activation for all animate objects reshaped to cortical sheet size
        mean_activation_animate = all_hook_outputs[labels==0].mean(0)
        size = find_rectangle_dimensions(mean_activation_animate.shape[0])
        mean_activation_animate = mean_activation_animate.reshape(size.height, size.width)

        ## mean activation for all inanimate objects reshaped to cortical sheet size
        mean_activation_inanimate = all_hook_outputs[labels==1].mean(0)
        size = find_rectangle_dimensions(mean_activation_inanimate.shape[0])
        mean_activation_inanimate = mean_activation_inanimate.reshape(size.height, size.width)
        delta = mean_activation_animate - mean_activation_inanimate
        ax[index].imshow(delta, cmap = "coolwarm")

        if field == "animacy":
            ax[index].set_title(f"Difference in mean activation\nRed: Animate\nBlue: Inanimate\nModel: {model_name}\nLayer: {args.layer_name}")
            mean_activations_flat[field] = delta.reshape(-1)
        else:
            ax[index].set_title(f"Difference in mean activation\nRed: Small\nBlue: Big\nModel: {model_name}\nLayer: {args.layer_name}")
            mean_activations_flat[field] = delta.reshape(-1)
        ## chatgpt calculate the pearsonr correlation between mean_activations_flat["animacy"] and mean_activations_flat["size"] and put it in fig.suptitle
        ax[index].axis("off")

     # Calculate Pearson correlation
    pearsonr_corr, p_value = pearsonr(mean_activations_flat["animacy"], mean_activations_flat["size"])

    # Update figure title with correlation
    fig.suptitle(f"Pearson Correlation: {pearsonr_corr}\n P value: {p_value}", fontsize=14)
    fig.tight_layout()

    # Save as an editable PDF
    pdf_file_path = os.path.join(
        figures_folder,
        f"{model_name}_{layer_name}.pdf"
    )
    fig.savefig(pdf_file_path, format="pdf")

    dict_to_json(
        {
            "layer_name": layer_name,
            "pearson_correlation": pearsonr_corr,
            "p_value": p_value
        },
        filename=os.path.join(
            results_json_folder,
            f"{model_name}_{layer_name}.json"
        )
    )

plot_difference(
    model_name=model_name,
    layer_name=args.layer_name,
    figures_folder=figures_folder,
    results_json_folder=results_json_folder
)