# Trains `tinyRSNN` model on all sessions of both monkeys.
# # # # # # # # #

# Pretraining is switched on by default and performed on all sessions of each monkey.
# If pretraining should be limited to the three sessions used in the challenge, run with data.pretrain_filenames=challenge-data
# 每个dataset使用各自的batchNorm

# CONFIG
from omegaconf import DictConfig
from omegaconf import open_dict
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork

# NUMERIC
from challenge import (get_model, get_model_attention_V2, get_model_attention_V16,
                       train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate,
                       train_validate_model_multiBN)
from challenge import get_dataloader_foundation_crossSet, evaluate_with_testdata
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np

# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:

    if cfg.testFlag:
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50, 'test!!!!!!!!!!!!!', "=*=" * 50,)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        cfg.multi_cuda = True
        cfg.training.nb_epochs_pretrain = 2
        # cfg.training.earlystop_patience = 3
        # cfg.training.batchsize_finetuning = 200
        cfg.training.batchsize_pretrain = 200
        cfg.training.earlystop_min_epochs = 0
        cfg.training.lr = 10
        # cfg.data.extend_data = False
        # cfg.data.mix_continuous_uncontinuous = True
        # cfg.data.sample_duration = 1
        # cfg.model.step_training = True
        # cfg.model.step_num = 2
        # cfg.model.Attention_qkv = 'linear'
        cfg.model.multi_BN = True
        cfg.model.reshape_num = 4
        cfg.model.MLP_size = [48]
        cfg.data.nb_inputs["indy"] = 192
        cfg.data.nb_inputs["C05"] = 192
        cfg.data.nb_inputs["RTT"] = 192
        cfg.data.padding = "copy"
        cfg.pretrain_monkeys = ["loco","indy","C05","MAZE"]
        # cfg.pretrain_monkeys = ["loco"]
        # cfg.pretrain_monkeys = ["RTT"]
        cfg.model.multi_hidden = True
        cfg.model.nb_linear_hidden = 1
        cfg.model.Attention_parameter_scale = 100
        cfg.model.hidden_shortcut = True
        cfg.initializer.compute_nu=False
        cfg.initializer.nu=9.2535
        # with open_dict(cfg.data.pretrain_filenames):
        #     del cfg.data.pretrain_filenames["indy"]["indy03"]
        #     del cfg.data.pretrain_filenames["indy"]["indy02"]
        #     del cfg.data.pretrain_filenames["indy"]["indy04"]
        #     del cfg.data.pretrain_filenames["indy"]["indy05"]
        #     del cfg.data.pretrain_filenames["indy"]["indy06"]
        #     del cfg.data.pretrain_filenames["C05"]["C05_02"]
        #
        #
        #     del cfg.data.pretrain_filenames["C05"]["C05_03"]
        #     del cfg.data.pretrain_filenames["C05"]["C05_04"]
        #     del cfg.data.pretrain_filenames["C05"]["C05_05"]
        #     del cfg.data.pretrain_filenames["C05"]["C05_06"]

    logger.info("Starting new simulation...")

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)
    print(cfg.data.sample_duration)

    # # # # # # # # #
    # # # # # # # # #
    # PRETRAINING OR LOADING STATE DICT
    # # # # # # # # # # #
    assert all(cfg.data.nb_inputs[monkey_name] == cfg.data.nb_inputs["indy"]
               for monkey_name in cfg.pretrain_monkeys),\
        "All monkeys must have the same number of inputs."
    nb_inputs = cfg.data.nb_inputs["indy"]
    if cfg.pretraining:
        # 预训练的数据
        filenames = {key: cfg.data.pretrain_filenames[key] for key in cfg.pretrain_monkeys}
        print("=*=" * 50)
        print("Pretraining on the following monkeys:")
        for key in filenames:
            print(f"{key}: {filenames[key]}")
        # 微调的数据
        print("=*=" * 50)
        print("finetuning on the following monkeys:")
        for key in cfg.train_monkeys:
            print("monkey:", key)
            for file in cfg.data.filenames[key]:
                print("file:", file)
        print("=*=" * 50)

        # 获取预训练数据集
        logger.info("Constructing model for crossSet pretraining...")
        pretrain_dat_div, pretrain_dat_all = dataloader.get_multiple_set_data_divide_set(
            filenames, nb_inputs=nb_inputs, with_S1=cfg.with_S1
        )
        # GET MODEL
        if cfg.attentionFlag == False:
            model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat_all["dataset_all_train"], stateFlag='pretrain')
        else:
            if cfg.model.Attention_qkv_mix:
                model = get_model_attention_V16(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat_all["dataset_all_train"], stateFlag='pretrain')
            else:
                model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=pretrain_dat_all["dataset_all_train"], stateFlag='pretrain')
        model.pretrainFlag = 'pretrain'

        logger.info("Configuring model...")
        model = configure_model(model, cfg, dtype)

        logger.info("Pretraining on all pretrain sessions...")
        # model.multiBN = True
        model, history = train_validate_model_multiBN(
            model,
            cfg,
            pretrain_dat_div["div_ds_train"],
            pretrain_dat_div["div_ds_valid"],
            cfg.training.nb_epochs_pretrain,
            verbose=cfg.training.verbose,
            snapshot_prefix="tinyRSNN_pretrain_pretrain_cross_dataSet",
        )
        results = {}
        for k, v in history.items():
            if "val" in k:
                results[k] = v.tolist()
            else:
                results["train_" + k] = v.tolist()

        logger.info("Pretraining complete.")

        # Save to JSON file with indentation
        converted_results = convert_np_float_to_float(results)
        with open("tinyRSNN-results-pretraining-pretrain_cross_dataSet.json", "w") as f:
            json.dump(converted_results, f, indent=4)

        # Save pretrained model state
        save_model_state(model, "tinyRSNN-pretrained-pretrain_cross_dataSet.pth")

        pretrained_model = model.state_dict()






    else:
        logger.info("No pretraining or model state loaded.")
        pretrained_model = None



if __name__ == "__main__":
    train_all()
