# Trains `tinyRSNN` model on all sessions of both monkeys.
# # # # # # # # #

# Pretraining is switched on by default and performed on all sessions of each monkey.
# If pretraining should be limited to the three sessions used in the challenge, run with data.pretrain_filenames=challenge-data


# CONFIG
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork

# NUMERIC
from challenge import get_model, get_model_attention_V2, train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate
from challenge import get_dataloader_foundation_crossSet, evaluate_with_testdata
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np

# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:

    logger.info("Starting new simulation...")
    if cfg.testFlag:
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50, 'test!!!!!!!!!!!!!', "=*=" * 50,)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        cfg.training.nb_epochs_pretrain=2
        cfg.data.nb_inputs['indy']=192
        cfg.data.nb_inputs['C05']=192
        # cfg.pretrain_monkeys = ["loco"]
        cfg.pretrain_monkeys = ["indy"]


    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)
    print(cfg.data.sample_duration)

    # # # # # # # # #
    # # # # # # # # #
    # PRETRAINING OR LOADING STATE DICT
    # # # # # # # # # # #
    assert all(cfg.data.nb_inputs[monkey_name] == cfg.data.nb_inputs["indy"]
               for monkey_name in cfg.pretrain_monkeys),\
        "All monkeys must have the same number of inputs."
    nb_inputs = cfg.data.nb_inputs["indy"]
    if cfg.pretraining:
        filenames = {key: cfg.data.pretrain_filenames[key] for key in cfg.pretrain_monkeys}

        print("=*=" * 50)
        print("if using crossBrain Data: ", cfg.with_S1)
        print("if using continus trial: ", cfg.data.continuous_trial)
        print("nb_inputs: ",nb_inputs)
        print("sample_duration: ",cfg.data.sample_duration)
        print("=*=" * 50)
        print("Pretraining on the following monkeys:")
        for key in filenames:
            print(f"{key}: {filenames[key]}")

        print("=*=" * 50)
        print("finetuning on the following monkeys:")
        for key in cfg.train_monkeys:
            print("monkey:", key)
            for file in cfg.data.filenames[key]:
                print("file:", file)
        print("=*=" * 50)

        # GET MODEL
        # # # # # # # # #
        logger.info("Constructing model for crossSet pretraining...")
        pretrain_dat, pretrain_val_dat, _ = dataloader.get_multiple_set_data(
            filenames, nb_inputs=nb_inputs, with_S1=cfg.with_S1
        )
        if cfg.attentionFlag == False:
            model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat, stateFlag='pretrain')
        else:
            model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat, stateFlag='pretrain')
        model.pretrainFlag = 'pretrain'

        logger.info("Configuring model...")
        model = configure_model(model, cfg, dtype)

        logger.info("Pretraining on all pretrain sessions...")
        model, history = train_validate_model(
            model,
            cfg,
            pretrain_dat,
            pretrain_val_dat,
            cfg.training.nb_epochs_pretrain,
            verbose=cfg.training.verbose,
            snapshot_prefix="tinyRSNN_pretrain_pretrain_cross_dataSet",
        )
        results = {}
        for k, v in history.items():
            if "val" in k:
                results[k] = v.tolist()
            else:
                results["train_" + k] = v.tolist()

        logger.info("Pretraining complete.")

        # Save to JSON file with indentation
        converted_results = convert_np_float_to_float(results)
        with open("tinyRSNN-results-pretraining-pretrain_cross_dataSet.json", "w") as f:
            json.dump(converted_results, f, indent=4)

        # Save pretrained model state
        save_model_state(model, "tinyRSNN-pretrained-pretrain_cross_dataSet.pth")







    else:
        logger.info("No pretraining or model state loaded.")
        pretrained_model = None


if __name__ == "__main__":
    train_all()
