import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union

import cv2
import d4rl
import gym
import imageio.v2
import numpy as np
import torch
import torch.nn as nn
from comet_ml import ExistingExperiment, Experiment
from common.dail.models import DAILAgent
from omegaconf import DictConfig

from .make_policy import make_policy
from ours.utils.utils import get_success


def put_infos(frame: np.ndarray, infos: Dict):
    line = 1

    text = ""
    if "i" in infos:
        text += f"i:{infos['i']:3d} "
    if "t" in infos:
        text += f"t:{infos['t']:3d} "
    if "pos" in infos:
        text += f"pos:[{infos['pos'][0]:.2f}  {infos['pos'][1]:.2f}] "
    if "rew" in infos:
        text += f"rew:{infos['rew']:.2f} "

    cv2.putText(frame,
                text, (10, 15 * line),
                fontFace=cv2.FONT_HERSHEY_PLAIN,
                fontScale=1.0,
                color=(220, 220, 220))
    line += 1

    if "goal" in infos:
        goal = infos["goal"]
        text = "goal: ["
        for g in goal:
            text += f"{g:+.1f} "
        text += "]"
        cv2.putText(frame,
                    text, (10, 15 * line),
                    fontFace=cv2.FONT_HERSHEY_PLAIN,
                    fontScale=1.0,
                    color=(220, 220, 220))
        line += 1

    if "act" in infos:
        act = infos["act"]
        text = "act: ["
        for a in act:
            text += f"{a:+.1f} "
        text += "]"
        cv2.putText(frame,
                    text, (10, 15 * line),
                    fontFace=cv2.FONT_HERSHEY_PLAIN,
                    fontScale=1.0,
                    color=(220, 220, 220))
        line += 1

    if "obs" in infos:
        obs = infos["obs"]
        for i in range(len(obs) // 9 + 1):
            if i == 0:
                text = "obs: ["
            else:
                text = "      "

            for o in obs[i * 9:(i + 1) * 9]:
                text += f"{o:+.2f} "

            if i == (len(obs) // 9):
                text += "]"

            cv2.putText(frame,
                        text, (10, 15 * line),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(220, 220, 220))
            line += 1


def visualize(
    args: DictConfig,
    env_id: str,
    mp4_path: Optional[str],
    task_ids: List[int],
    model: nn.Module,
    domain_id: int,
    num_task_ids: int,
    env_kwargs: Dict = {"eval": True},
    reverse_obs: bool = False,
    reverse_act: bool = False,
    repeat: int = 10,
    skip: int = 5,
    fps: int = 20,
    device: str = "cuda:0",
    return_frames: bool = False,
    start_idx: int = 0,
):
    env = gym.make(env_id, **env_kwargs)
    env.reset()
    env.render(mode="rgb_array")
    del env

    env = gym.make(env_id)
    goal_to_id = {tuple(v.astype("int")): k for k, v in env.id_to_xy.items()}

    morph = env_id.split("-")[0]
    act_dim_domain = {"point": 2, "ant": 8, "maze2d": 2}[morph]
    scale = 1 if morph == "maze2d" else 4

    policy = make_policy(
        args,
        env_id,
        model,
        reverse_obs,
        reverse_act,
        domain_id,
        num_task_ids,
        act_dim_domain,
    )

    targets = [env.id_to_xy[i] for i in task_ids]

    env.reset()
    env.step(env.action_space.sample())
    env.render(mode="rgb_array")

    frames = []
    for i in range(repeat):
        env.reset()
        idx = np.random.randint(len(targets))
        # goal = targets[idx] + np.random.uniform(-0.5, 0.5, size=(2, ))
        goal = targets[idx]
        env.set_target(goal)
        env.set_init_xy()
        env.set_marker()
        obs = env.get_obs()
        done = False
        t = 1
        while not done:
            for _ in range(skip):
                act = policy(obs, goal)
                obs, rew, done, info = env.step(act)
                t += 1
                if done or get_success(obs, goal, env_id):
                    break
            frame = env.render(mode="rgb_array").astype("uint8")
            put_infos(frame,
                      infos={
                          "i": start_idx + i,
                          "t": t,
                          "obs": obs,
                          "act": act,
                      })
            frames.append(frame)

    if return_frames or not mp4_path:
        return frames
    else:
        imageio.mimsave(mp4_path, frames, fps=fps)


def visualize_fn(
    start_idx: int,
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    num_episodes: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):

    # load model
    policy = DAILAgent(args).to(args.device)
    policy.load_state_dict(torch.load(model_path, map_location=args.device))

    # visualize
    frames = visualize(
        args=args,
        env_id=env_id,
        mp4_path=None,
        task_ids=args.task_ids,
        model=policy,
        domain_id=domain_id,
        num_task_ids=args.num_task_ids,
        reverse_obs=reverse_observations,
        reverse_act=reverse_actions,
        repeat=num_episodes,
        skip=args.skip_frames,
        fps=args.fps,
        device=args.device,
        return_frames=True,
        start_idx=start_idx,
    )

    return frames


def parallel_visualize(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]],
    env_id: str,
    model_path: Union[str, Path],
    mp4_path: Union[str, Path],
    epoch: int,
    domain_id: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):

    num_processes = args.processes
    num_episodes = args.num_visualize_episodes
    assert num_episodes % num_processes == 0
    num_episodes_per_process = num_episodes // num_processes

    if torch.multiprocessing.get_start_method() == "fork":
        torch.multiprocessing.set_start_method("spawn", force=True)

    fn = partial(
        visualize_fn,
        args=args,
        env_id=env_id,
        model_path=model_path,
        domain_id=domain_id,
        num_episodes=num_episodes_per_process,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
    )

    p = mp.Pool(processes=args.processes)
    start_idxs = list(range(0, num_episodes, num_episodes_per_process))
    frames_list = list(p.map(fn, start_idxs))
    p.close()
    frames = []
    for f in frames_list:
        frames += f

    imageio.mimsave(mp4_path, frames, fps=args.fps)
    if experiment:
        experiment.log_asset(mp4_path, step=epoch)
    for _ in range(num_processes):
        print("\033[F\033[K", end="")
