from nesim.eval.resnet import load_resnet18_checkpoint, load_resnet50_checkpoint
from nesim.eval.eshed import load_eshed_checkpoint
from helpers.load_images import load_images
from helpers.preprocess_image import preprocess
from helpers.load_brain_data import load_brain_data
from helpers.hook import get_intermediate_layer_activations

import torch
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from einops import rearrange, reduce
import scipy
import numpy as np
import json
import os
from tqdm import tqdm

# Load data (this is common for all models)
def load_data(subj, roi):
    brain_data = load_brain_data(
        dataset_name="NSD1000",
        subject=subj,
        roi=roi,
        dataset_root="./",
        averaging=False,
    )
    brain_data = reduce(brain_data, "n_stim trials voxels -> n_stim voxels", "mean")  # Mean of 3 trials

    images = load_images(
        dataset_name="NSD1000",
        dataset_root="./",
    )
    preprocessed_images = preprocess(
        images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    return brain_data.cpu().numpy(), preprocessed_images

rn18_all_conv_layers_except_first = [
    # "layer1.0.conv1",
    # "layer1.0.conv2",
    # "layer1.1.conv1",
    # "layer1.1.conv2",
    # "layer2.0.conv1", #
    # "layer2.0.conv2",
    # "layer2.0.downsample.0", #
    # "layer2.1.conv1", 
    # "layer2.1.conv2", 
    "layer3.0.conv1", #
    "layer3.0.conv2",
    "layer3.0.downsample.0", #
    "layer3.1.conv1", 
    "layer3.1.conv2", 
    "layer4.0.conv1", # 
    "layer4.0.conv2",
    "layer4.0.downsample.0", #
    "layer4.1.conv1", 
    "layer4.1.conv2"
]

rn18_last_conv_layers_in_each_block =  [
    # "layer1.1.conv2",
    # "layer2.1.conv2", 
    # "layer3.1.conv2",  
    # "layer4.1.conv2"
    ]

# rn18_all_conv_layers_except_first = ["layer4.1"]
# rn18_last_conv_layers_in_each_block = ["layer4.1"]

resnet18_checkpoints_folder = "/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18"

layer_names = {
    "resnet18": {
        # "baseline": rn18_all_conv_layers_except_first,
        "all_conv_layers_except_first": rn18_all_conv_layers_except_first,
        "last_conv_layers_in_each_block": rn18_last_conv_layers_in_each_block,
        
    },
}

checkpoint_names = {
    "resnet18": {
        "last_conv_layers_in_each_block": [
            # "end_topo_scale_1.0_shrink_factor_3.0", 
            # "end_topo_scale_5.0_shrink_factor_3.0",
            # "end_topo_scale_10.0_shrink_factor_3.0", 
            # "end_topo_scale_50.0_shrink_factor_3.0"
        ],

        "all_conv_layers_except_first": [
            "all_topo_scale_1_shrink_factor_3.0", 
            "all_topo_scale_5_shrink_factor_3.0",
            "all_topo_scale_10.0_shrink_factor_3.0", 
            "all_topo_scale_50.0_shrink_factor_3.0"
        ],
        
        # "baseline": ["baseline_scale_None_shrink_factor_3.0"]
    },
}

# device = "cuda" if torch.cuda.is_available() else "cpu"
device="cpu"
train_test_split_index = 800  # Should be under 1000

def evaluate_model(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, model_type, topo_mode):
    activations = get_intermediate_layer_activations(
        model=model,
        layer_name=layer_name,
        input_images=preprocessed_images.to(device, dtype=torch.float),
    )

    activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten
    activations = activations.cpu().numpy()
    X_train, X_test = activations[:train_test_split_index], activations[train_test_split_index:]
    y_train, y_test = brain_data[:train_test_split_index], brain_data[train_test_split_index:]

    # Train regression model
    regression_model = LinearRegression()
    regression_model.fit(X_train, y_train)
    y_pred = regression_model.predict(X_test)

    # Calculate metrics
    mse = mean_squared_error(y_test, y_pred)

    pearson_r_individual = [scipy.stats.pearsonr(y_test[:, i], y_pred[:, i])[0] for i in range(y_test.shape[1])]
    # pearson_r_individual = [scipy.stats.pearsonr(y_test[i, :], y_pred[i, :])[0] for i in range(y_test.shape[0])]
    

    # raise AssertionError(len(pearson_r_individual))
    # pearson_r_individual = []
    # for i in range(y_pred.shape[1]):
    #     r = scipy.stats.pearsonr(y_pred[:, i], y_test[:, i])[0]
    #     pearson_r_individual.append(r)
    
    mean_pearson_r = float(np.mean(pearson_r_individual))

    eval_results[model_type][topo_mode][model_name][layer_name]["mse"] = float(mse)
    eval_results[model_type][topo_mode][model_name][layer_name]["mean_pearson_r"] = mean_pearson_r
    eval_results[model_type][topo_mode][model_name][layer_name]["pearson_r_individual"] = pearson_r_individual

    print(f"Val set mse: {mse}, mean_pearson_r: {mean_pearson_r}")

def run_experiment(subj, roi):
    eval_results = {
    "resnet18": {
        topo_mode: {
            checkpoint: {layer: {} for layer in layer_names["resnet18"][topo_mode]} 
            for checkpoint in checkpoint_names["resnet18"][topo_mode]
        }
        for topo_mode in checkpoint_names["resnet18"]
    },
}

    brain_data, preprocessed_images = load_data(subj, roi)
    # # Evaluate resnet18 checkpoints
    total_runs = sum(len(layer_names["resnet18"][topo_mode]) * len(model_list) for topo_mode, model_list in checkpoint_names["resnet18"].items())
    runs_left = total_runs

    for topo_mode, model_list in checkpoint_names["resnet18"].items():
        for model_name in model_list[:]:
            model = load_resnet18_checkpoint(
                checkpoints_folder="/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18/",
                model_name=model_name,
                epoch="final"
            )
            model.to(device)
            for layer_name in layer_names["resnet18"][topo_mode]:
                evaluate_model(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, "resnet18", topo_mode)

                runs_left -= 1
                print(f"Runs left: {runs_left} total: {total_runs}")

    # Save results
    with open(f"./eval_jsons_3/resnet18_eval_subj0{subj}_{roi}.json", "w") as f:
        json.dump(eval_results, f, indent=4)

import argparse
parser = argparse.ArgumentParser(description="Run experiment with voxels from ROI.")
parser.add_argument('--roi', type=str, required=True, help='Region of interest (e.g., "ppa")')
args = parser.parse_args()
run_experiment(subj=1, roi=args.roi)