import argparse
import importlib
import os
from datetime import datetime

import numpy as np
import torch

from ts_tcc.dataloader.dataloader import data_generator
from ts_tcc.models.model import base_Model
from ts_tcc.models.TC import TC, TS_SD
from ts_tcc.trainer.trainer import Trainer, model_evaluate
from ts_tcc.utils import _calc_metrics, _logger, copy_Files, set_requires_grad


def main():
    """Main entry point for the TS-TCC training script."""
    # Args selections
    start_time = datetime.now()

    parser = argparse.ArgumentParser()

    # Model parameters ########################
    home_dir = os.getcwd()
    parser.add_argument(
        "--experiment_description", default="exp1", type=str, help="Experiment Description"
    )
    parser.add_argument(
        "--run_description", default="run1", type=str, help="Experiment Description"
    )
    parser.add_argument("--seed", default=0, type=int, help="seed value")
    parser.add_argument(
        "--training_mode",
        default="supervised",
        type=str,
        help="Modes of choice: random_init, supervised, self_supervised, fine_tune, train_linear",
    )
    parser.add_argument(
        "--selected_dataset",
        default="Epilepsy",
        type=str,
        help="Dataset of choice: sleepEDF, HAR, Epilepsy, pFD",
    )
    parser.add_argument(
        "--logs_save_dir",
        default="/Volumes/sandbox/ts-tcc/experiments_logs",
        type=str,
        help="saving directory",
    )
    parser.add_argument("--device", default="cuda", type=str, help="cpu or cuda")
    parser.add_argument(
        "--skip_validation",
        action="store_true",
        help="Skip validation evaluation step during training to speed up runs",
    )
    parser.add_argument("--home_path", default=home_dir, type=str, help="Project home directory")
    args = parser.parse_args()

    device = torch.device(args.device)
    experiment_description = args.experiment_description
    data_type = args.selected_dataset
    method = "TS-TCC"
    training_mode = args.training_mode
    run_description = args.run_description

    logs_save_dir = args.logs_save_dir
    os.makedirs(logs_save_dir, exist_ok=True)

    # Dynamic import of config file
    config_module = importlib.import_module(f"ts_tcc.config_files.{data_type}_Configs")
    configs = config_module.Config()

    # ##### fix random seeds for reproducibility ########
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    #####################################################

    experiment_log_dir = os.path.join(
        logs_save_dir,
        experiment_description,
        run_description,
        training_mode + f"_seed_{SEED}",
    )
    os.makedirs(experiment_log_dir, exist_ok=True)

    # loop through domains
    counter = 0
    src_counter = 0

    # Logging
    log_file_name = os.path.join(
        experiment_log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log"
    )
    logger = _logger(log_file_name)
    logger.debug("=" * 45)
    logger.debug(f"Dataset: {data_type}")
    logger.debug(f"Method:  {method}")
    logger.debug(f"Mode:    {training_mode}")
    logger.debug("=" * 45)

    # Load datasets
    data_path = os.path.join(args.home_path, "src", "ts_tcc", "data", data_type)
    train_dl, valid_dl, test_dl = data_generator(data_path, configs, training_mode)
    logger.debug("Data loaded ...")

    # Load Model
    model = base_Model(configs).to(device)
    if training_mode in ["ts_sd", "ts_sd_finetune"]:
        temporal_contr_model = TS_SD(
            configs, device, [3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25], 96, 2
        ).to(device)
    else:
        temporal_contr_model = TC(configs, device).to(device)

    if training_mode == "fine_tune":
        # load saved model of this experiment
        load_from = os.path.join(
            os.path.join(
                logs_save_dir,
                experiment_description,
                "pretrain",  # Look in the pretrain directory, not current run_description
                f"self_supervised_seed_{SEED}",
                "saved_models",
            )
        )
        chkpoint = torch.load(
            os.path.join(load_from, "ckp_last.pt"), map_location=device, weights_only=False
        )
        pretrained_dict = chkpoint["model_state_dict"]
        model_dict = model.state_dict()
        del_list = ["logits"]
        pretrained_dict_copy = pretrained_dict.copy()
        for i in pretrained_dict_copy.keys():
            for j in del_list:
                if j in i:
                    del pretrained_dict[i]
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    if training_mode == "ts_sd_finetune":
        pass
        # load saved model of this experiment
        #    params = temporal_contr_model.state_dict()
        #    load_from = os.path.join(os.path.join(logs_save_dir, experiment_description, run_description, f"ts_sd_seed_{SEED}", "saved_models"))
        #    chkpoint = torch.load(os.path.join(load_from, "ckp_last.pt"), map_location=device)
        #    pretrained_dict = chkpoint["temporal_contr_model_state_dict"]
        #    model_dict = temporal_contr_model.state_dict()
        #    del_list = ['logit']
        #    pretrained_dict_copy = pretrained_dict.copy()
        #    for i in pretrained_dict_copy.keys():
        #        for j in del_list:
        #            if j in i:
        #                del pretrained_dict[i]
        #    model_dict.update(pretrained_dict)
        #    temporal_contr_model.load_state_dict(model_dict)
        # set_requires_grad(model, pretrained_dict, requires_grad=False)

    if training_mode == "train_linear" or "tl" in training_mode:
        load_from = os.path.join(
            os.path.join(
                logs_save_dir,
                experiment_description,
                "pretrain",  # Look in the pretrain directory, not current run_description
                f"self_supervised_seed_{SEED}",
                "saved_models",
            )
        )
        chkpoint = torch.load(
            os.path.join(load_from, "ckp_last.pt"), map_location=device, weights_only=False
        )
        pretrained_dict = chkpoint["model_state_dict"]
        model_dict = model.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

        # delete these parameters (Ex: the linear layer at the end)
        del_list = ["logits"]
        pretrained_dict_copy = pretrained_dict.copy()
        for i in pretrained_dict_copy.keys():
            for j in del_list:
                if j in i:
                    del pretrained_dict[i]

        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        set_requires_grad(
            model, pretrained_dict, requires_grad=False
        )  # Freeze everything except last layer.

    if training_mode == "random_init":
        model_dict = model.state_dict()

        # delete all the parameters except for logits
        del_list = ["logits"]
        pretrained_dict_copy = model_dict.copy()
        for i in pretrained_dict_copy.keys():
            for j in del_list:
                if j in i:
                    del model_dict[i]
        set_requires_grad(
            model, model_dict, requires_grad=False
        )  # Freeze everything except last layer.

    model_optimizer = torch.optim.Adam(
        model.parameters(), lr=configs.lr, betas=(configs.beta1, configs.beta2), weight_decay=3e-4
    )
    temporal_contr_optimizer = torch.optim.Adam(
        temporal_contr_model.parameters(),
        lr=configs.lr,
        betas=(configs.beta1, configs.beta2),
        weight_decay=3e-4,
    )

    if training_mode in ["ts_sd", "self_supervised"]:  # to do it only once
        copy_Files(os.path.join(logs_save_dir, experiment_description, run_description), data_type)

    # Trainer
    Trainer(
        model,
        temporal_contr_model,
        model_optimizer,
        temporal_contr_optimizer,
        train_dl,
        valid_dl,
        test_dl,
        device,
        logger,
        configs,
        experiment_log_dir,
        training_mode,
        skip_validation=args.skip_validation,
    )

    if training_mode not in ["ts_sd", "self_supervised"]:
        # Testing
        outs = model_evaluate(model, temporal_contr_model, test_dl, device, training_mode)
        total_loss, total_acc, pred_labels, true_labels, metrics = outs
        _calc_metrics(pred_labels, true_labels, experiment_log_dir, args.home_path)
        print(metrics)

    logger.debug(f"Training time is : {datetime.now() - start_time}")


if __name__ == "__main__":
    main()
