import logging
import multiprocessing as mp
from multiprocessing import Manager, Process
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import comet_ml
import gym
import imageio.v2
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
from tqdm import tqdm

import common.utils.metaworld_utils
import common.utils.policy
import common.utils.robot_utils
from common.ours.utils import maze2d_to_ant
from common.utils.policy import create_eval_policy
from common.utils.process_dataset import (ActionConverter,
                                          AntMazeTaskIDManager,
                                          ObservationConverter,
                                          PairedTrajDataset,
                                          PointMazeTaskIDManager)
from d4rl.pointmaze.maze_model import MazeEnv

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

TQDM_BAR_FORMAT = "{l_bar}{bar:50}{r_bar}"


def save_video(path: Path,
               frames: List[np.ndarray],
               epoch: int,
               fps: int = 20,
               experiment: Optional[comet_ml.Experiment] = None):
    path.parent.mkdir(exist_ok=True, parents=True)

    mp4_path = path.with_suffix('.mp4')
    imageio.mimsave(mp4_path, frames, fps=fps)
    logger.info(f'video is saved to {mp4_path}')

    if experiment:
        experiment.log_asset(mp4_path, step=int(epoch))


def evaluate(
    args: DictConfig,
    env_id: str,
    model: Any,
    goal_ids: List[int],
    domain_id: int,
    task_id_manager: Union[PointMazeTaskIDManager, AntMazeTaskIDManager],
    obs_converter: ObservationConverter,
    action_converter: ActionConverter,
    epoch: int,
    env_kwargs: Dict = {"eval": True},
    num_episodes: int = 100,
    num_render_episodes: int = 0,
    render_kwargs: Dict = {
        "height": 512,
        "width": 512
    },
    skip_rate: int = 5,
    fps: int = 20,
    video_path: Path = Path("."),
    experiment: Optional[comet_ml.Experiment] = None,
    traj_dataset: Optional[PairedTrajDataset] = None,
    prefix: str = "",
    pretrained_image_encoder: Optional[torch.nn.Module] = None,
):

    model.eval()
    policy = create_eval_policy(
        args=args,
        model=model,
        env_id=env_id,
        domain_id=domain_id,
        obs_converter=obs_converter,
        action_converter=action_converter,
        traj_dataset=traj_dataset,
    )

    if args.image_observation:
        env_kwargs["image_observation"] = True
        env_kwargs["domain_id"] = domain_id
        render_kwargs["resolution"] = (256, 256)
    else:
        if 'Lift' in env_id and num_render_episodes <= 0:
            env_kwargs['env_kwargs'] = {
                'use_camera_obs': False,
                'has_offscreen_renderer': False
            }

    env: MazeEnv = gym.make(env_id, **env_kwargs)
    max_episode_steps = env.spec.max_episode_steps

    if "medium" in env_id or "umaze" in env_id:
        get_success = common.utils.policy.get_success
    else:
        get_success = env.get_success

    if num_render_episodes != 0 and ("medium" in env_id or "umaze" in env_id):
        env.render(mode="rgb_array")
    env.reset()

    if task_id_manager:
        task_id_list = []
        for goal_id in goal_ids:
            task_id_list += task_id_manager.goal_id_to_task_id_list(
                goal_id=goal_id)
    else:
        task_id_list = goal_ids

    results = []
    steps = []
    frames = []
    is_adapt_phase = len(goal_ids) == 1
    use_complex_task = args.complex_task and is_adapt_phase

    pbar = tqdm(range(num_episodes), bar_format=TQDM_BAR_FORMAT)
    for episode in pbar:
        env.reset()

        if use_complex_task:
            checkpt_flag = False
            task_id = task_id_manager.n_task_id  # added ID
            reset_pos = np.array([2, 4])
            checkpt = np.array([6, 2])
            goal_pos = np.array([2, 2])
            if 'ant' in env_id:
                start_for_ant = maze2d_to_ant(reset_pos[None])[0]
                goal_for_ant = maze2d_to_ant(goal_pos[None])[0]
                checkpt = maze2d_to_ant(checkpt[None])[0]
                env.set_target_goal(goal_input=goal_for_ant)
                env.set_init_xy(init_xy=start_for_ant)
            else:
                env.reset_to_location(location=reset_pos)
                env.set_target(target_location=goal_pos)
            goal = goal_pos
            policy.init(task_id=task_id)
        else:
            while True:
                idx = np.random.randint(len(task_id_list))
                task_id = task_id_list[idx]
                if task_id_manager:
                    start_id = task_id_manager.task_id_to_start_id(
                        task_id=task_id)
                    goal_id = task_id_manager.task_id_to_goal_id(
                        task_id=task_id)
                else:
                    start_id = 0
                    goal_id = task_id

                goal = env.setup_task(goal_id=goal_id, start_id=start_id)
                try:
                    policy.init(task_id=task_id)
                    break
                except ValueError:
                    pass

        if "Stack" in env_id:
            env.reset()

        obs = env.get_obs()
        done = False
        cumulative_reward = 0
        t = 0
        while not done:
            if args.image_observation:

                image = torch.from_numpy(obs["image"]).to(args.device)[None]
                with torch.no_grad():
                    image_encoded = model.image_encoder(image)[0]
                image_encoded = image_encoded.cpu().numpy()

                mask_proprio = args.get("mask_proprio", False)
                if mask_proprio:
                    state = np.zeros_like(state)
                else:
                    state = obs["state"]

                obs = np.concatenate((state, image_encoded), axis=-1)

            act = policy(obs, task_id)
            obs, rew, done, _ = env.step(act)
            cumulative_reward += rew
            if use_complex_task and not checkpt_flag:
                checkpt_flag = get_success(obs,
                                           target=checkpt,
                                           env_id=env_id,
                                           checkpt=True)

            if episode < num_render_episodes and t % skip_rate == 0:
                frame = env.render(mode="rgb_array", **render_kwargs)
                frames.append(frame)

            t += 1
            if done or get_success(obs, goal, env_id):
                break

        ant_fail_flag = 'ant' in env_id and cumulative_reward < -1000
        # print(t, max_episode_steps)
        if t < max_episode_steps and not ant_fail_flag:
            if use_complex_task and not checkpt_flag:
                success = False
            else:
                success = get_success(obs, goal, env_id)
                if not success:  # Only reach-color env does force termination
                    assert 'color' in env_id
        else:
            success = False

        results.append(success)
        steps.append(t)

        pbar.set_postfix_str(f"Success Rate: {np.mean(results):.3f}")

    if len(frames) != 0:
        try:
            save_video(
                path=video_path,
                frames=frames,
                fps=fps,
                epoch=epoch,
                experiment=experiment,
            )
        except Exception as e:
            print('Error happened when saving a video.')
            print(e)

    success_rate = np.mean(results)
    steps_mean = np.mean(steps)
    model.train()
    return success_rate, steps_mean,


def evaluate_process(manager_dict, domain_index, *evaluate_args,
                     **evaluate_kwargs):
    success_rate, mean_steps = evaluate(*evaluate_args, **evaluate_kwargs)
    manager_dict[domain_index] = (success_rate, mean_steps)


def eval_callback(
    args,
    policy,
    epoch: int,
    goal_ids: List[int],
    savedir_root: Path,
    task_id_manager: Union[PointMazeTaskIDManager, AntMazeTaskIDManager],
    log_prefix: str = 'align',
    experiment: Optional[comet_ml.Experiment] = None,
    traj_dataset: Optional[PairedTrajDataset] = None,
    skip_source: bool = False,
    target_env: bool = False,
):

    if args.multienv:
        return multienv_eval_callback(
            args=args,
            policy=policy,
            epoch=epoch,
            goal_ids=goal_ids,
            savedir_root=savedir_root,
            task_id_manager=task_id_manager,
            log_prefix=log_prefix,
            experiment=experiment,
            traj_dataset=traj_dataset,
            skip_source=skip_source,
            target_env=target_env,
        )

    if not args.evaluate:
        return
    from common.utils.process_dataset import (get_action_converter,
                                              get_obs_converter)
    pretrained_image_encoder = None

    metrics_dict = {}
    if args.evaluate_parallel:
        mp.set_start_method('spawn', force=True)
        manager_dict = Manager().dict()
    else:
        manager_dict = {}
    processes = []
    video_path_list = []

    for domain_index, domain_info in enumerate(args.domains):
        if skip_source and domain_info['domain_id'] == 0:
            continue
        if trans_args := domain_info.get('obs_converter_args'):
            obs_converter = get_obs_converter(
                name=domain_info.get('obs_converter'), **trans_args)
        else:
            obs_converter = get_obs_converter(
                name=domain_info.get('obs_converter'))
        action_converter = get_action_converter(
            name=domain_info.get('action_converter'))

        video_path = savedir_root / f'{log_prefix}_domain{domain_info.domain_id}_{epoch}.mp4'
        video_path_list.append(video_path)

        if args.evaluate_parallel:
            process = Process(
                target=evaluate_process,
                args=(manager_dict, domain_index, args, domain_info.env,
                      policy, goal_ids, domain_info.domain_id, task_id_manager,
                      obs_converter, action_converter, epoch, {
                          "eval": True
                      }, args.n_eval_episodes, args.n_render_episodes, {
                          "height": 512,
                          "width": 512
                      }, 5, 20, video_path, None, traj_dataset, "",
                      pretrained_image_encoder))

            processes.append(process)
            process.start()
        else:
            manager_dict[domain_index] = evaluate(
                args=args,
                env_id=domain_info.env,
                model=policy,
                goal_ids=goal_ids,
                domain_id=domain_info.domain_id,
                task_id_manager=task_id_manager,
                obs_converter=obs_converter,
                action_converter=action_converter,
                epoch=epoch,
                env_kwargs={"eval": True},
                num_episodes=args.n_eval_episodes,
                num_render_episodes=args.n_render_episodes,
                render_kwargs={
                    "height": 512,
                    "width": 512
                },
                skip_rate=5,
                fps=20,
                video_path=video_path,
                experiment=None,
                traj_dataset=traj_dataset,
                prefix="",
                pretrained_image_encoder=pretrained_image_encoder,
            )

    if args.evaluate_parallel:
        for process in processes:
            process.join()

    for domain_index, (success_rate, mean_steps) in manager_dict.items():
        logger.info(
            f'Epoch {epoch}: Success rate (domain {args.domains[domain_index].domain_id}): {success_rate * 100:.1f}% '
            f'Steps len (domain {args.domains[domain_index].domain_id}): {mean_steps:.1f}'
        )
        metrics_dict[
            f'{log_prefix}_domain{args.domains[domain_index].domain_id}_success_rate'] = success_rate
        metrics_dict[
            f'{log_prefix}_domain{args.domains[domain_index].domain_id}_steps_mean'] = mean_steps

    if experiment is not None:
        experiment.log_metrics(metrics_dict, epoch=epoch)
        for video_path in video_path_list:
            if video_path.exists():
                experiment.log_asset(video_path, step=epoch)

    success_rate, mean_steps = list(manager_dict.items())[-1]
    return success_rate


def multienv_evaluate_process(manager_dict, domain_tag, *evaluate_args,
                              **evaluate_kwargs):
    success_rate, mean_steps = multienv_evaluate(*evaluate_args,
                                                 **evaluate_kwargs)
    manager_dict[domain_tag] = (success_rate, mean_steps)


def multienv_eval_callback(
    args,
    policy,
    epoch: int,
    goal_ids: List[int],
    savedir_root: Path,
    task_id_manager: Union[PointMazeTaskIDManager, AntMazeTaskIDManager],
    log_prefix: str = 'align',
    experiment: Optional[comet_ml.Experiment] = None,
    traj_dataset: Optional[PairedTrajDataset] = None,
    skip_source: bool = False,
    target_env: bool = False,
):
    mp.set_start_method('spawn', force=True)
    if not args.evaluate:
        return
    from common.utils.process_dataset import (get_action_converter,
                                              get_obs_converter)
    metrics_dict = {}
    phase = "Target" if target_env else "Proxy"
    pretrained_image_encoder = None

    metrics_dict = {}
    manager = Manager()
    manager_dict = manager.dict()
    processes = []
    video_path_list = []

    for domain_index, domain_info in enumerate(args.domains):
        if skip_source and not hasattr(domain_info, "target"):
            continue
        for env_index, env_tag in enumerate(domain_info.env_tags):
            if target_env ^ (env_tag == domain_info.target_env):
                continue
            env_info = domain_info[env_tag]
            domain_tag = f"{domain_index}_{env_tag}"

            if not target_env:
                env_goal_ids = [
                    gid for gid in goal_ids
                    if ((args.task_id_offset_list[env_index] <= gid)
                        & (gid < args.task_id_offset_list[env_index + 1]))
                ]
            else:
                env_goal_ids = goal_ids

            print(
                f"Phase: {phase} Task | Domain ID: {domain_info.domain_id} | Env: {env_tag} | Task IDs: {env_goal_ids}"
            )
            if trans_args := domain_info.get('obs_converter_args'):
                obs_converter = get_obs_converter(
                    name=domain_info.get('obs_converter'), **trans_args)
            else:
                obs_converter = get_obs_converter(
                    name=domain_info.get('obs_converter'))
            action_converter = get_action_converter(
                name=domain_info.get('action_converter'))

            video_path = savedir_root / f'{log_prefix}_domain{domain_info.domain_id}_{env_tag}_{epoch:03d}.mp4'
            video_path_list.append(video_path)

            process = Process(
                target=multienv_evaluate_process,
                args=
                (manager_dict, domain_tag, args, env_info.env,
                 policy, env_goal_ids, domain_info.domain_id, task_id_manager,
                 obs_converter, action_converter, epoch, {
                     "eval": True
                 },
                 args.
                 n_eval_episodes,
                 args.
                 n_render_episodes,
                 {
                     "height": 512,
                     "width": 512
                 },
                 5,
                 20, video_path,
                 None, traj_dataset, "", target_env, env_index,
                 pretrained_image_encoder))

            processes.append(process)
            process.start()

    for process in processes:
        process.join()

    for domain_tag, (success_rate, mean_steps) in sorted(manager_dict.items()):
        domain_id = domain_tag[0]
        logger.info(
            f'Epoch {epoch}: Success rate (domain {domain_id}, env_id {domain_tag[1:]}): {success_rate * 100:.1f}% '
            f'Steps len (domain {domain_id}): {mean_steps:.1f}')
        metrics_dict[
            f'{log_prefix}_domain{domain_id}_{domain_tag[1:]}_success_rate'] = success_rate
        metrics_dict[f'{log_prefix}_domain{domain_id}_steps_mean'] = mean_steps

    if experiment is not None:
        experiment.log_metrics(metrics_dict, epoch=epoch)
        for video_path in video_path_list:
            if video_path.exists():
                experiment.log_asset(video_path, step=epoch)

    success_rate, mean_steps = list(manager_dict.items())[-1]
    mp.set_start_method('fork', force=True)
    return success_rate


def multienv_evaluate(args: DictConfig,
                      env_id: str,
                      model: Any,
                      goal_ids: List[int],
                      domain_id: int,
                      task_id_manager: Union[PointMazeTaskIDManager,
                                             AntMazeTaskIDManager],
                      obs_converter: ObservationConverter,
                      action_converter: ActionConverter,
                      epoch: int,
                      env_kwargs: Dict = {"eval": True},
                      num_episodes: int = 100,
                      num_render_episodes: int = 0,
                      render_kwargs: Dict = {
                          "height": 512,
                          "width": 512
                      },
                      skip_rate: int = 5,
                      fps: int = 20,
                      video_path: Path = Path("."),
                      experiment: Optional[comet_ml.Experiment] = None,
                      traj_dataset: Optional[PairedTrajDataset] = None,
                      prefix: str = "",
                      target_env: bool = False,
                      source_env_idx: int = 0,
                      pretrained_image_encoder: Optional[nn.Module] = None):

    model.eval()
    policy = create_eval_policy(
        args=args,
        model=model,
        env_id=env_id,
        domain_id=domain_id,
        obs_converter=obs_converter,
        action_converter=action_converter,
        traj_dataset=traj_dataset,
    )

    env_kwargs["domain_id"] = domain_id
    render_kwargs["resolution"] = (512, 512)
    if args.image_observation:
        env_kwargs["image_observation"] = True
    env: MazeEnv = gym.make(env_id, **env_kwargs)
    max_episode_steps = env.spec.max_episode_steps

    if "medium" in env_id or "umaze" in env_id:
        get_success = common.utils.policy.get_success
    else:
        get_success = env.get_success

    if num_render_episodes != 0 and ("medium" in env_id or "umaze" in env_id):
        env.render(mode="rgb_array")
    env.reset()

    if task_id_manager:
        task_id_list = []
        for goal_id in goal_ids:
            task_id_list += task_id_manager.goal_id_to_task_id_list(
                goal_id=goal_id)
    else:
        task_id_list = goal_ids

    results = []
    steps = []
    frames = []
    for episode in tqdm(range(num_episodes), bar_format=TQDM_BAR_FORMAT):
        env.reset()
        while True:
            idx = np.random.randint(len(task_id_list))
            task_id = task_id_list[idx]
            if task_id_manager:
                start_id = task_id_manager.task_id_to_start_id(task_id=task_id)
                goal_id = task_id_manager.task_id_to_goal_id(task_id=task_id)
            else:
                start_id = 0
                goal_id = task_id

            if target_env:
                goal_id -= args.target_task_id_offset
                goal_id += args.target_goal_id
            else:
                goal_id -= args.task_id_offset_list[source_env_idx]
            goal = env.setup_task(goal_id=goal_id, start_id=start_id)
            try:
                policy.init(task_id=task_id)
                break
            except ValueError:
                pass

        if "Stack" in env_id:
            env.reset()

        obs = env.get_obs()
        done = False
        cumulative_reward = 0
        t = 0
        while not done:
            if args.image_observation:

                image = torch.from_numpy(obs["image"]).to(args.device)[None]
                with torch.no_grad():
                    image_encoded = model.image_encoder(image)[0]
                image_encoded = image_encoded.cpu().numpy()

                mask_proprio = args.get("mask_proprio", False)
                if mask_proprio:
                    state = np.zeros_like(state)
                else:
                    state = obs["state"]

                obs = np.concatenate((state, image_encoded), axis=-1)

            act = policy(obs, task_id)
            obs, rew, done, _ = env.step(act)
            cumulative_reward += rew

            if episode < num_render_episodes and t % skip_rate == 0:
                frame = env.render(mode="rgb_array", **render_kwargs)
                frames.append(frame)

            t += 1
            if done or get_success(obs, goal, env_id):
                break

        ant_fail_flag = 'ant' in env_id and cumulative_reward < -1000
        # print(t, max_episode_steps)
        if t < max_episode_steps and not ant_fail_flag:
            success = True
        else:
            success = False

        results.append(success)
        steps.append(t)

    if len(frames) != 0:
        try:
            save_video(
                path=video_path,
                frames=frames,
                fps=fps,
                epoch=epoch,
                experiment=experiment,
            )
        except Exception as e:
            print('Error happened when saving a video.')
            print(e)

    success_rate = np.mean(results)
    steps_mean = np.mean(steps)
    return success_rate, steps_mean,
