# CONFIG
from omegaconf import DictConfig
from omegaconf import open_dict
import hydra
from hydra.utils import to_absolute_path
import os
from pathlib import Path
import stork


# NUMERIC
from challenge import train_validate_model, evaluate_model, configure_model, prune_retrain_model_iterate, get_model_attention_V2, get_model, add_model_attention
from challenge import evaluate_with_testdata, get_dataloader, get_dataloader_foundation_crossSet
from challenge.utils import save_model_state, load_model_state
from challenge.utils.plotting import plot_training, plot_cumulative_mse
from challenge.utils.misc import convert_np_float_to_float
import torch
import numpy as np


# LOGGING
import logging
import time
import json

# SET UP LOGGER
logging.basicConfig()
logger = logging.getLogger("train-tinyRSNN-attention-test")


@hydra.main(config_path="conf", config_name="train-tinyRSNN-attention-test", version_base="1.1")
def train_all(cfg: DictConfig) -> None:
    logger.info("Starting new simulation...")

    if cfg.testFlag:
        # with open_dict(cfg.data.pretrain_filenames):
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("=*=" * 50)
        print("test! warning! test!")
        # cfg.training.nb_epochs_pretrain = 2
        cfg.training.nb_epochs_train = 200
        # cfg.training.earlystop_patience = 3
        cfg.training.batchsize_finetuning = 400
        # cfg.training.batchsize_pretrain = 200
        cfg.training.earlystop_min_epochs = 0
        # cfg.data.extend_data = False
        # cfg.data.mix_continuous_uncontinuous = True
        # cfg.data.sample_duration = 1
        # cfg.model.step_training = True
        # cfg.model.step_num = 2
        # cfg.reparameterization = True
        # cfg.model.output_BN = True
        cfg.model.Attention_qkv = 'conv_linear'
        # cfg.model.multi_BN = True
        # cfg.model.multi_hidden = True
        # cfg.model.nb_linear_hidden = 1
        # cfg.model.reshape_num = 4
        cfg.model.MLP_size = [512]
        cfg.data.nb_inputs["indy"] = 192
        cfg.data.nb_inputs["C05"] = 192
        cfg.data.nb_inputs["RTT"] = 192
        cfg.data.nb_inputs["B04"] = 192
        # cfg.task_retrain = 'finetune_outputnew_0708_retrain_multiBN'
        # cfg.task_retrain = 'finetune_outputnew_0708_retrain_multiBN_noOutputBN'
        cfg.training.lr = 0.01
        # cfg.train_monkeys=['RTT']

    # SETUP
    # # # # # # # # #

    # convert dtype string to torch dtype
    dtype = getattr(torch, cfg.dtype)

    # SETTING RANDOM SEED
    # # # # # # # # #

    if cfg.seed:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)

        if torch.cuda.is_available and 'cuda' in cfg.device:
            torch.cuda.manual_seed(cfg.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        os.environ['PYTHONHASHSEED'] = str(cfg.seed)

    # Get dataloader
    dataloader = get_dataloader_foundation_crossSet(cfg, dtype=dtype)

    for monkey_name in cfg.train_monkeys:

        # # # # # # # # #
        # # # # # # # # #
        # PRETRAINING OR LOADING STATE DICT
        # # # # # # # # # # #
        nb_inputs = cfg.data.nb_inputs[monkey_name]

        assert cfg.load_state[cfg.task_retrain][monkey_name], "Pretraining or model state must be loaded."
        logger.info("Loading pretrained model for " + monkey_name)
        assert cfg.seed, "Seed must be set to load pretrained model state."
        if cfg.load_idx is None or cfg.load_idx == "None":
            load_idx = cfg.seed
            print("load_idx is None, using seed as load_idx:", load_idx)
        else:
            load_idx = cfg.load_idx
            print("Using load_idx:", load_idx)

        pretrain_model_dict = os.path.join(
            Path(__file__).resolve().parent,
            cfg.load_state[cfg.task_retrain][monkey_name][load_idx - 1]
        )
        pretrained_model = load_model_state(pretrain_model_dict, target_device=cfg.device)
        logger.info("Model state loaded: %s", cfg.seed)
        assert pretrained_model is not None, "Pretrained model state must be loaded."


        # TRAINING & EVALUATION
        # # # # # # # # #

        for session_name, filename in cfg.data.filenames[monkey_name].items():

            logger.info("=" * 50)
            logger.info("Constructing model for " + session_name + "...")
            logger.info("=" * 50)

            train_dat, val_dat, test_dat = dataloader.get_single_session_data(
                filename,
                monkeyname=monkey_name,
                nb_inputs=nb_inputs,
            )
            model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=None, stateFlag='fine-tune')

            model.pretrainFlag = 'train'
            logger.info("Configuring model...")
            model = configure_model(model, cfg, dtype)
            # Load pretrained model state
            if cfg.task_retrain == 'finetune_outputnew_0708_retrain_multiBN_noOutputBN':
                if not cfg.model.output_BN:
                    if cfg.model.multiple_readouts:
                        for c in model.connections[-5:]:
                            if hasattr(c, "bn") and c.bn is not None:
                                c.bn = None
                    else:
                        c = model.connections[-1]
                        if hasattr(c, "bn") and c.bn is not None:
                            c.bn = None
            model.load_state_dict(pretrained_model)
            logger.info("Pretrained model state loaded.")


            if cfg.model.multi_BN:
                model.BN_switch(monkey_name)
                # del model.multiBN
            if cfg.model.multi_hidden:
                model.hidden_switch(monkey_name)
            if cfg.reparameterization:
                model.rep()
            # if not cfg.model.output_BN:
            #     if cfg.model.multiple_readouts:
            #         for c in model.connections[-5:]:
            #             if hasattr(c, "bn") and c.bn is not None:
            #                 c.bn = None
            #     else:
            #         c= model.connections[-1]
            #         if hasattr(c, "bn") and c.bn is not None:
            #             c.bn = None
            # if cfg.task_retrain=="pretrain_add_atten":
            #     model = get_model(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            #     model.pretrainFlag = 'train'
            #
            #     logger.info("Configuring model...")
            #     model = configure_model(model, cfg, train_dat, dtype)
            #
            #     # Load pretrained model state
            #     model.load_state_dict(pretrained_model)
            #     logger.info("Pretrained model state loaded.")
            #
            #     add_model_attention(cfg, nb_inputs=nb_inputs, dtype=dtype, model=model,data=train_dat, stateFlag='fine-tune')
            #     logger.info("attention configured.")
            # elif cfg.task_retrain=="pretrain_add_atten_forze":
            #     model = get_model_attention_V2(cfg, nb_inputs=nb_inputs, dtype=dtype, data=train_dat, stateFlag='fine-tune')
            #     model.pretrainFlag = 'train'
            #
            #     logger.info("Configuring model...")
            #     model = configure_model(model, cfg, train_dat, dtype)
            #
            #     # Load pretrained model state
            #     model.load_state_dict(pretrained_model)
            #     logger.info("Pretrained model state loaded.")



            logger.info("Training on " + session_name + "...")
            model, history = train_validate_model(
                model,
                cfg,
                train_dat,
                val_dat,
                cfg.training.nb_epochs_train,
                verbose=cfg.training.verbose,
                snapshot_prefix="tinyRSNN_" + session_name + "_",
            )

            results = {}
            for k, v in history.items():
                if "val" in k:
                    results[k] = v.tolist()
                else:
                    results["train_" + k] = v.tolist()

            logger.info("Training complete.")

            # SAVE MODEL STATE
            # # # # # # # # #

            # Local save in hydra run directory
            save_model_state(model, "tinyRSNN-" + session_name + ".pth")

            # PRUNING
            # # # # # # # # #

            if cfg.training.is_prune:
                model = prune_retrain_model_iterate(
                    model,
                    cfg,
                    train_dat,
                    val_dat,
                    logger,
                    history['r2'][-1],
                    history['val_r2'][-1],
                    nb_epochs_retrain=cfg.training.nb_epochs_retrain,
                    prune_percentage_start=cfg.training.prune_percentage_start,
                    tolerance=cfg.training.tolerance,
                    prune_precision=cfg.training.prune_precision,
                    max_prune_percentage=cfg.training.max_prune_percentage,
                    is_plot_pruning=cfg.training.is_plot_pruning,
                    is_pruning_ver=cfg.training.is_pruning_ver,
                    session_name=session_name,
                    pruning_plot_prefix=session_name
                )

                # SAVE MODEL STATE
                # # # # # # # # #

                # Save pruned model state
                save_model_state(model, "tinyRSNN-" + session_name + " pruned.pth")

            if cfg.model.is_half:

                model = model.half()
                # Save pruned model state
                if cfg.training.is_prune:
                    save_model_state(model, "tinyRSNN-" + session_name + " pruned half.pth")
                else:
                    save_model_state(model, "tinyRSNN-" + session_name + " half.pth")

                logger.info("Model converted to half precision.")

                if isinstance(test_dat, stork.datasets.RasDataset):
                    test_dat.dtype = torch.float16
                else:
                    for trialIndex in range(len(test_dat)):
                        test_dat[trialIndex].dtype = torch.float16
                logger.info("Test data converted to half precision.")

            # If seed is set, save model state as 'models / {session_name} / tinyRSNN-{seed}.pth'
            if cfg.seed:
                path = Path(to_absolute_path('models')) / session_name
                path.mkdir(parents=True, exist_ok=True)
                filepath = path / ("tinyRSNN-" + str(cfg.seed) + ".pth")
                save_model_state(model, filepath)

            logger.info("Saved model state.")

            # EVALUATE MODEL
            # # # # # # # # #
            logger.info("Evaluating model...")

            if cfg.plotting.plot_cumulative_mse:
                fig, ax = plot_cumulative_mse(
                    model, val_dat, save_path="tinyRSNN_cumulative_se_" + session_name + ".png"
                )

            model, scores, pred, bm_results = evaluate_with_testdata(model, cfg, test_dat)

            logger.info("Benchmark results:")
            for k, v in bm_results.items():
                # log key and value rounded to 4 decimal places
                if isinstance(v, float):
                    logger.info(f"{k}: {v:.4f}")
                else:
                    logger.info(f"{k}: {v}")

            for k, v in model.get_metrics_dict(scores).items():
                results["test_" + k] = v

            # Save to JSON file with indentation
            converted_results = convert_np_float_to_float(results)
            with open("tinyRSNN-results-" + session_name + ".json", "w") as f:
                json.dump(converted_results, f, indent=4)

            if cfg.plotting.plot_training:
                fig, ax = plot_training(
                    results,
                    cfg.training.nb_epochs_train,
                    names=["loss", "r2"],
                    save_path="tinyRSNN_training_" + session_name + ".png",
                )


if __name__ == "__main__":
    train_all()
