import os
import argparse

from utils.io import load_trainer, get_full_config, save_pickle, save_yaml
from utils.logging import get_logger, log_class_scores
from utils.constants import SPLIT_LIST
from utils.evaluate import eval_classification
from utils.dataset import get_data_from_config, get_eval_dataloader
from utils.common import set_seed
from trainers.transfer import TransferTrainer
from pathlib import Path
from omegaconf import OmegaConf

# import torch
import numpy as np


# from trainers.base import BaseTrainer


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config_path",
        help="keys from transfer_config.py",
        type=str,
        required=True,
    )
    parser.add_argument(
        "-p",
        "--ckpt_path",
        help="overrides load_from_checkpoints",
        type=str,
        default=None,
    )
    parser.add_argument("-s", "--seed", help="overrides seed", type=int, default=None)

    args = parser.parse_args()
    # config = get_full_config(args.config_path)
    # logger = get_logger()

    # config_path = "configs/transfer/epilepsy/epilepsy.yaml"
    transfer_config = get_full_config(args.config_path)
    data, info = get_data_from_config(transfer_config, "subseq", train=True)
    logger = get_logger()

    logger.info(f"Loading data from:\t{transfer_config.data_args.path}")
    logger.info(f"train data shape:\t{data['train_data'].shape}")
    logger.info(f"val data shape:\t{data['val_data'].shape}")
    logger.info(f"")

    if args.seed is not None:
        transfer_config["seed"] = args.seed

    if args.ckpt_path is not None:
        transfer_config["load_from_checkpoint"] = args.ckpt_path

    set_seed(transfer_config["seed"])

    assert (
        transfer_config["load_from_checkpoint"] is not None
    ), "load_from_checkpoint must be specified for transfer learning"

    # generate save_dir
    if "data_name" in transfer_config and transfer_config["data_name"] is not None:
        data_name = transfer_config["data_name"]
    else:
        data_name = transfer_config["data_args"]["path"].split("/")[1]
    # else:
    logger.info(f"loading from checkpoint: {transfer_config['load_from_checkpoint']}")

    path = Path(transfer_config["load_from_checkpoint"])
    seed = path.parent.stem
    model_name = path.parent.parent.stem

    transfer_config["model_name"]

    transfer_config["save_dir"] = os.path.join(
        f"experiments", data_name, model_name, seed
    )
    save_dir = Path(transfer_config["save_dir"])

    os.makedirs(save_dir, exist_ok=True)

    save_yaml(
        OmegaConf.to_container(transfer_config, resolve=True), save_dir / "config.yaml"
    )

    log_file = save_dir / "log.log"
    logger = get_logger(log_file=log_file)
    logger.info(f"Log file created:\t{log_file}")

    train_data = np.concatenate([data["train_data"], data["val_data"]], axis=0)
    train_labels = np.concatenate([data["train_labels"], data["val_labels"]], axis=0)

    trainer = TransferTrainer(
        transfer_config,
        train_data,
        train_labels,
        #   data["train_data"], data["train_labels"],
        data["val_data"],
        data["val_labels"],
    )
    trainer.fit()

    eval_loaders, data_info = get_eval_dataloader(transfer_config)

    # label_names = get_label_names(transfer_config)

    eval_results = {}

    for split in SPLIT_LIST:
        results = trainer.evaluate(eval_loaders[f"{split}_loader"])
        results["embed"] = results["embed"].squeeze()
        eval_results[f"{split}_results"] = results

    eval_class_results = eval_classification(
        eval_results["train_results"]["embed"],
        eval_results["train_results"]["labels"],
        eval_results["test_results"]["embed"],
        eval_results["test_results"]["labels"],
    )

    log_class_scores(eval_class_results)

    scores = {
        "classification": eval_class_results,

    }

    # save scores
    score_path = Path(transfer_config["save_dir"]) / "scores.pkl"
    save_pickle(scores, score_path)
    logger.info(f"Saving scores to {score_path}")
