"""
Hooks are None for layer2.0.conv1.conv layer.
"""

from nesim.eval.resnet import load_resnet18_checkpoint, load_resnet50_checkpoint
from nesim.eval.eshed import load_eshed_checkpoint
from helpers.load_images import load_images
from helpers.preprocess_image import preprocess
from helpers.load_brain_data import load_brain_data
from helpers.hook import get_intermediate_layer_activations

import torch
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from einops import rearrange, reduce
import scipy
import numpy as np
import json
import os

def find_all_files(directory):
    file_paths = []
    nicknames = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file[-6:] == ".torch":
                full_path = os.path.join(root, file)
                nickname = "/".join(full_path.split("/")[-2:])

                # if "lw" not in nickname: # selecting only 0.25
                #     file_paths.append(full_path)
                #     nicknames.append(nickname)

                if "simclr" not in nickname: # selecting only 0.25
                    file_paths.append(full_path)
                    nicknames.append(nickname)

    return file_paths, nicknames

device = "cuda" if torch.cuda.is_available() else "cpu"
train_test_split_index = 800  # Should be under 1000

# Layer names for resnet18, resnet50, and tdann

rn18_all_conv_layers_except_first = [
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1", #
    "layer2.0.conv2",
    "layer2.0.downsample.0", #
    "layer2.1.conv1", 
    "layer2.1.conv2", 
    "layer3.0.conv1", #
    "layer3.0.conv2",
    "layer3.0.downsample.0", #
    "layer3.1.conv1", 
    "layer3.1.conv2", 
    "layer4.0.conv1", # 
    "layer4.0.conv2",
    "layer4.0.downsample.0", #
    "layer4.1.conv1", 
    "layer4.1.conv2"
]
rn18_last_conv_layers_in_each_block =  ["layer1.1.conv2","layer2.1.conv2", "layer3.1.conv2",  "layer4.1.conv2"]

rn50_all_conv_layers_except_first = [
        "layer1.0.conv1",
        "layer1.0.conv2",
        "layer1.0.conv3",
        "layer1.1.conv1",
        "layer1.1.conv2",
        "layer1.1.conv3",
        "layer1.2.conv1",
        "layer1.2.conv2",
        "layer1.2.conv3",
        "layer2.0.conv1",
        "layer2.0.conv2", #sus
        "layer2.0.conv3",
        "layer2.0.downsample.0", #
        "layer2.1.conv1",
        "layer2.1.conv2",
        "layer2.1.conv3",
        "layer2.2.conv1",
        "layer2.2.conv2",
        "layer2.2.conv3",
        "layer2.3.conv1",
        "layer2.3.conv2",
        "layer2.3.conv3",
        "layer3.0.conv1",
        "layer3.0.conv2", #sus
        "layer3.0.conv3",
        "layer3.0.downsample.0", #
        "layer3.1.conv1",
        "layer3.1.conv2",
        "layer3.1.conv3",
        "layer3.2.conv1",
        "layer3.2.conv2",
        "layer3.2.conv3",
        "layer3.3.conv1",
        "layer3.3.conv2",
        "layer3.3.conv3",
        "layer3.4.conv1",
        "layer3.4.conv2",
        "layer3.4.conv3",
        "layer3.5.conv1",
        "layer3.5.conv2",
        "layer3.5.conv3",
        "layer4.0.conv1",
        "layer4.0.conv2",
        "layer4.0.conv3",
        "layer4.0.downsample.0", #
        "layer4.1.conv1",
        "layer4.1.conv2",
        "layer4.1.conv3",
        "layer4.2.conv1",
        "layer4.2.conv2",
        "layer4.2.conv3"
    ]
rn50_last_conv_layers_in_each_block = [
        "layer1.2.conv3",
        "layer2.3.conv3",
        "layer3.5.conv3",
        "layer4.2.conv3"
    ]

all_tdann_layers = [
    # 'base_model.conv1', 
    'base_model.layer1.0.conv1', 
    'base_model.layer1.0.conv2', 
    'base_model.layer1.1.conv1', 
    'base_model.layer1.1.conv2', 
    'base_model.layer2.0.conv1', 
    'base_model.layer2.0.conv2', 
    'base_model.layer2.1.conv1', 
    'base_model.layer2.1.conv2', 
    'base_model.layer3.0.conv1', 
    'base_model.layer3.0.conv2', 
    'base_model.layer3.1.conv1', 
    'base_model.layer3.1.conv2', 
    'base_model.layer4.0.conv1', 
    'base_model.layer4.0.conv2', 
    'base_model.layer4.1.conv1', 
    'base_model.layer4.1.conv2'
]

resnet18_checkpoints_folder = "/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18"
resnet50_checkpoints_folder = "/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet50"

# tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints"
tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints"
tdann_checkpoint_paths, tdann_file_nicknames = find_all_files(tdann_directory)

layer_names = {
    "resnet18": {
        "baseline": rn18_all_conv_layers_except_first,
        "all_conv_layers_except_first": rn18_all_conv_layers_except_first,
        "last_conv_layers_in_each_block": rn18_last_conv_layers_in_each_block,
        
    },
    "resnet50": {
        "baseline": rn50_all_conv_layers_except_first,
        "all_conv_layers_except_first": rn50_all_conv_layers_except_first,
        "last_conv_layers_in_each_block": rn50_last_conv_layers_in_each_block,
    },
    "tdann": all_tdann_layers
}

# Checkpoint names for models
checkpoint_names = {
    "resnet18": {
        "last_conv_layers_in_each_block": ["end_topo_scale_1.0_shrink_factor_3.0", "end_topo_scale_5.0_shrink_factor_3.0",
                                           "end_topo_scale_10.0_shrink_factor_3.0", "end_topo_scale_50.0_shrink_factor_3.0"],
        "all_conv_layers_except_first": ["all_topo_scale_1_shrink_factor_3.0", "all_topo_scale_5_shrink_factor_3.0",
                                         "all_topo_scale_10.0_shrink_factor_3.0", "all_topo_scale_50.0_shrink_factor_3.0"],
        "baseline": ["baseline_scale_None_shrink_factor_3.0"]
    },
    "resnet50": {
        "last_conv_layers_in_each_block": ["end_topo_scale_1_shrink_factor_3.0"],
        "all_conv_layers_except_first": ["all_topo_scale_1_shrink_factor_3.0"],
        "baseline": ["baseline_scale_None_shrink_factor_3.0"]
    }
}

# Helper function to evaluate models
def evaluate_model(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, model_type, topo_mode):
    activations = get_intermediate_layer_activations(
        model=model,
        layer_name=layer_name,
        input_images=preprocessed_images.to(device, dtype=torch.float),
    )

    activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten
    activations = activations.cpu().numpy()
    X_train, X_test = activations[:train_test_split_index], activations[train_test_split_index:]
    y_train, y_test = brain_data[:train_test_split_index], brain_data[train_test_split_index:]

    # Train regression model
    regression_model = LinearRegression()
    print(f"training regression model...  X_train {X_train.shape} y_train {y_train.shape}")
    regression_model.fit(X_train, y_train)
    y_pred = regression_model.predict(X_test)

    # Calculate metrics
    mse = mean_squared_error(y_test, y_pred)
    pearson_r = scipy.stats.pearsonr(y_pred.flatten(), y_test.flatten())[0]

    if topo_mode != "tdann":
        eval_results[model_type][topo_mode][model_name][layer_name]["mse"] = float(mse)
        eval_results[model_type][topo_mode][model_name][layer_name]["pearson_r"] = float(np.mean(pearson_r))

    elif topo_mode == "tdann":
        eval_results[model_type][model_name][layer_name]["mse"] = float(mse)
        eval_results[model_type][model_name][layer_name]["pearson_r"] = float(np.mean(pearson_r))

    print(f"Model {model_type} {topo_mode} {model_name} {layer_name} OLS brainmodel\nVal set mse: {mse}, pearson_r: {pearson_r}")

# Load data (this is common for all models)
def load_data():
    brain_data = load_brain_data(
        dataset_name="NSD1000",
        subject="1",
        roi="ffa",
        dataset_root="./",
        averaging=False,
    )
    brain_data = reduce(brain_data, "n_stim trials voxels -> n_stim voxels", "mean")  # Mean of 3 trials

    images = load_images(
        dataset_name="NSD1000",
        dataset_root="./",
    )
    preprocessed_images = preprocess(
        images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    return brain_data.cpu().numpy(), preprocessed_images

def run_experiment():
    tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints"
    tdann_checkpoint_paths, tdann_file_nicknames = find_all_files(tdann_directory)
    
    eval_results = {
    "resnet18": {
        topo_mode: {
            checkpoint: {layer: {} for layer in layer_names["resnet18"][topo_mode]} 
            for checkpoint in checkpoint_names["resnet18"][topo_mode]
        }
        for topo_mode in checkpoint_names["resnet18"]
    },
    "resnet50": {
        topo_mode: {
            checkpoint: {layer: {} for layer in layer_names["resnet50"][topo_mode]} 
            for checkpoint in checkpoint_names["resnet50"][topo_mode]
        }
        for topo_mode in checkpoint_names["resnet50"]
    },
    # "tdann": {
    #     "single_layer": {layer: {} for layer in layer_names["tdann"]}
    # },

    "tdann": {
        name: {layer: {} for layer in layer_names["tdann"]}
        for name in tdann_file_nicknames
    },
}
    import pprint
    # pprint.pprint(eval_results)
    # raise AssertionError()
    
    

    # Load brain data and preprocessed images
    brain_data, preprocessed_images = load_data()

    # # Evaluate resnet18 checkpoints
    # for topo_mode, model_list in checkpoint_names["resnet18"].items():
    #     for model_name in model_list:
    #         model = load_resnet18_checkpoint(
    #             checkpoints_folder="/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18/",
    #             model_name=model_name,
    #             epoch="final"
    #         )
    #         for layer_name in layer_names["resnet18"][topo_mode]:
    #             evaluate_model(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, "resnet18", topo_mode)
    #             # pprint.pprint(eval_results)

    # # Evaluate resnet50 checkpoints
    # for topo_mode, model_list in checkpoint_names["resnet50"].items():
    #     for model_name in model_list:
    #         model = load_resnet50_checkpoint(
    #             checkpoints_folder="/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet50/",
    #             model_name=model_name,
    #             epoch="final"
    #         )
    #         for layer_name in layer_names["resnet50"][topo_mode]:
    #             evaluate_model(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, "resnet50", topo_mode)
    #             # pprint.pprint(eval_results)

    for tdann_checkpoint_path, tdann_nickname in zip(tdann_checkpoint_paths[:1], tdann_file_nicknames[:1]):
        for layer_name in all_tdann_layers[:1]:
            print(layer_name)
            topo_mode = "tdann"
            model = load_eshed_checkpoint(tdann_checkpoint_path).to(device)
            evaluate_model(model, layer_names["tdann"][0], tdann_nickname, brain_data, preprocessed_images, eval_results, "tdann", topo_mode)
            pprint.pprint(eval_results["tdann"])

    # Save results
    with open("eval_results_tdann.json", "w") as f:
        json.dump(eval_results, f, indent=4)

run_experiment()
