# Trains `tinyRSNN` model on all sessions of both monkeys.
# # # # # # # # #

# Pretraining is switched on by default and performed on all sessions of each monkey.
# If pretraining should be limited to the three sessions used in the challenge, run with data.pretrain_filenames=challenge-data


# CONFIG
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork

# NUMERIC
from challenge import get_model, get_model_attention_V2, train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate
from challenge import get_dataloader_foundation_crossSet, evaluate_with_testdata
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np

# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:
    logger.info("Starting new simulation...")

    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)
    print(cfg.data.sample_duration)

    # # # # # # # # #
    # # # # # # # # #
    # PRETRAINING OR LOADING STATE DICT
    # # # # # # # # # # #
    # assert all(cfg.data.nb_inputs[monkey_name] == cfg.data.nb_inputs["indy"]
    #            for monkey_name in cfg.train_monkeys),\
    #     "All monkeys must have the same number of inputs."
    nb_inputs = cfg.data.nb_inputs["indy"]
    if cfg.pretraining:
        # filenames = cfg.data.pretrain_filenames
        filenames = {key: cfg.data.pretrain_filenames[key] for key in cfg.pretrain_monkeys}
        print("=*=" * 50)
        print("Pretraining on the following monkeys:")
        for key in filenames:
            print(f"{key}: {filenames[key]}")

        print("=*=" * 50)
        print("finetuning on the following monkeys:")
        for key in cfg.train_monkeys:
            print("monkey:", key)
            for file in cfg.data.filenames[key]:
                print("file:", file)
        print("=*=" * 50)

        # GET MODEL
        # # # # # # # # #
        logger.info("Constructing model for crossSet pretraining...")
        pretrain_dat, pretrain_val_dat, _ = dataloader.get_multiple_set_data(
            filenames, nb_inputs=nb_inputs, with_S1=cfg.with_S1
        )
        if cfg.attentionFlag == False:
            model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat, stateFlag='pretrain')
        else:
            model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat, stateFlag='pretrain')
        model.pretrainFlag = 'pretrain'

        logger.info("Configuring model...")
        model = configure_model(model, cfg, dtype)

        logger.info("Pretraining on all pretrain sessions...")
        model, history = train_validate_model(
            model,
            cfg,
            pretrain_dat,
            pretrain_val_dat,
            cfg.training.nb_epochs_pretrain,
            verbose=cfg.training.verbose,
            snapshot_prefix="tinyRSNN_pretrain_pretrain_cross_dataSet",
        )
        results = {}
        for k, v in history.items():
            if "val" in k:
                results[k] = v.tolist()
            else:
                results["train_" + k] = v.tolist()

        logger.info("Pretraining complete.")

        # Save to JSON file with indentation
        converted_results = convert_np_float_to_float(results)
        with open("tinyRSNN-results-pretraining-pretrain_cross_dataSet.json", "w") as f:
            json.dump(converted_results, f, indent=4)

        # Save pretrained model state
        save_model_state(model, "tinyRSNN-pretrained-pretrain_cross_dataSet.pth")

        pretrained_model = model.state_dict()






    else:
        logger.info("No pretraining or model state loaded.")
        pretrained_model = None

    # TRAINING & EVALUATION
    # # # # # # # # #
    for finetune_monkeyname in cfg.train_monkeys:
        for session_name, filename in cfg.data.filenames[finetune_monkeyname].items():

            logger.info("=" * 50)
            logger.info("Constructing model for " + session_name + "...")
            logger.info("=" * 50)

            train_dat, val_dat, test_dat = dataloader.get_single_session_data(
                filename,
                monkeyname=finetune_monkeyname,
                nb_inputs=nb_inputs,
            )
            if cfg.attentionFlag == False:
                model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            else:
                model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            model.pretrainFlag = 'train'

            logger.info("Configuring model...")
            model = configure_model(model, cfg, dtype)

            # Load pretrained model state
            if pretrained_model is not None:
                model.load_state_dict(pretrained_model)
                logger.info("Pretrained model state loaded.")

            logger.info("Training on " + session_name + "...")
            model, history = train_validate_model(
                model,
                cfg,
                train_dat,
                val_dat,
                cfg.training.nb_epochs_train,
                verbose=cfg.training.verbose,
                snapshot_prefix="tinyRSNN_" + session_name + "_",
            )

            results = {}
            for k, v in history.items():
                if "val" in k:
                    results[k] = v.tolist()
                else:
                    results["train_" + k] = v.tolist()

            logger.info("Training complete.")

            # SAVE MODEL STATE
            # # # # # # # # #

            # Local save in hydra run directory
            save_model_state(model, "tinyRSNN-" + session_name + ".pth")

            # PRUNING
            # # # # # # # # #

            if cfg.training.is_prune:
                model = prune_retrain_model_iterate(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    logger,
                    history['r2'][-1],
                    history['val_r2'][-1],
                    nb_epochs_retrain=cfg.training.nb_epochs_retrain,
                    prune_percentage_start=cfg.training.prune_percentage_start,
                    tolerance=cfg.training.tolerance,
                    prune_precision=cfg.training.prune_precision,
                    max_prune_percentage=cfg.training.max_prune_percentage,
                    is_plot_pruning=cfg.training.is_plot_pruning,
                    is_pruning_ver=cfg.training.is_pruning_ver,
                    session_name=session_name,
                    pruning_plot_prefix=session_name
                )

                # SAVE MODEL STATE
                # # # # # # # # #

                # Save pruned model state
                save_model_state(model, "tinyRSNN-" + session_name + " pruned.pth")

            if cfg.model.is_half:

                model = model.half()
                # Save pruned model state
                if cfg.training.is_prune:
                    save_model_state(model, "tinyRSNN-" + session_name + " pruned half.pth")
                else:
                    save_model_state(model, "tinyRSNN-" + session_name + " half.pth")

                logger.info("Model converted to half precision.")

                if isinstance(test_dat, stork.datasets.RasDataset):
                    test_dat.dtype = torch.float16
                else:
                    for trialIndex in range(len(test_dat)):
                        test_dat[trialIndex].dtype = torch.float16
                logger.info("Test data converted to half precision.")

            # If seed is set, save model state as 'models / {session_name} / tinyRSNN-{seed}.pth'
            if cfg.seed:
                path = Path(to_absolute_path('models')) / session_name
                path.mkdir(parents=True, exist_ok=True)
                filepath = path / ("tinyRSNN-" + str(cfg.seed) + ".pth")
                save_model_state(model, filepath)

            logger.info("Saved model state.")

            # EVALUATE MODEL
            # # # # # # # # #
            logger.info("Evaluating model...")

            if cfg.plotting.plot_cumulative_mse:
                fig, ax = plot_cumulative_mse(
                    model, val_dat, save_path="tinyRSNN_cumulative_se_" + session_name + ".png"
                )

            model, scores, pred, bm_results = evaluate_with_testdata(model, cfg, test_dat)

            logger.info("Benchmark results:")
            for k, v in bm_results.items():
                # log key and value rounded to 4 decimal places
                if isinstance(v, float):
                    logger.info(f"{k}: {v:.4f}")
                else:
                    logger.info(f"{k}: {v}")

            for k, v in model.get_metrics_dict(scores).items():
                results["test_" + k] = v

            # Save to JSON file with indentation
            converted_results = convert_np_float_to_float(results)
            with open("tinyRSNN-results-" + session_name + ".json", "w") as f:
                json.dump(converted_results, f, indent=4)

            if cfg.plotting.plot_training:
                fig, ax = plot_training(
                    results,
                    cfg.training.nb_epochs_train,
                    names=["loss", "r2"],
                    save_path="tinyRSNN_training_" + session_name + ".png",
                )


if __name__ == "__main__":
    train_all()
