from nesim.experiments.brain_model.nesim_regression_trainer import (
    NesimRegressionTrainer,
    NesimRegressionTrainerConfig,
    SchedulerConfig,
)
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper

from neuro.datasets.image_encoding_dataset.builder import ImageEncodingDatasetBuilder
from neuro.models.intermediate_output_extractor.model import IntermediateOutputExtractor
from neuro.models.image_encoder import get_image_encoder_from_name
from neuro.datasets.murty_185 import Murty185Dataset
import os
import torch
import wandb
import argparse
from lightning.pytorch import seed_everything
from nesim.utils.json_stuff import load_json_as_dict
from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)
from neuro.utils.model_building import make_convmapper_config_from_size_sequences

seed_everything(1)

### parse CLI args
parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
parser.add_argument("--fold", type=int, help="Fold index to train on")
parser.add_argument("--config", type=str, help="config file")
parser.add_argument("--resume-from-checkpoint", type=str, help="config file")

args = parser.parse_args()
########## CONSTANTS ###################

config = load_json_as_dict(filename=args.config)
barlow_data_root = "/research/XXXX-1/neuro/data/murty_185/data/"
openmind_data_root = "/mindhive/nklab3/users/XXXX-1/repos/neuro/data/murty_185/data"
data_root = barlow_data_root if os.path.exists(barlow_data_root) else openmind_data_root

device = "cuda" if torch.cuda.is_available() else "cpu"

image_encodings_folder = (
    f'./datasets/{config["model"]["intermediate_layer_name"]}/image_encodings_folder'
)
image_filenames_and_labels_folder = f'./datasets/{config["model"]["intermediate_layer_name"]}/image_filenames_and_labels_folder'
checkpoint_folder = config["checkpoint_folder"]

os.system(f"mkdir -p {checkpoint_folder}")
os.system(f'mkdir -p {config["save_checkpoint_every_n_steps_folder"]}')

image_encoder = get_image_encoder_from_name(name="clip_rn50").to(device)
intermediate_output_extractor = IntermediateOutputExtractor.from_module_name(
    model=image_encoder, name=config["model"]["intermediate_layer_name"]
)
dataset = Murty185Dataset(
    image_folder=os.path.join(data_root, "images_185/"),
    fmri_response_filename=os.path.join(data_root, "all_data.pickle"),
    transforms=image_encoder.transforms,  ## resize and normalize
)

os.system(f"mkdir -p {image_encodings_folder}")
os.system(f"mkdir -p {image_filenames_and_labels_folder}")

builder = ImageEncodingDatasetBuilder(
    intermediate_output_extractor=intermediate_output_extractor,
    image_encodings_folder=image_encodings_folder,
    image_filenames_and_labels_folder=image_filenames_and_labels_folder,
    device=device,
)

builder.build(dataset=dataset)

conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

scheduler_config = SchedulerConfig(
    step_size=config["scheduler"]["step_size"],
    gamma=config["scheduler"]["gamma"],
    verbose=config["scheduler"]["verbose"],
)


layer_names = [
    f"conv_layers.{i}"
    for i in range(0, (len(config["model"]["conv_layer_size_sequence"]) * 2), 2)
][:-1]
layer_names.extend(
    [
        f"linear_mapper.model.{i}"
        for i in range(0, (len(config["model"]["linear_layer_size_sequence"]) * 2), 2)
    ][:-2]
)

layer_wise_configs = []

for layer_name in layer_names:
    layer_wise_configs.extend(
        [
            ## scale = None -> just watch the layer's loss, do not backprop
            NeighbourhoodCosineSimilarity(layer_name=layer_name, scale=None),
            LaplacianPyramid(
                layer_name=layer_name,
                scale=config["laplacian_pyramid_loss"]["scale"],
                shrink_factor=config["laplacian_pyramid_loss"]["shrink_factor"],
            ),
        ]
    )
nesim_config = NesimConfig(
    layer_wise_configs=layer_wise_configs,
)

## very good performance on this config
regression_trainer_config = NesimRegressionTrainerConfig(
    num_epochs=20_000,
    batch_size=32,
    learning_rate=0.000096,
    momentum=0.7,
    weight_decay=0.000916035717274417,
    device_ids=[0],
    image_filenames_and_labels_folder=image_filenames_and_labels_folder,
    image_encodings_folder=image_encodings_folder,
    scheduler_config=scheduler_config,
    checkpoint_folder=checkpoint_folder,
    train_on_fold=None,
    num_folds=10,
    progress=False,
    quiet=False,
    nesim_config=nesim_config,
    wandb_log=args.wandb_log,
    save_checkpoint_every_n_steps_folder=config["save_checkpoint_every_n_steps_folder"],
    save_checkpoint_every_n_steps=100,
    apply_nesim_every_n_steps=10,
)
trainer = NesimRegressionTrainer(
    config=regression_trainer_config,
)

os.system(f"mkdir -p {checkpoint_folder}")
fold_indices = [args.fold]

if args.wandb_log:
    if config["laplacian_pyramid_loss"]["scale"] is not None:
        run_name = f'{config["model"]["intermediate_layer_name"]}-fold-{args.fold}'
    else:
        run_name = (
            f'baseline-{config["model"]["intermediate_layer_name"]}-fold-{args.fold}'
        )
    run = wandb.init(
        project="murty185-clip-rn50-convmapper",
        config={
            "trainer": regression_trainer_config.model_dump(),
            "loaded_config": config,
        },
        name=run_name,
    )
model = ConvMapper(config=conv_mapper_config)

if args.resume_from_checkpoint is not None:
    model.load(args.resume_from_checkpoint)
    print(f"resuming from checkpoint: {args.resume_from_checkpoint}")
trainer.train_single_fold(fold_idx=args.fold, model=model)
if args.wandb_log:
    wandb.finish()
