import torchvision.models as models
import torch
import os
from typing import Dict
import torch.nn as nn
from torchvision.models.resnet import resnet18
from neuro.models.intermediate_output_extractor.model import IntermediateOutputExtractor
from neuro.utils.model_building import make_convmapper_config_from_size_sequences

from neuro.trainers.regression_trainer import (
    RegressionTrainer,
    RegressionTrainerConfig,
    SchedulerConfig,
)
from neuro.datasets.image_encoding_dataset.builder import ImageEncodingDatasetBuilder
from neuro.datasets.murty_185 import Murty185Dataset
import torchvision.transforms as transforms
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper
from nesim.utils.json_stuff import dict_to_json
from lightning.pytorch import seed_everything
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict

seed_everything(1)

device = "cuda:0"
openmind_data_root = "/mindhive/nklab3/users/XXXX-1/repos/neuro/data/murty_185/data"
data_root = openmind_data_root
intermediate_layer_configs = load_json_as_dict("intermediate_layer_configs.json")[
    "tdann"
]
conv_mapper_config_args = load_json_as_dict("conv_mapper_config_args.json")
hyperparams = load_json_as_dict("hyperparams.json")


"""
Taken and modified from: XXXX
"""
# @register_model_trunk("eshednet")
class EshedNet(nn.Module):
    def __init__(self):
        """Create a new EshedNet

        Inputs:
            model_config: an AttrDict (like a dictionary, but with dot syntax support)
                that specifies the parameters for the model trunk. Specifically, we will
                expect that "model_config.TRUNK.TRUNK_PARAMS.position_dir" exists

            model_name: VISSL will pass the model name as the second arg, but we don't
                use it for anything in this case
        """
        super(EshedNet, self).__init__()

        # self.positions = self._load_positions(model_config.TRUNK.TRUNK_PARAMS.position_dir)
        self.base_model = resnet18(weights=None)

        # remove the FC layer, we're not going to need it
        self.base_model.fc = nn.Identity()

    # VISSL requires this signature for forward passes
    def forward(self, x: torch.Tensor):
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        maxpool = self.base_model.maxpool(x)

        x_1_0 = self.base_model.layer1[0](maxpool)
        x_1_1 = self.base_model.layer1[1](x_1_0)
        x_2_0 = self.base_model.layer2[0](x_1_1)
        x_2_1 = self.base_model.layer2[1](x_2_0)
        x_3_0 = self.base_model.layer3[0](x_2_1)
        x_3_1 = self.base_model.layer3[1](x_3_0)
        x_4_0 = self.base_model.layer4[0](x_3_1)
        x_4_1 = self.base_model.layer4[1](x_4_0)

        x = self.base_model.avgpool(x_4_1)
        flat_outputs = torch.flatten(x, 1)
        return flat_outputs


"""
Step 0: Load tdann model checkpoint
"""
tdann_model = EshedNet()
"""
if not strict, it throws an error for unexpected keys:
Unexpected key(s) in state_dict: "_feature_blocks.conv1.weight",...
"""
state_dict = torch.load("tdann_checkpoint.pth")["classy_state_dict"]["base_model"][
    "model"
]["trunk"]
tdann_model.load_state_dict(state_dict, strict=False)
tdann_model.to(device)
tdann_model.transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

for config in intermediate_layer_configs:
    image_encodings_folder = (
        f'./datasets/tdann/intermediate_image_encodings/{config["layer_name"]}'
    )
    image_filenames_and_labels_folder = (
        f'./datasets/tdann/image_filenames_and_labels/{config["layer_name"]}'
    )

    os.system(f"mkdir -p {image_encodings_folder}")
    os.system(f"mkdir -p {image_filenames_and_labels_folder}")
    os.system(f"rm {image_encodings_folder}/*.pth")
    os.system(f"rm {image_encodings_folder}/*")

    """
    Step 1: build intermediate image encoding dataset
    """
    intermediate_output_extractor = IntermediateOutputExtractor(
        model=tdann_model, forward_hook_layer=tdann_model.base_model.layer3[1]
    )
    dataset = Murty185Dataset(
        image_folder=os.path.join(data_root, "images_185/"),
        fmri_response_filename=os.path.join(data_root, "all_data.pickle"),
        transforms=intermediate_output_extractor.model.transforms,  ## resize and normalize
    )

    builder = ImageEncodingDatasetBuilder(
        intermediate_output_extractor=intermediate_output_extractor,
        image_encodings_folder=image_encodings_folder,
        image_filenames_and_labels_folder=image_filenames_and_labels_folder,
        device=device,
    )
    builder.build(dataset=dataset)

    """
    Step 2: build and train convmapper
    """
    conv_mapper_config = make_convmapper_config_from_size_sequences(
        conv_layer_size_sequence=config["conv_mapper_size_sequence"][
            "conv_layer_size_sequence"
        ],
        linear_layer_size_sequence=config["conv_mapper_size_sequence"][
            "linear_layer_size_sequence"
        ],
        reduce_fn=conv_mapper_config_args["reduce_fn"],
        activation=conv_mapper_config_args["activation"],
        conv_layer_kernel_size=conv_mapper_config_args["conv_layer_kernel_size"],
    )

    scheduler_config = SchedulerConfig(
        step_size=hyperparams["scheduler"]["step_size"],
        gamma=hyperparams["scheduler"]["gamma"],
        verbose=hyperparams["scheduler"]["verbose"],
    )

    ## very good performance on this config
    regression_trainer_config = RegressionTrainerConfig(
        checkpoint_folder="./checkpoints/tdann",
        image_filenames_and_labels_folder=image_filenames_and_labels_folder,
        image_encodings_folder=image_encodings_folder,
        scheduler_config=scheduler_config,
        train_on_fold=hyperparams["regression_trainer"]["train_on_fold"],
        num_folds=hyperparams["regression_trainer"]["num_folds"],
        progress=hyperparams["regression_trainer"]["progress"],
        quiet=hyperparams["regression_trainer"]["quiet"],
        wandb_log=hyperparams["regression_trainer"]["wandb_log"],
        num_epochs=hyperparams["regression_trainer"]["num_epochs"],
        batch_size=hyperparams["regression_trainer"]["batch_size"],
        learning_rate=hyperparams["regression_trainer"]["learning_rate"],
        momentum=hyperparams["regression_trainer"]["momentum"],
        weight_decay=hyperparams["regression_trainer"]["weight_decay"],
        device_ids=hyperparams["regression_trainer"]["device_ids"],
    )
    trainer = RegressionTrainer(
        config=regression_trainer_config,
    )

    results_combined = []
    for fold_idx in range(hyperparams["regression_trainer"]["num_folds"]):
        result = trainer.train_single_fold(
            fold_idx=fold_idx, model=ConvMapper(config=conv_mapper_config)
        )
        results_combined.append(result)

    dict_to_json(results_combined, f'results/tdann_{config["layer_name"]}.json')

print("Done :)")
