from datetime import datetime as dt
from pathlib import Path

import comet_ml
from adapt import adapt
from align import align
from comet_ml import ExistingExperiment, Experiment
from evaluate import evaluate_
from omegaconf import DictConfig, OmegaConf
from utils.utils import process_args


def main_():
    # --------------------
    # Alignment
    # --------------------

    cli_args = OmegaConf.from_cli()
    conf_args = OmegaConf.load(cli_args.config)

    args = OmegaConf.merge(conf_args, cli_args)
    OmegaConf.resolve(args)
    args = process_args(
        args=args,
        phase="align",
        inference_task_ids=args.inference_task_ids,
    )
    root_dir = Path(args.root_dir)
    is_new_experiment = True
    current_time = dt.now().strftime("%Y%m%d_%H%M%S")

    if args.logdir is not None:
        # existing experiment
        logdir = Path(args.logdir)
        align_dir = logdir / "align"
        adapt_dir = logdir / "adapt"
        align_dir.mkdir(exist_ok=True)
        adapt_dir.mkdir(exist_ok=True)
        assert logdir.parent == root_dir
        is_new_experiment = False
    else:
        # new experiment
        logdir = root_dir / (args.experiment_name + "_" + current_time)
        args.logdir = str(logdir)

        align_dir = logdir / "align"
        adapt_dir = logdir / "adapt"
        align_dir.mkdir(parents=True)
        adapt_dir.mkdir(parents=True)
        OmegaConf.save(args, align_dir / "config.yaml")

    experiment = None

    if args.train_alignment:
        align(args, experiment)

    if args.evaluate_alignment:
        evaluate_args = OmegaConf.merge(args, args.evaluate_args)
        if args.evaluate_by_inference_task:
            evaluate_args.task_ids = evaluate_args.inference_task_ids
        evaluate_(evaluate_args, experiment, "align")

    # --------------------
    # Adaptation
    # --------------------

    args = OmegaConf.merge(args, args.adapt_args)
    args = process_args(
        args=args,
        phase="adapt",
        inference_task_ids=args.inference_task_ids,
    )
    args.model_path = str(sorted((align_dir / "model").glob("*"))[-1])

    if args.train_adaptation:
        OmegaConf.save(args, adapt_dir / "config.yaml")
        for key, val in OmegaConf.to_container(args.adapt_args).items():
            key = "adapt_" + key
            experiment.log_parameter(key, val)

        adapt(args, experiment)

    if args.evaluate_adaptation:
        evaluate_args = OmegaConf.merge(args, args.evaluate_args)
        evaluate_(evaluate_args, experiment, "adapt")


def load_comet_experiment(args: DictConfig, logdir: Path,
                          is_new_experiment: bool, current_time: str):

    api_key = open(Path.home() / "comet_key.txt", "r").readline()

    if is_new_experiment:
        experiment_key = f"crossdomaintransfer{current_time.replace('_','')}"
        with open(logdir / "experiment_key.txt", "w") as f:
            f.write(experiment_key)
        experiment = Experiment(
            api_key=api_key,
            disabled=args.debug,
            experiment_key=experiment_key,
        )
        experiment.set_name(args.experiment_name)

        if args.discriminator.enable:
            experiment.add_tag("adv")

        experiment.log_parameters(OmegaConf.to_container(args))
    else:
        experiment_key = open(logdir / "experiment_key.txt", "r").readline()
        experiment = ExistingExperiment(
            api_key=api_key,
            experiment_key=experiment_key,
        )

    return experiment


if __name__ == "__main__":
    main_()
