from nesim.experiments.brain_model.natural_scenes_dataset import (
    NaturalScenesImageEncodingDatasetBuilder,
)
from nesim.experiments.brain_model.natural_scenes_dataset import (
    MurtyNSD,
    NaturalScenesRegressionTrainerConfig,
    NaturalScenesRegressionTrainer,
    NaturalScenesImageEncodingDataset,
    SchedulerConfig,
)
from neuro.models.image_encoder import get_image_encoder_from_name
from neuro.models.intermediate_output_extractor.model import IntermediateOutputExtractor
from neuro.utils.model_building import make_convmapper_config_from_size_sequences
import os
import argparse
from nesim.utils.json_stuff import load_json_as_dict
from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)
from neuro.models.brain_response_predictor.conv_mapper import ConvMapper
import wandb

### parse CLI args
parser = argparse.ArgumentParser(
    description="Trains a convmapper on the murty185 dataset with nesim loss"
)

parser.add_argument("--config", type=str, help="config file")
parser.add_argument("--wandb-log", action="store_true", help="Enable logging to wandb")
parser.add_argument(
    "--fold",
    type=int,
    default=None,
    help="Fold index to train on, if None then will train on all folds",
)

args = parser.parse_args()

config = load_json_as_dict(filename=args.config)

device = "cuda:0"
intermediate_layer_name = config["model"]["intermediate_layer_name"]
image_encodings_folder = f"./datasets/image_encodings/{intermediate_layer_name}/"
image_filenames_and_labels_folder = (
    f"./datasets/image_filenames_and_labels/{intermediate_layer_name}/"
)
brain_signals_filename = "./datasets/nsd/component_responses.npy"
image_data_filename = "./datasets/nsd/test_images_ordered.npy"

image_encoder = get_image_encoder_from_name(name="clip_rn50").to(device)
intermediate_output_extractor = IntermediateOutputExtractor.from_module_name(
    model=image_encoder, name=intermediate_layer_name
)
dataset = MurtyNSD(
    brain_signals_filename=brain_signals_filename,
    image_data_filename=image_data_filename,
    transforms=image_encoder.transforms,
)

os.system(f"mkdir -p {image_encodings_folder}")
os.system(f"mkdir -p {image_filenames_and_labels_folder}")
os.system(f"mkdir -p checkpoints/{intermediate_layer_name}")

builder = NaturalScenesImageEncodingDatasetBuilder(
    intermediate_output_extractor=intermediate_output_extractor,
    image_encodings_folder=image_encodings_folder,
    image_filenames_and_labels_folder=image_filenames_and_labels_folder,
    device=device,
)

builder.build(dataset=dataset)

# dataset = NaturalScenesImageEncodingDataset(
#     image_filenames_and_labels_folder=image_filenames_and_labels_folder,
#     image_encodings_folder=image_encodings_folder
# )

conv_mapper_config = make_convmapper_config_from_size_sequences(
    conv_layer_size_sequence=config["model"]["conv_layer_size_sequence"],
    linear_layer_size_sequence=config["model"]["linear_layer_size_sequence"],
    reduce_fn=config["model"]["conv_mapper_reduce_fn"],
    activation=config["model"]["conv_mapper_activation"],
    conv_layer_kernel_size=config["model"]["conv_mapper_kernel_size"],
)

# scheduler_config = SchedulerConfig(
#     step_size = 10_000,
#     gamma = 0.9,
#     verbose = False,
# )
scheduler_config = None

layer_names = [
    f"conv_layers.{i}"
    for i in range(0, (len(config["model"]["conv_layer_size_sequence"]) * 2), 2)
][:-1]
layer_names.extend(
    [
        f"linear_mapper.model.{i}"
        for i in range(0, (len(config["model"]["linear_layer_size_sequence"]) * 2), 2)
    ][:-2]
)

layer_wise_configs = []

for layer_name in layer_names:
    layer_wise_configs.extend(
        [
            ## scale = None -> just watch the layer's loss, do not backprop
            # NeighbourhoodCosineSimilarity(layer_name = layer_name, scale = None),
            LaplacianPyramid(layer_name=layer_name, scale=1e-5, shrink_factor=[5.0]),
        ]
    )

nesim_config = NesimConfig(
    layer_wise_configs=layer_wise_configs,
)

regression_trainer_config = NaturalScenesRegressionTrainerConfig(
    num_epochs=10_000,
    batch_size=32,
    learning_rate=3e-3,
    momentum=0.7,
    weight_decay=1e-5,
    device_ids=[0],
    image_filenames_and_labels_folder=image_filenames_and_labels_folder,
    image_encodings_folder=image_encodings_folder,
    scheduler_config=scheduler_config,
    checkpoint_folder=f'./checkpoints/{config["model"]["intermediate_layer_name"]}',
    train_on_fold=None,
    num_folds=10,
    progress=False,
    quiet=False,
    nesim_config=nesim_config,
    wandb_log=args.wandb_log,
    save_checkpoint_every_n_steps_folder="./step_checkpoints",
    save_checkpoint_every_n_steps=None,
    apply_nesim_every_n_steps=999999999,
    transforms=image_encoder.transforms,
)

trainer = NaturalScenesRegressionTrainer(
    config=regression_trainer_config,
)

if args.fold is not None:
    fold_indices = [args.fold]
else:
    fold_indices = range(10)

for fold_idx in fold_indices:
    if args.wandb_log:
        run = wandb.init(
            project="nsd-clip-rn50-convmapper",
            config=regression_trainer_config.model_dump(),
            name=f'{config["model"]["intermediate_layer_name"]}-fold-{fold_idx}',
        )
    trainer.train_single_fold(
        fold_idx=fold_idx, model=ConvMapper(config=conv_mapper_config)
    )
    if args.wandb_log:
        wandb.finish()
