from nesim.eval.resnet import load_resnet18_checkpoint, load_resnet50_checkpoint
from helpers.load_images import load_images
from helpers.preprocess_image import preprocess
from helpers.load_brain_data import load_brain_data
from helpers.hook import get_intermediate_layer_activations

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from einops import rearrange
from einops import reduce
import scipy
import numpy as np
import json
import os

"""
end topo layer names: 
XXXX

all topo layer names:
XXXX
"""
# last_conv_layers_in_each_block and all_conv_layers_except_first and baseline are topo modes
resnet18_checkpoint_names = {
    "last_conv_layers_in_each_block":[
    "end_topo_scale_1.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_50.0_shrink_factor_3.0",
],
    "all_conv_layers_except_first":[
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0"
],
    "baseline":[
    "baseline_scale_None_shrink_factor_3.0"
]
}

resnet50_checkpoint_names = {
    "last_conv_layers_in_each_block": ["end_topo_scale_1_shrink_factor_3.0",],
    "all_conv_layers_except_first": ["all_topo_scale_1_shrink_factor_3.0",],
    "baseline":["baseline_scale_None_shrink_factor_3.0"],
}

def find_all_files(directory):
    file_paths = []
    nicknames = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file[-4:] == ".pth":
                full_path = os.path.join(root, file)
                file_paths.append(full_path)
                nickname = "/".join(full_path.split("/")[-2:])
                nicknames.append(nickname)
    return file_paths, nicknames


tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints"
tdann_checkpoint_paths, tdann_file_nicknames = find_all_files(tdann_directory)

all_layers_to_try = {
    "all_conv_layers_except_first": [
        "layer1.0.conv1",
        "layer1.0.conv2",
        "layer1.1.conv1",
        "layer1.1.conv2",
        "layer2.0.conv1.conv",
        "layer2.0.conv2",
        "layer2.0.downsample.0.conv",
        "layer2.1.conv1",
        "layer2.1.conv2",
        "layer3.0.conv1.conv",
        "layer3.0.conv2",
        "layer3.0.downsample.0.conv",
        "layer3.1.conv1",
        "layer3.1.conv2",
        "layer4.0.conv1.conv",
        "layer4.0.conv2",
        "layer4.0.downsample.0.conv",
        "layer4.1.conv1",
        "layer4.1.conv2",
    ],
    "last_conv_layers_in_each_block": [
        "layer1.1.conv2",
        "layer2.1.conv2",
        "layer3.1.conv2",
        "layer4.1.conv2",
    ],
}

resnet18_layer_names = all_layers_to_try
resnet50_layer_names = all_layers_to_try
tdann_layer_names = ["layer4.1.conv2"] #fix this

eval_results = {
    "resnet18": {
        name: {layer: {} for layer in resnet18_layer_names}
        for name in resnet18_checkpoint_names
    },
    "resnet50": {
        name: {layer: {} for layer in resnet50_layer_names}
        for name in resnet50_checkpoint_names
    },
    "tdann": {
        name: {layer: {} for layer in tdann_layer_names}
        for name in tdann_file_nicknames
    },
}


resnet_18_layer_name = "layer4.1.conv2"
resnet_50_layer_name = "layer4.1.conv2"
tdann_layer_name = "base_model.layer4.1.conv2"

train_test_split_index = 900  # should be under 1000

for resnet_18_layer_name in all_layers_to_try:
    for model_name in resnet18_checkpoint_names:
        backend_model = load_resnet18_checkpoint(
            ## I kept the checkpoints here
            # checkpoints_folder="/home/XXXX-5/toponets_checkpoints/resnet18",
            checkpoints_folder="/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18/",
            model_name=model_name,
            epoch="final",  ## dont change this
        )

        # NSD 1000 brain data, image data
        brain_data = load_brain_data(
            dataset_name="NSD1000",
            subject="1",
            roi="ffa",
            dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
            averaging=False,
        )
        brain_data = reduce(
            brain_data, "n_stim trials voxels -> n_stim voxels", "mean"
        )  # mean of 3 trials

        images = load_images(
            dataset_name="NSD1000",
            dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
        )

        preprocessed_images = preprocess(
            images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

        # activations
        activations = get_intermediate_layer_activations(
            model=backend_model,
            layer_name=resnet_18_layer_name,  # FOR RN18, could not find layer in XXXX
            input_images=preprocessed_images.to(device, dtype=torch.float),
        )
        activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten

        # train ols model
        regression_model = LinearRegression()

        activations, brain_data = activations.cpu().numpy(), brain_data.cpu().numpy()
        X_train, X_test = (
            activations[:train_test_split_index],
            activations[train_test_split_index:],
        )
        y_train, y_test = (
            brain_data[:train_test_split_index],
            brain_data[train_test_split_index:],
        )
        regression_model.fit(X_train, y_train)
        y_pred = regression_model.predict(X_test)
        mse = mean_squared_error(y_test, y_pred)
        pearson_r = scipy.stats.pearsonr(y_pred, y_test)[0]
        print(
            f"model Resnet 18 {model_name} OLS brainmodel\n val set mse: {mse}n pearson_r: {np.mean(pearson_r)}"
        )
        eval_results["resnet18"][model_name]["mse"] = float(mse)
        eval_results["resnet18"][model_name]["pearson_r"] = float(np.mean(pearson_r))

"""
all topo layer_names: XXXX

end topo layer names: XXXX
"""


for model_name in resnet50_checkpoint_names:
    backend_model = load_resnet50_checkpoint(
        ## I kept the checkpoints here
        # checkpoints_folder="/home/XXXX-5/toponets_checkpoints/resnet18",
        checkpoints_folder="/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet50/",
        model_name=model_name,
        epoch="final",  ## dont change this
    )

    # NSD 1000 brain data, image data
    brain_data = load_brain_data(
        dataset_name="NSD1000",
        subject="1",
        roi="ffa",
        dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
        averaging=False,
    )
    brain_data = reduce(
        brain_data, "n_stim trials voxels -> n_stim voxels", "mean"
    )  # mean of 3 trials

    images = load_images(
        dataset_name="NSD1000",
        dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
    )

    preprocessed_images = preprocess(
        images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    # activations
    activations = get_intermediate_layer_activations(
        model=backend_model,
        layer_name=resnet_50_layer_name,
        input_images=preprocessed_images.to(device, dtype=torch.float),
    )
    activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten

    # train ols model
    regression_model = LinearRegression()

    activations, brain_data = activations.cpu().numpy(), brain_data.cpu().numpy()
    X_train, X_test = (
        activations[:train_test_split_index],
        activations[train_test_split_index:],
    )
    y_train, y_test = (
        brain_data[:train_test_split_index],
        brain_data[train_test_split_index:],
    )
    regression_model.fit(X_train, y_train)
    y_pred = regression_model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    pearson_r = scipy.stats.pearsonr(y_pred, y_test)[0]
    print(
        f"model Resnet 50 {model_name} OLS brainmodel\n val set mse: {mse}n pearson_r: {np.mean(pearson_r)}"
    )
    eval_results["resnet50"][model_name]["mse"] = float(mse)
    eval_results["resnet50"][model_name]["pearson_r"] = float(np.mean(pearson_r))


"""
Now lets load the eshed checkpoints for eval

wget -O eshed_checkpoints/supervised_resnet18_lw0.pth XXXX
... download other checkpoints

model_final_checkpoint_phase199.torch
XXXX


"""
from nesim.eval.eshed import load_eshed_checkpoint

for tdann_checkpoint_path, tdann_nickname in zip(
    tdann_checkpoint_paths, tdann_file_nicknames
):
    backend_model = load_eshed_checkpoint(tdann_checkpoint_path).to(device)

    # NSD 1000 brain data, image data
    brain_data = load_brain_data(
        dataset_name="NSD1000",
        subject="1",
        roi="ffa",
        dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
        averaging=False,
    )
    brain_data = reduce(
        brain_data, "n_stim trials voxels -> n_stim voxels", "mean"
    )  # mean of 3 trials

    images = load_images(
        dataset_name="NSD1000",
        dataset_root="/home/XXXX-5/repos/robust-brainmodels/temp_root/",
    )

    preprocessed_images = preprocess(
        images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    # activations
    activations = get_intermediate_layer_activations(
        model=backend_model,
        layer_name=tdann_layer_name,  # FOR RN18, could not find layer in XXXX
        input_images=preprocessed_images.to(device, dtype=torch.float),
    )
    activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten

    # train ols model
    regression_model = LinearRegression()

    activations, brain_data = activations.cpu().numpy(), brain_data.cpu().numpy()
    X_train, X_test = (
        activations[:train_test_split_index],
        activations[train_test_split_index:],
    )
    y_train, y_test = (
        brain_data[:train_test_split_index],
        brain_data[train_test_split_index:],
    )
    regression_model.fit(X_train, y_train)
    y_pred = regression_model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    pearson_r = scipy.stats.pearsonr(y_pred, y_test)[0]
    print(
        f"model Resnet 18 {tdann_nickname} OLS brainmodel\n val set mse: {mse}n pearson_r: {np.mean(pearson_r)}"
    )
    eval_results["tdann"][tdann_nickname]["mse"] = float(mse)
    eval_results["tdann"][tdann_nickname]["pearson_r"] = float(np.mean(pearson_r))


with open("eval_results.json", "w") as f:
    json.dump(eval_results, f, indent=4)
