# python scripts/run_semisupervised.py -c
import argparse
from utils.logging import get_logger
from pathlib import Path

from utils.io import load_trainer, get_full_config, save_pickle
from utils.dataset import (
    get_data_from_config,
    get_eval_dataloader,
    SPLIT_LIST,
    get_eval_dataloader,
)
from utils.evaluate import eval_semisupervised


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--ckpt_path", help="config path yaml file", type=str, default=None
    )
    parser.add_argument(
        "-sd",
        "--save_dir",
        help="set save_dir. this overrides the config file",
        type=str,
        default=None,
    )
    parser.add_argument(
        "-s",
        "--seed",
        help="set seed. this overrides the config file",
        type=str,
        default=None,
    )
    parser.add_argument(
        "-n",
        "--n_samples",
        help="set number of samples. this overrides the config file",
        type=int,
        default=None,
    )
    # parser.add_argument("-p", "--data_path", help="set data_path. this overrides the config file", type=str, default=None)
    return parser


if __name__ == "__main__":

    args = get_parser().parse_args()
    ckpt_path = args.ckpt_path

    config_path = Path(ckpt_path) / "config.yaml"

    config = get_full_config(config_path)
    if config["save_dir"] != str(Path(ckpt_path).parent):
        config["save_dir"] = str(Path(ckpt_path).parent)

    save_dir = Path(config["save_dir"])
    log_file = save_dir / "log.log"

    logger = get_logger(log_file=log_file)

    if args.save_dir is not None:
        config["save_dir"] = args.save_dir
    if args.seed is not None:
        config["seed"] = args.seed
    if args.n_samples is not None:
        config["eval_args"]["eval_n_seeds"] = args.n_samples

    data, _ = get_data_from_config(config, "subseq")
    trainer, _ = load_trainer(config, data, "checkpoint_best")
    eval_loaders, _ = get_eval_dataloader(config)

    eval_results = {}
    for split in SPLIT_LIST:
        results = trainer.evaluate(eval_loaders[f"{split}_loader"])
        results["embed"] = results["embed"].squeeze()
        eval_results[f"{split}_results"] = results

    eval_loaders, data_info = get_eval_dataloader(config)
    eval_semisupervised_results = eval_semisupervised(
        config,
        eval_results["train_results"]["embed"],
        eval_results["train_results"]["labels"],
        eval_results["test_results"]["embed"],
        eval_results["test_results"]["labels"],
    )

    output_path = Path(ckpt_path).parent / "semi_supervised_results.pkl"
    logger.info(f"saving results to: {output_path}")
    save_pickle(eval_semisupervised_results, output_path)
