# Trains `tinyRSNN` model on all sessions of both monkeys.
# # # # # # # # #

# Pretraining is switched on by default and performed on all sessions of each monkey.
# If pretraining should be limited to the three sessions used in the challenge, run with data.pretrain_filenames=challenge-data


# CONFIG
from omegaconf import DictConfig
from omegaconf import open_dict
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork

# NUMERIC
from challenge import get_model, get_model_attention_V2, train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate
from challenge import get_dataloader_foundation_crossSet, evaluate_with_testdata
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np

# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:

    logger.info("Starting new simulation...")
    if cfg.testFlag:
        cfg.training.nb_epochs_pretrain = 2
        # cfg.training.earlystop_patience = 3
        # cfg.training.batchsize_finetuning = 200
        # cfg.training.batchsize_pretrain = 200
        cfg.training.earlystop_min_epochs = 0
        # cfg.data.extend_data = False
        # cfg.data.mix_continuous_uncontinuous = True
        # cfg.data.sample_duration = 1
        # cfg.model.step_training = True
        # cfg.model.step_num = 2
        cfg.model.Attention_qkv = 'conv_linear'
        cfg.model.multi_BN = True
        cfg.model.multi_hidden = True
        cfg.model.nb_linear_hidden = 1
        cfg.model.output_BN = False
        cfg.model.reshape_num = 4
        cfg.model.MLP_size = [512]
        cfg.data.nb_inputs["indy"] = 192
        cfg.data.nb_inputs["C05"] = 192
        # cfg.pretrain_monkeys=["C05"]
        with open_dict(cfg.data.pretrain_filenames):
            del cfg.data.pretrain_filenames["indy"]["indy03"]
            del cfg.data.pretrain_filenames["indy"]["indy04"]
            del cfg.data.pretrain_filenames["indy"]["indy05"]
            del cfg.data.pretrain_filenames["indy"]["indy06"]
            del cfg.data.pretrain_filenames["C05"]["C05_03"]
            del cfg.data.pretrain_filenames["C05"]["C05_04"]
            del cfg.data.pretrain_filenames["C05"]["C05_05"]
            del cfg.data.pretrain_filenames["C05"]["C05_06"]


    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)
    print(cfg.data.sample_duration)

    # # # # # # # # #
    # # # # # # # # #
    # PRETRAINING OR LOADING STATE DICT
    # # # # # # # # # # #
    for monkey_name in cfg.train_monkeys:

        nb_inputs = cfg.data.nb_inputs[monkey_name]

        assert cfg.load_state[cfg.task_retrain][monkey_name], "Pretraining or model state must be loaded."
        logger.info("Loading pretrained model for " + monkey_name)
        assert cfg.seed, "Seed must be set to load pretrained model state."

        pretrain_model_dict = os.path.join(
            Path(__file__).resolve().parent,
            cfg.load_state[cfg.task_retrain][monkey_name][cfg.seed - 1]
        )
        pretrained_model = load_model_state(pretrain_model_dict, target_device=cfg.device)
        logger.info("Model state loaded: %s", cfg.seed)
        assert pretrained_model is not None, "Pretrained model state must be loaded."

        if cfg.pretraining:

            filenames = {key: cfg.data.pretrain_filenames[key] for key in cfg.train_monkeys}

            print("=*=" * 50)
            print("if using crossBrain Data: ", cfg.with_S1)
            print("if using continus trial: ", cfg.data.continuous_trial)
            print("nb_inputs: ",nb_inputs)
            print("sample_duration: ",cfg.data.sample_duration)
            print("=*=" * 50)
            print("Pretraining on the following monkeys:")
            for key in filenames:
                print(f"{key}: {filenames[key]}")

            print("=*=" * 50)
            print("finetuning on the following monkeys:")
            for key in cfg.train_monkeys:
                print("monkey:", key)
                for file in cfg.data.filenames[key]:
                    print("file:", file)
            print("=*=" * 50)

            # GET MODEL
            # # # # # # # # #
            logger.info("Constructing model for crossSet pretraining...")
            pretrain_dat, pretrain_val_dat, _ = dataloader.get_multiple_set_data(
                filenames, nb_inputs=nb_inputs, with_S1=cfg.with_S1
            )
            if cfg.attentionFlag == False:
                model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=None, stateFlag='pretrain')
            else:
                model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=None, stateFlag='pretrain')
            model.pretrainFlag = 'pretrain'

            logger.info("Configuring model...")
            model = configure_model(model, cfg, dtype)

            model.load_state_dict(pretrained_model)
            logger.info("Pretrained model state loaded.")

            if cfg.model.multi_BN:
                model.BN_switch(monkey_name)
                # del model.multiBN
            if cfg.model.multi_hidden:
                model.hidden_switch(monkey_name)
            if cfg.reparameterization:
                model.rep()
            # if not cfg.model.output_BN:
            #     if cfg.model.multiple_readouts:
            #         for c in model.connections[-5:]:
            #             if hasattr(c, "bn") and c.bn is not None:
            #                 c.bn = None
            #     else:
            #         c= model.connections[-1]
            #         if hasattr(c, "bn") and c.bn is not None:
            #             c.bn = None

            logger.info("Pretraining on all pretrain sessions...")
            model, history = train_validate_model(
                model,
                cfg,
                pretrain_dat,
                pretrain_val_dat,
                cfg.training.nb_epochs_pretrain,
                verbose=cfg.training.verbose,
                snapshot_prefix="tinyRSNN_pretrain_pretrain_cross_dataSet",
            )
            results = {}
            for k, v in history.items():
                if "val" in k:
                    results[k] = v.tolist()
                else:
                    results["train_" + k] = v.tolist()

            logger.info("Pretraining complete.")

            # Save to JSON file with indentation
            converted_results = convert_np_float_to_float(results)
            with open("tinyRSNN-results-pretraining-pretrain_cross_dataSet.json", "w") as f:
                json.dump(converted_results, f, indent=4)

            # Save pretrained model state
            save_model_state(model, "tinyRSNN-pretrained-pretrain_cross_dataSet.pth")







    else:
        logger.info("No pretraining or model state loaded.")
        pretrained_model = None


if __name__ == "__main__":
    train_all()
