# Trains `tinyRSNN` model on all sessions of both monkeys.
# # # # # # # # #

# Pretraining is switched on by default and performed on all sessions of each monkey.
# If pretraining should be limited to the three sessions used in the challenge, run with data.pretrain_filenames=challenge-data


# CONFIG
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path

# NUMERIC
from challenge import get_model, train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate, train_validate_model_step_by_step
from challenge import (
    get_model_attention_V1,
    get_model_attention_V2,
    get_model_attention_V3,
    get_model_attention_V4,
    get_model_attention_V5,
    get_model_attention_V6,
    get_model_attention_V7,
    get_model_attention_V8,
    get_model_attention_V9,
    get_model_attention_V10,
    get_model_attention_V12,
    get_model_attention_V13,
    get_model_attention_V14,
    get_model_attention_V15,
    get_model_attention_V16,
    get_model_attention_V17,
                       )
from challenge import get_dataloader_foundation, evaluate_with_testdata, get_dataloader_MAZE, get_dataloader_C05, get_dataloader_foundation_crossSet
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np


# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")

@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:

    if cfg.testFlag:
        cfg.training.nb_epochs_train = 1
        # cfg.training.earlystop_patience = 3
        # cfg.training.batchsize_finetuning = 400
        cfg.training.earlystop_min_epochs = 0
        cfg.data.extend_data = False
        # cfg.data.mix_continuous_uncontinuous = True
        # cfg.data.sample_duration = 1.2
        # cfg.model.step_training = True
        # cfg.model.step_num = 2
        cfg.model.Attention_qkv = 'conv_linear'
        # cfg.model.multiSyn_LIF = True
        # cfg.model.firingRate_Decoder = True
        cfg.model.output_BN = False
        # cfg.model.Repconv = "LayerNorm"
        # cfg.model.reshape_num = 4
        cfg.model.MLP_size = [196]
        cfg.data.nb_inputs["indy"] = 192
        cfg.data.nb_inputs["C05"] = 192
        # cfg.model.multi_hidden = True
        # cfg.model.nb_linear_hidden = 1
        # cfg.model.Attention_parameter_scale = 100
        # cfg.model.forward_shortcut = True
        # cfg.train_monkeys = ["B04"]

    # print(cfg.data.filenames["C05"])
    logger.info("Starting new simulation...")

    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)


    # monkey_name = ["indy"]
    # monkey_name = ["loco"]
    for monkey_name in cfg.train_monkeys:

        # Get dataloader
        if cfg.session_encode:
            dataloader = get_dataloader_foundation(cfg, dtype=dtype)
        else:
            dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)

        nb_inputs = cfg.data.nb_inputs[monkey_name]

        # TRAINING & EVALUATION
        # # # # # # # # #
        for session_name, filename in cfg.data.filenames[monkey_name].items():

            logger.info("=" * 50)
            logger.info("Constructing model for " + session_name + "...")
            logger.info("=" * 50)

            if cfg.session_encode:
                nb_inputs += 10
                if monkey_name != "MAZE":
                    session_code = format(int(session_name[-2:]) - 1, '06b')
                    if monkey_name == "indy":
                        session_code = "0000" + session_code
                    elif monkey_name == "loco":
                        session_code = "0001" + session_code
                    elif monkey_name == "C05":
                        session_code = "0010" + session_code
                    else:
                        raise NotImplementedError
                elif monkey_name == "MAZE":
                    session_code = "0011000000"
                else:
                    raise NotImplementedError

            if cfg.session_encode:
                train_dat, val_dat, test_dat = dataloader.get_single_session_data(
                    filename,
                    nb_inputs=nb_inputs,
                    zscore = cfg.data.zscore,
                    session_code = session_code,
                )
            else:
                train_dat, val_dat, test_dat = dataloader.get_single_session_data(
                    filename,
                    monkey_name,
                    nb_inputs=nb_inputs,
                )


            assert cfg.model.nb_linear_hidden==0 or cfg.model.nb_conv_hidden==0, \
                "at most one kind of hidden can be used"
            if 'model' in locals():
                del model
            if cfg.attentionFlag:
                if cfg.model.firingRate_Decoder:
                    model = get_model_attention_V17(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.separate_brain_area:
                    model = get_model_attention_V15(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.Attention_qkv_mix:
                    model = get_model_attention_V16(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.output_crossAttention:
                    model = get_model_attention_V4(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.nb_conv_hidden>0:
                    model = get_model_attention_V3(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.session_encode:
                    model = get_model_attention_V5(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.area_wise:
                    model = get_model_attention_V6(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.no_MLP:
                    model = get_model_attention_V8(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.readout_MLP:
                    model = get_model_attention_V9(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.self_and_crossAttention:
                    model = get_model_attention_V13(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.fake_attention:
                    model = get_model_attention_V12(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.multiSyn_LIF:
                    model = get_model_attention_V14(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                else:
                    # model = get_model_attention_V1(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                    model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            else:
                model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            model.pretrainFlag = 'train'

            logger.info("Configuring model...")
            model = configure_model(model, cfg, dtype)
            if cfg.reparameterization:
                model.rep()
            # if not cfg.model.output_BN:
            #     print("Removing output batch normalization layer...")
            #     if cfg.model.multiple_readouts:
            #         for c in model.connections[-5:]:
            #             if hasattr(c, "bn") and c.bn is not None:
            #                 c.bn = None
            #     else:
            #         c= model.connections[-1]
            #         if hasattr(c, "bn") and c.bn is not None:
            #             c.bn = None

            logger.info("Training on " + session_name + "...")
            if cfg.model.step_training and not cfg.model.self_and_crossAttention:
                model, history = train_validate_model_step_by_step(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    nb_epochs=cfg.training.nb_epochs_train,
                    filenames=filename,
                    finetune_monkeyname=monkey_name,
                    stateFlag='fine-tune',
                    verbose=cfg.training.verbose,
                    snapshot_prefix="tinyRSNN_" + session_name + "_",
                )
            else:
                model, history = train_validate_model(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    cfg.training.nb_epochs_train,
                    verbose=cfg.training.verbose,
                    snapshot_prefix="tinyRSNN_" + session_name + "_",
                )

            results = {}
            for k, v in history.items():
                if "val" in k:
                    results[k] = v.tolist()
                else:
                    results["train_" + k] = v.tolist()

            logger.info("Training complete.")

            # SAVE MODEL STATE
            # # # # # # # # #

            # Local save in hydra run directory
            save_model_state(model, "tinyRSNN-" + session_name + ".pth")

            # If seed is set, save model state as 'models / {session_name} / tinyRSNN-{seed}.pth'
            if cfg.seed:
                path = Path(to_absolute_path('models')) / session_name
                path.mkdir(parents=True, exist_ok=True)
                filepath = path / ("tinyRSNN-" + str(cfg.seed) + ".pth")
                save_model_state(model, filepath)

            logger.info("Saved model state.")


            # EVALUATE MODEL
            # # # # # # # # #
            logger.info("Evaluating model...")

            if cfg.plotting.plot_cumulative_mse:
                fig, ax = plot_cumulative_mse(
                    model, val_dat, save_path="tinyRSNN_cumulative_se_" + session_name + ".png"
                )

            model, scores, pred, bm_results = evaluate_with_testdata(model, cfg, test_dat)


            logger.info("Benchmark results:")
            for k, v in bm_results.items():
                # log key and value rounded to 4 decimal places
                if isinstance(v, float):
                    logger.info(f"{k}: {v:.4f}")
                else:
                    logger.info(f"{k}: {v}")

            for k, v in model.get_metrics_dict(scores).items():
                results["test_" + k] = v

            # Save to JSON file with indentation
            converted_results = convert_np_float_to_float(results)
            with open("tinyRSNN-results-" + session_name + ".json", "w") as f:
                json.dump(converted_results, f, indent=4)

            if cfg.plotting.plot_training:
                fig, ax = plot_training(
                    results,
                    cfg.training.nb_epochs_train,
                    names=["loss", "r2"],
                    save_path="tinyRSNN_training_" + session_name + ".png",
                )

            model.reset_states()

if __name__ == "__main__":
    train_all()
