import time
from pathlib import Path
from typing import Literal, Optional, Union

from comet_ml import ExistingExperiment, Experiment
from omegaconf import DictConfig, OmegaConf
from utils.evaluate import parallel_evaluate
from utils.visualize_episodes import parallel_visualize
from utils.visualize_latents import visualize_latents


def evaluate_(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]] = None,
    prefix: Literal["", "align", "adapt"] = "",
):
    logdir = Path(args.logdir) / prefix
    model_dir = logdir / "model"
    video_dir = logdir / "video"
    image_dir = logdir / "image"

    logfile = open(logdir / "eval.csv", "w")
    logfile.write(f"epoch,source_success_rate,target_success_rate\n")

    video_dir.mkdir(exist_ok=True)
    image_dir.mkdir(exist_ok=True)

    i = 1
    while True:
        try:
            latest_model_path = sorted(model_dir.glob("*"))[-i]
            max_epoch = int(str(latest_model_path.stem)[:3])
            break
        except:
            i += 1
    for model_path in sorted(model_dir.glob("*")):
        start = time.time()
        try:
            epoch = int(str(model_path.stem)[:3])
        except:
            continue
        s = f"Epoch: {epoch:3d}/{max_epoch:3d} | "

        if args.visualize_latents:
            print("Visualizing latents...")
            visualize_latents(
                args=args,
                experiment=experiment,
                model_path=model_path,
                image_dir=image_dir,
                epoch=epoch,
                prefix=prefix,
            )

        source_mp4_path = video_dir / f"{prefix}_source_{epoch:03d}.mp4"
        target_mp4_path = video_dir / f"{prefix}_target_{epoch:03d}.mp4"

        if args.visualize_episodes:
            print("Visualizing episodes...")
            parallel_visualize(
                args=args,
                experiment=experiment,
                env_id=args.source_env_id,
                model_path=model_path,
                mp4_path=source_mp4_path,
                epoch=epoch,
                domain_id=args.source_domain_id,
                reverse_observations=args.reverse_source_observations,
                reverse_actions=args.reverse_source_actions,
            )
            parallel_visualize(
                args=args,
                experiment=experiment,
                env_id=args.target_env_id,
                model_path=model_path,
                mp4_path=target_mp4_path,
                epoch=epoch,
                domain_id=args.target_domain_id,
                reverse_observations=args.reverse_target_observations,
                reverse_actions=args.reverse_target_actions,
            )

        if args.evaluate:
            print("Evaluating...")
            success_rate_dict = {}
            success_rate_dict["source"] = parallel_evaluate(
                args=args,
                env_id=args.source_env_id,
                model_path=model_path,
                domain_id=args.source_domain_id,
                reverse_observations=args.reverse_source_observations,
                reverse_actions=args.reverse_source_actions,
            )
            success_rate_dict["target"] = parallel_evaluate(
                args=args,
                env_id=args.target_env_id,
                model_path=model_path,
                domain_id=args.target_domain_id,
                reverse_observations=args.reverse_target_observations,
                reverse_actions=args.reverse_target_actions,
            )

            logfile.write(
                f"{epoch},{success_rate_dict['source']},{success_rate_dict['target']}\n"
            )
            logfile.flush()
            if experiment:
                experiment.log_metric(f"{prefix}_source_success_rate",
                                      success_rate_dict["source"],
                                      epoch=epoch)
                experiment.log_metric(f"{prefix}_target_success_rate",
                                      success_rate_dict["target"],
                                      epoch=epoch)

            s += "Success rate | "
            s += f"Source: {success_rate_dict['source']:.4f} | "
            s += f"Target: {success_rate_dict['target']:.4f} | "

        end = time.time()
        s += f"Duration: {end-start:.2f} sec. | "
        print(s)
