import time
import warnings
from pathlib import Path
from typing import Optional, Union

from comet_ml import ExistingExperiment, Experiment
from omegaconf import DictConfig

from utils.evaluate import parallel_evaluate
from utils.visualize_episodes import parallel_visualize


def _evaluate(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]] = None,
    log_prefix: str = "",
):
    logdir = Path(args.logdir)
    model_dir = logdir / "model"
    video_dir = logdir / "video"

    log_file = open(logdir / f"{log_prefix}_eval.csv", "a")
    log_file.write(f"epoch,source_success_rate,target_success_rate,\n")

    video_dir.mkdir(exist_ok=True)

    latest_model_path = sorted(model_dir.glob("*"))[-1]
    max_epoch = int(str(latest_model_path.stem)[:3])
    for model_path in sorted(model_dir.glob("*")):
        start = time.time()
        epoch = int(str(model_path.stem)[:3])
        s = f"Epoch: {epoch:3d}/{max_epoch:3d} | "

        source_mp4_path = video_dir / f"source_{epoch:03d}.mp4"
        target_mp4_path = video_dir / f"target_{epoch:03d}.mp4"

        if args.visualize_episodes:
            print("Visualizing episodes...")
            parallel_visualize(
                args=args,
                experiment=experiment,
                env_id=args.source_env_id,
                model_path=model_path,
                mp4_path=source_mp4_path,
                epoch=epoch,
                domain_id=args.source_domain_id,
                reverse_observations=args.reverse_source_observations,
                reverse_actions=args.reverse_source_actions,
            )
            parallel_visualize(
                args=args,
                experiment=experiment,
                env_id=args.target_env_id,
                model_path=model_path,
                mp4_path=target_mp4_path,
                epoch=epoch,
                domain_id=args.target_domain_id,
                reverse_observations=args.reverse_target_observations,
                reverse_actions=args.reverse_target_actions,
            )

        if args.evaluate:
            print("Evaluating...")
            success_rate_dict = {}
            success_rate_dict["source"] = parallel_evaluate(
                args=args,
                env_id=args.source_env_id,
                model_path=model_path,
                domain_id=args.source_domain_id,
                reverse_observations=args.reverse_source_observations,
                reverse_actions=args.reverse_source_actions,
            )
            success_rate_dict["target"] = parallel_evaluate(
                args=args,
                env_id=args.target_env_id,
                model_path=model_path,
                domain_id=args.target_domain_id,
                reverse_observations=args.reverse_target_observations,
                reverse_actions=args.reverse_target_actions,
            )

            if experiment:
                experiment.log_metric(f"source_success_rate",
                                      success_rate_dict["source"],
                                      step=epoch)
                experiment.log_metric(f"target_success_rate",
                                      success_rate_dict["target"],
                                      step=epoch)

            s += "Success rate | "
            s += f"Source: {success_rate_dict['source']:.4f} | "
            s += f"Target: {success_rate_dict['target']:.4f} | "
            log_file.write(f"{epoch},")
            log_file.write(f"{success_rate_dict['source']},")
            log_file.write(f"{success_rate_dict['target']},")
            log_file.write(f"\n")
            log_file.flush()

        end = time.time()
        s += f"Duration: {end-start:.2f} sec. | "
        print(s)

    log_file.close()
