# CONFIG
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork


# NUMERIC
from challenge import train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate, get_model_attention_V2, get_model, add_model_attention, train_validate_model_step_by_step
from challenge import evaluate_with_testdata, get_dataloader, get_dataloader_foundation_crossSet
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np
from challenge import (
    get_model_attention_V1,
    get_model_attention_V2,
    get_model_attention_V3,
    get_model_attention_V4,
    get_model_attention_V5,
    get_model_attention_V6,
    get_model_attention_V7,
    get_model_attention_V8,
    get_model_attention_V9,
    get_model_attention_V10,
    get_model_attention_V12,
    get_model_attention_V13,
    get_model_attention_V14,
    get_model_attention_V15,
    get_model_attention_V16,
                       )


# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:
    logger.info("Starting new simulation...")

    if cfg.testFlag:
        # with open_dict(cfg.data.pretrain_filenames):
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("test! warning! test!")
        # cfg.training.nb_epochs_pretrain = 2
        cfg.training.nb_epochs_train = 2
        # cfg.training.earlystop_patience = 3
        cfg.training.batchsize_finetuning = 400
        # cfg.training.batchsize_pretrain = 200
        cfg.training.earlystop_min_epochs = 0
        # cfg.data.extend_data = False
        # cfg.data.mix_continuous_uncontinuous = True
        # cfg.data.sample_duration = 1
        # cfg.model.step_training = True
        # cfg.model.step_num = 2
        cfg.reparameterization = True
        cfg.model.output_BN = False
        cfg.model.Attention_qkv = 'conv_linear'
        # cfg.model.multi_BN = True
        # cfg.model.multi_hidden = True
        cfg.model.nb_linear_hidden = 1
        # cfg.model.reshape_num = 4
        cfg.model.MLP_size = [512]
        cfg.data.nb_inputs["indy"] = 192
        cfg.data.nb_inputs["C05"] = 192
        cfg.data.nb_inputs["RTT"] = 192
        # cfg.task_retrain = 'finetune_outputnew_0708_retrain_multiBN'
        # cfg.task_retrain = 'finetune_outputnew_0708_retrain_multiBN_noOutputBN'
        cfg.training.lr = 0.01
        cfg.pretrain_monkeys = ["indy","loco","C05","MAZE"]
        # cfg.train_monkeys=['RTT']
        cfg.initializer.compute_nu=False




    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)

    for monkey_name in cfg.train_monkeys:

        # # # # # # # # #
        # # # # # # # # #
        # PRETRAINING OR LOADING STATE DICT
        # # # # # # # # # # #
        nb_inputs = cfg.data.nb_inputs[monkey_name]

        # TRAINING & EVALUATION
        # # # # # # # # #

        for session_name, filename in cfg.data.filenames[monkey_name].items():

            assert cfg.load_state[cfg.task_retrain][session_name], "Pretraining or model state must be loaded."
            logger.info("Loading pretrained model for " + session_name)
            assert cfg.seed, "Seed must be set to load pretrained model state."

            pretrain_model_dict = os.path.join(
                Path(__file__).resolve().parent,
                cfg.load_state[cfg.task_retrain][session_name][cfg.seed - 1]
            )
            pretrained_model = load_model_state(pretrain_model_dict, target_device=cfg.device)
            logger.info("Model state loaded: %s", cfg.seed)
            assert pretrained_model is not None, "Pretrained model state must be loaded."

            logger.info("=" * 50)
            logger.info("Constructing model for " + session_name + "...")
            logger.info("=" * 50)
            train_dat, val_dat, test_dat = dataloader.get_single_session_data(
                filename,
                monkeyname=monkey_name,
                nb_inputs=nb_inputs,
            )
            if cfg.attentionFlag:
                if cfg.model.separate_brain_area:
                    model = get_model_attention_V15(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.Attention_qkv_mix:
                    model = get_model_attention_V16(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.output_crossAttention:
                    model = get_model_attention_V4(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.nb_conv_hidden > 0:
                    model = get_model_attention_V3(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.session_encode:
                    model = get_model_attention_V5(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.area_wise:
                    model = get_model_attention_V6(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.no_MLP:
                    model = get_model_attention_V8(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.readout_MLP:
                    model = get_model_attention_V9(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.self_and_crossAttention:
                    model = get_model_attention_V13(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.fake_attention:
                    model = get_model_attention_V12(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                elif cfg.model.multiSyn_LIF:
                    model = get_model_attention_V14(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                else:
                    # model = get_model_attention_V1(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
                    model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            else:
                model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            model.pretrainFlag = 'train'
            logger.info("Configuring model...")
            model = configure_model(model, cfg, dtype)
            # Load pretrained model state
            if model.multiHidden:
                try:
                    del model.multiHidden["RTT"]
                except KeyError:
                    logger.warning("Key 'RTT' not found in multiHidden, skipping deletion.")
                try:
                    del model.multiHidden["B04"]
                except KeyError:
                    logger.warning("Key 'B04' not found in multiHidden, skipping deletion.")

            model.load_state_dict(pretrained_model)
            logger.info("Pretrained model state loaded.")
            # if cfg.model.multi_BN:
            #     model.BN_switch(monkey_name)
            #     # del model.multiBN
            # if cfg.model.multi_hidden:
            #     model.hidden_switch(monkey_name)
            if cfg.reparameterization:
                model.rep()

            # model, scores, pred, bm_results = evaluate_with_testdata(model, cfg, test_dat)
            # logger.info("Benchmark results:")
            # for k, v in bm_results.items():
            #     # log key and value rounded to 4 decimal places
            #     if isinstance(v, float):
            #         logger.info(f"{k}: {v:.4f}")
            #     else:
            #         logger.info(f"{k}: {v}")



            logger.info("Training on " + session_name + "...")
            if cfg.model.step_training and not cfg.model.self_and_crossAttention:
                model, history = train_validate_model_step_by_step(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    nb_epochs=cfg.training.nb_epochs_train,
                    filenames=filename,
                    finetune_monkeyname=monkey_name,
                    stateFlag='fine-tune',
                    verbose=cfg.training.verbose,
                    snapshot_prefix="tinyRSNN_" + session_name + "_",
                )
            else:
                model, history = train_validate_model(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    cfg.training.nb_epochs_train,
                    verbose=cfg.training.verbose,
                    snapshot_prefix="tinyRSNN_" + session_name + "_",
                )

            results = {}
            for k, v in history.items():
                if "val" in k:
                    results[k] = v.tolist()
                else:
                    results["train_" + k] = v.tolist()

            logger.info("Training complete.")

            # SAVE MODEL STATE
            # # # # # # # # #

            # Local save in hydra run directory
            save_model_state(model, "tinyRSNN-" + session_name + ".pth")

            # PRUNING
            # # # # # # # # #

            if cfg.training.is_prune:
                model = prune_retrain_model_iterate(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    logger,
                    history['r2'][-1],
                    history['val_r2'][-1],
                    nb_epochs_retrain=cfg.training.nb_epochs_retrain,
                    prune_percentage_start=cfg.training.prune_percentage_start,
                    tolerance=cfg.training.tolerance,
                    prune_precision=cfg.training.prune_precision,
                    max_prune_percentage=cfg.training.max_prune_percentage,
                    is_plot_pruning=cfg.training.is_plot_pruning,
                    is_pruning_ver=cfg.training.is_pruning_ver,
                    session_name=session_name,
                    pruning_plot_prefix=session_name
                )

                # SAVE MODEL STATE
                # # # # # # # # #

                # Save pruned model state
                save_model_state(model, "tinyRSNN-" + session_name + " pruned.pth")

            if cfg.model.is_half:

                model = model.half()
                # Save pruned model state
                if cfg.training.is_prune:
                    save_model_state(model, "tinyRSNN-" + session_name + " pruned half.pth")
                else:
                    save_model_state(model, "tinyRSNN-" + session_name + " half.pth")

                logger.info("Model converted to half precision.")

                if isinstance(test_dat, stork.datasets.RasDataset):
                    test_dat.dtype = torch.float16
                else:
                    for trialIndex in range(len(test_dat)):
                        test_dat[trialIndex].dtype = torch.float16
                logger.info("Test data converted to half precision.")

            # If seed is set, save model state as 'models / {session_name} / tinyRSNN-{seed}.pth'
            if cfg.seed:
                path = Path(to_absolute_path('models')) / session_name
                path.mkdir(parents=True, exist_ok=True)
                filepath = path / ("tinyRSNN-" + str(cfg.seed) + ".pth")
                save_model_state(model, filepath)

            logger.info("Saved model state.")

            # EVALUATE MODEL
            # # # # # # # # #
            logger.info("Evaluating model...")

            if cfg.plotting.plot_cumulative_mse:
                fig, ax = plot_cumulative_mse(
                    model, val_dat, save_path="tinyRSNN_cumulative_se_" + session_name + ".png"
                )

            model, scores, pred, bm_results = evaluate_with_testdata(model, cfg, test_dat)

            logger.info("Benchmark results:")
            for k, v in bm_results.items():
                # log key and value rounded to 4 decimal places
                if isinstance(v, float):
                    logger.info(f"{k}: {v:.4f}")
                else:
                    logger.info(f"{k}: {v}")

            for k, v in model.get_metrics_dict(scores).items():
                results["test_" + k] = v

            # Save to JSON file with indentation
            converted_results = convert_np_float_to_float(results)
            with open("tinyRSNN-results-" + session_name + ".json", "w") as f:
                json.dump(converted_results, f, indent=4)

            if cfg.plotting.plot_training:
                fig, ax = plot_training(
                    results,
                    cfg.training.nb_epochs_train,
                    names=["loss", "r2"],
                    save_path="tinyRSNN_training_" + session_name + ".png",
                )


if __name__ == "__main__":
    train_all()
