from nesim.eval.resnet import load_resnet18_checkpoint, load_resnet50_checkpoint
from nesim.eval.eshed import load_eshed_checkpoint
from helpers.load_images import load_images
from helpers.preprocess_image import preprocess
from helpers.load_brain_data import load_brain_data
from helpers.hook import get_intermediate_layer_activations

import torch
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from einops import rearrange, reduce
import scipy
import numpy as np
import json
import os
from tqdm import tqdm

# Load data (this is common for all models)
def load_data(subj, roi):
    brain_data = load_brain_data(
        dataset_name="NSD1000",
        subject=subj,
        roi=roi,
        dataset_root="./",
        averaging=False,
    )
    brain_data = reduce(brain_data, "n_stim trials voxels -> n_stim voxels", "mean")  # Mean of 3 trials

    images = load_images(
        dataset_name="NSD1000",
        dataset_root="./",
    )
    preprocessed_images = preprocess(
        images=images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    # import matplotlib.pyplot as plt # c h w  
    # plt.imshow(images[0].permute(1, 2, 0))
    # plt.savefig("bomb.png")
    # raise AssertionError()

    return brain_data.cpu().numpy(), preprocessed_images

all_tdann_layers = [
    # 'base_model.conv1', 
    'base_model.layer1.0.conv1', 
    'base_model.layer1.0.conv2', 
    'base_model.layer1.1.conv1', 
    'base_model.layer1.1.conv2', 
    'base_model.layer2.0.conv1', 
    'base_model.layer2.0.conv2', 
    'base_model.layer2.1.conv1', 
    'base_model.layer2.1.conv2', 
    'base_model.layer3.0.conv1', 
    'base_model.layer3.0.conv2', 
    'base_model.layer3.1.conv1', 
    'base_model.layer3.1.conv2', 
    'base_model.layer4.0.conv1', 
    'base_model.layer4.0.conv2', 
    'base_model.layer4.1.conv1', 
    'base_model.layer4.1.conv2'
]

resnet18_checkpoints_folder = "/home/XXXX-5/repos/nesim/training/brainmodels/toponet_checkpoints/resnet18"

layer_names = {
    "tdann": all_tdann_layers
}

device = "cuda" if torch.cuda.is_available() else "cpu"
train_test_split_index = 800  # Should be under 1000

def save_activation_as_mat(model, layer_name, model_name, brain_data, preprocessed_images, eval_results, model_type, topo_mode):
    activations = get_intermediate_layer_activations(
        model=model,
        layer_name=layer_name,
        input_images=preprocessed_images.to(device, dtype=torch.float),
    )

    # raise AssertionError(activations.shape) # [1000, 64, 56, 56]
    activations = rearrange(activations, "b c h w -> b (c h w)")  # Flatten
    activations = activations.cpu().numpy()

    filename = f"./mat_files/tdann_{layer_name}_{model_name}.mat"

    print(activations.shape, "saving ...", filename)
    scipy.io.savemat(filename, {'mydata': activations})


def run_experiment(subj, roi):
    # tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints/"
    tdann_directory = "/home/XXXX-5/repos/nesim/training/brainmodels/eshed_checkpoints/"

    tdann_file_nicknames = [
    'supervised_spatial_resnet18_swappedon_SineGrating2019_lwx5_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_spatial_resnet18_swappedon_SineGrating2019_lwx2_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_resnet18_lw0_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_spatial_resnet18_swappedon_SineGrating2019_lwx10_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_spatial_resnet18_swappedon_SineGrating2019_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_spatial_resnet18_swappedon_SineGrating2019_lw01_checkpoints/model_final_checkpoint_phase199.torch',
    'supervised_spatial_resnet18_swappedon_SineGrating2019_lwx100_checkpoints/model_final_checkpoint_phase199.torch',

    # "alex_tdann_res18_simclr_checkpoint.torch",
    # "alex_tdann_res18_simclr_spatial_sinegrating.torch"
]

    tdann_checkpoint_paths = [tdann_directory+i for i in tdann_file_nicknames]
    
    eval_results = {
    "tdann": {
        name: {layer: {} for layer in layer_names["tdann"]}
        for name in tdann_file_nicknames
    },
}

    brain_data, preprocessed_images = load_data(subj, roi)
    # # Evaluate resnet18 checkpoints
    total_runs_tdann = len(tdann_checkpoint_paths) * len(all_tdann_layers)    
    runs_left_tdann = total_runs_tdann

    for tdann_checkpoint_path, tdann_nickname in zip(tdann_checkpoint_paths[:], tdann_file_nicknames[:]):
        for layer_name in all_tdann_layers[:]:
            topo_mode = "tdann"
            model = load_eshed_checkpoint(tdann_checkpoint_path).to(device)
            save_activation_as_mat(model, layer_name, tdann_nickname.split("/")[0], brain_data, preprocessed_images, eval_results, "tdann", topo_mode)
            runs_left_tdann -= 1
            print(f"Runs left: {runs_left_tdann} total {total_runs_tdann}")

import argparse
parser = argparse.ArgumentParser(description="Run experiment with voxels from ROI.")
parser.add_argument('--roi', type=str, required=True, help='Region of interest (e.g., "ppa")')
args = parser.parse_args()
run_experiment(subj=1, roi=args.roi)