import torchvision.models as models
import os
from neuro.models.intermediate_output_extractor.model import IntermediateOutputExtractor
from neuro.utils.model_building import make_convmapper_config_from_size_sequences

from neuro.trainers.regression_trainer import (
    RegressionTrainer,
    RegressionTrainerConfig,
    SchedulerConfig,
)
from neuro.datasets.image_encoding_dataset.builder import ImageEncodingDatasetBuilder
from neuro.datasets.murty_185 import Murty185Dataset
import torchvision.transforms as transforms
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper
from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
from neuro.utils.getting_modules import get_module_by_name
from lightning.pytorch import seed_everything

seed_everything(1)

device = "cuda:0"
openmind_data_root = "/mindhive/nklab3/users/XXXX-1/repos/neuro/data/murty_185/data"
data_root = openmind_data_root

intermediate_layer_configs = load_json_as_dict("intermediate_layer_configs.json")[
    "ours"
]
conv_mapper_config_args = load_json_as_dict("conv_mapper_config_args.json")
hyperparams = load_json_as_dict("hyperparams.json")


"""
Step 0: Load our model checkpoint
"""
model = models.resnet18(weights=None)
"""
if not strict, it throws an error for unexpected keys:
Unexpected key(s) in state_dict: "_feature_blocks.conv1.weight",...
"""

# XXXX
run_name = "torchvision_recipe_shrink_factor_[5.0]_loss_scale_50_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_30_steps_apply_sorted_weights_init_filename_None"
state_dict = load_and_filter_state_dict_keys(
    checkpoint_filename=f"../../training/imagenet/resnet18/checkpoints/imagenet/{run_name}/best/best_model-v3.ckpt"
)

model.load_state_dict(state_dict, strict=False)
model.to(device)
model.transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


for config in intermediate_layer_configs:
    image_encodings_folder = (
        f'./datasets/ours/intermediate_image_encodings/{config["layer_name"]}'
    )
    image_filenames_and_labels_folder = (
        f'./datasets/ours/image_filenames_and_labels/{config["layer_name"]}'
    )

    os.system(f"mkdir -p {image_encodings_folder}")
    os.system(f"mkdir -p {image_filenames_and_labels_folder}")
    os.system(f"rm {image_encodings_folder}/*.pth")
    os.system(f"rm {image_encodings_folder}/*")

    """
    Step 1: build intermediate image encoding dataset
    """
    intermediate_output_extractor = IntermediateOutputExtractor(
        model=model,
        forward_hook_layer=get_module_by_name(module=model, name=config["layer_name"]),
    )

    dataset = Murty185Dataset(
        image_folder=os.path.join(data_root, "images_185/"),
        fmri_response_filename=os.path.join(data_root, "all_data.pickle"),
        transforms=intermediate_output_extractor.model.transforms,  ## resize and normalize
    )

    builder = ImageEncodingDatasetBuilder(
        intermediate_output_extractor=intermediate_output_extractor,
        image_encodings_folder=image_encodings_folder,
        image_filenames_and_labels_folder=image_filenames_and_labels_folder,
        device=device,
    )
    builder.build(dataset=dataset)

    """
    Step 2: build and train convmapper
    """
    conv_mapper_config = make_convmapper_config_from_size_sequences(
        conv_layer_size_sequence=config["conv_mapper_size_sequence"][
            "conv_layer_size_sequence"
        ],
        linear_layer_size_sequence=config["conv_mapper_size_sequence"][
            "linear_layer_size_sequence"
        ],
        reduce_fn=conv_mapper_config_args["reduce_fn"],
        activation=conv_mapper_config_args["activation"],
        conv_layer_kernel_size=conv_mapper_config_args["conv_layer_kernel_size"],
    )

    scheduler_config = SchedulerConfig(
        step_size=hyperparams["scheduler"]["step_size"],
        gamma=hyperparams["scheduler"]["gamma"],
        verbose=hyperparams["scheduler"]["verbose"],
    )

    ## very good performance on this config
    regression_trainer_config = RegressionTrainerConfig(
        checkpoint_folder="./checkpoints/ours",
        image_filenames_and_labels_folder=image_filenames_and_labels_folder,
        image_encodings_folder=image_encodings_folder,
        scheduler_config=scheduler_config,
        train_on_fold=hyperparams["regression_trainer"]["train_on_fold"],
        num_folds=hyperparams["regression_trainer"]["num_folds"],
        progress=hyperparams["regression_trainer"]["progress"],
        quiet=hyperparams["regression_trainer"]["quiet"],
        wandb_log=hyperparams["regression_trainer"]["wandb_log"],
        num_epochs=hyperparams["regression_trainer"]["num_epochs"],
        batch_size=hyperparams["regression_trainer"]["batch_size"],
        learning_rate=hyperparams["regression_trainer"]["learning_rate"],
        momentum=hyperparams["regression_trainer"]["momentum"],
        weight_decay=hyperparams["regression_trainer"]["weight_decay"],
        device_ids=hyperparams["regression_trainer"]["device_ids"],
    )
    trainer = RegressionTrainer(
        config=regression_trainer_config,
    )

    results_combined = []
    for fold_idx in range(hyperparams["regression_trainer"]["num_folds"]):
        result = trainer.train_single_fold(
            fold_idx=fold_idx, model=ConvMapper(config=conv_mapper_config)
        )
        results_combined.append(result)

    dict_to_json(results_combined, f'results/ours_{config["layer_name"]}.json')

print("Done :)")
