"""
wget -O datasets/nsd_stimuli.hdf5 XXXX
"""

from nesim.eval.nsd import NSDStimuli
import torchvision.transforms as transforms
from nesim.eval.resnet import EvalSuite
from torch.utils.data import DataLoader
from nesim.eval.resnet import EvalSuite, load_resnet50_checkpoint
from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
from tqdm import tqdm
from einops import rearrange
from utils import correlation_matrix, euclidean_distance_tensor
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
import torch

MAX_NUM_BATCHES = 30
device = "cuda:0"
model_names = [
    "baseline_scale_None_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_30_shrink_factor_3.0",
]
layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet50/layer_names.json"
)

imagenet_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]
)

dataset = NSDStimuli(
    hdf5_filename="/research/datasets/nsd/nsd_stimuli.hdf5",
    transform=imagenet_transforms
)

eval_suite = EvalSuite(
    dataloader=DataLoader(dataset=dataset, shuffle=True, batch_size=32)
)
correlation_vs_distance_data = {}
def get_distances_and_mean_correlations(topo_layer_names: list[str], hook_outputs: dict):
    data = {}
    for row, layer_name in tqdm(enumerate(topo_layer_names), desc = f'computing dist vs corr'):

        hook_outputs_single_layer = hook_outputs[layer_name]

        if hook_outputs_single_layer.ndim == 4:
            hook_outputs_single_layer = (
                rearrange(hook_outputs_single_layer, "b c h w -> (b h w) c")
                .to(device)
                .float()
            )

        corr_matrix = correlation_matrix(tensor=hook_outputs_single_layer)
        size = find_rectangle_dimensions(area=corr_matrix.shape[0])

        distance_matrix = euclidean_distance_tensor(
            height=size.height, width=size.width
        ).to(corr_matrix.device)

        all_correlation_values = corr_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]
        distances = distance_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]
        # Assuming distances and all_correlation_values are defined
        unique_distances = np.unique(distances.cpu()).tolist()

        mean_correlations = [
            float(np.mean((all_correlation_values[distances == d].cpu()).numpy()))
            for d in unique_distances
        ]
        data[layer_name] = {}
        data[layer_name]["distances"] = unique_distances
        data[layer_name]["correlations"] = mean_correlations
        data[layer_name]["lower_triangle_correlations"] = all_correlation_values.tolist()
        data[layer_name]["lower_triangle_distances"] = distances.tolist()

    return data

def pouya_smoothness_metric(correlations: np.array):

    return float(np.max(correlations) - np.min(correlations))

results= {}
raw_data = {}
for model_name in model_names:
    correlation_vs_distance_data[model_name] = {}
    results[model_name] = {}
    raw_data[model_name] = {}

    model = load_resnet50_checkpoint(
        checkpoints_folder= "/research/XXXX-1/toponets_resnet50_imagenet_checkpoints",
        model_name=model_name,
        epoch="final"
    )
    model.eval()

    if model_name.startswith("end_topo"):
        topo_layer_names = layer_names["last_conv_layers_in_each_block"]
        # print(f'END TOPO TAKING LAST LATER ONLY FOR DEBUGGING')
    elif model_name.startswith("all_topo"):
        topo_layer_names = layer_names["all_conv_layers_except_first"]
    else:
        topo_layer_names = layer_names["all_conv_layers_except_first"]

    if model_name.startswith("baseline") or model_name.startswith("all_topo"):
        topo_layer_names_new = []
        for t in topo_layer_names:
            if t.endswith(f".conv"):
                ## remove the .conv from the end of t
                t = t.rstrip(".conv")
            topo_layer_names_new.append(t)

        topo_layer_names = topo_layer_names_new

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model, 
        layer_names=topo_layer_names, 
        progress=True, 
        max_num_batches=MAX_NUM_BATCHES, 
        # spatial_pooling='max' ## following margalit2024's technique i.e spatial maxpooling
    )

    data = get_distances_and_mean_correlations(
        topo_layer_names=topo_layer_names,
        hook_outputs=hook_outputs
    )
    dict_to_json(
        data,
        filename=f"raw_results_{model_name}.json"
    )

    for layer_name in data:
        correlation_vs_distance_data[model_name][layer_name] = data[layer_name]["correlations"]

    smoothness_scores = {}
    for layer_name in topo_layer_names:
        smoothness = pouya_smoothness_metric(
            correlations=data[layer_name]["correlations"]
        )
        smoothness_scores[layer_name] = smoothness
    
    print(f"{model_name}")
    print(smoothness_scores)
    print(f"\n\n\n")
    results[model_name] = smoothness_scores
    raw_data[model_name] = data

dict_to_json(
    dictionary=results,
    filename ="results.json"
)

# raise AssertionError(correlation_vs_distance_data)
dict_to_json(
    dictionary=correlation_vs_distance_data,
    filename="correlation_vs_distance_data.json"
)
dict_to_json(
    dictionary=raw_data,
    filename ="raw_data.json"
)