import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union

import cv2
import d4rl
import gym
import imageio.v2
import numpy as np
import torch
import torch.nn as nn
from comet_ml import ExistingExperiment, Experiment
from common.models import Policy
from omegaconf import DictConfig

from .make_policy import make_policy
from .utils import task_id_map_inv_m2p, get_success


def put_infos(frame: np.ndarray, infos: Dict):
    line = 1

    text = ""
    if "i" in infos:
        text += f"i:{infos['i']:3d} "
    if "t" in infos:
        text += f"t:{infos['t']:3d} "
    if "pos" in infos:
        text += f"pos:[{infos['pos'][0]:.2f}  {infos['pos'][1]:.2f}] "
    if "rew" in infos:
        text += f"rew:{infos['rew']:.2f} "

    cv2.putText(frame,
                text, (10, 15 * line),
                fontFace=cv2.FONT_HERSHEY_PLAIN,
                fontScale=1.0,
                color=(220, 220, 220))
    line += 1

    if "goal" in infos:
        goal = infos["goal"]
        text = "goal: ["
        for g in goal:
            text += f"{g:+.1f} "
        text += "]"
        cv2.putText(frame,
                    text, (10, 15 * line),
                    fontFace=cv2.FONT_HERSHEY_PLAIN,
                    fontScale=1.0,
                    color=(220, 220, 220))
        line += 1

    if "act" in infos:
        act = infos["act"]
        text = "act: ["
        for a in act:
            text += f"{a:+.1f} "
        text += "]"
        cv2.putText(frame,
                    text, (10, 15 * line),
                    fontFace=cv2.FONT_HERSHEY_PLAIN,
                    fontScale=1.0,
                    color=(220, 220, 220))
        line += 1

    if "obs" in infos:
        obs = infos["obs"]
        for i in range(len(obs) // 9 + 1):
            if i == 0:
                text = "obs: ["
            else:
                text = "      "

            for o in obs[i * 9:(i + 1) * 9]:
                text += f"{o:+.2f} "

            if i == (len(obs) // 9):
                text += "]"

            cv2.putText(frame,
                        text, (10, 15 * line),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(220, 220, 220))
            line += 1


def visualize(
    args: DictConfig,
    env_id: str,
    mp4_path: Optional[str],
    task_ids: List[int],
    model: nn.Module,
    domain_id: int,
    num_task_ids: int,
    env_kwargs: Dict = {"eval": True},
    reverse_obs: bool = False,
    reverse_act: bool = False,
    repeat: int = 10,
    skip: int = 5,
    fps: int = 20,
    device: str = "cuda:0",
    return_frames: bool = False,
    start_idx: int = 0,
):
    env = gym.make(env_id, **env_kwargs)
    env.reset()
    env.render(mode="rgb_array")
    del env

    env = gym.make(env_id, **env_kwargs)

    morph = env_id.split("-")[0]
    act_dim_domain = {"point": 2, "ant": 8, "maze2d": 2}[morph]

    policy = make_policy(
        args=args,
        env_id=env_id,
        model=model,
        reverse_observations=reverse_obs,
        reverse_actions=reverse_act,
        domain_id=domain_id,
        num_task_ids=num_task_ids,
        act_dim_domain=act_dim_domain,
    )

    if "maze2d" in env_id:
        maze_type = env_id.split("-")[1]
        targets = [
            env.id_to_xy[task_id_map_inv_m2p[maze_type][i]] for i in task_ids
        ]
    else:
        targets = [env.id_to_xy[i] for i in task_ids]

    env.reset()
    env.step(env.action_space.sample())
    env.render(mode="rgb_array")

    frames = []
    for i in range(repeat):
        env.reset()
        idx = np.random.randint(len(targets))
        # goal = targets[idx] + np.random.uniform(-0.5, 0.5, size=(2, ))
        goal = targets[idx]
        env.set_target(goal)
        env.set_init_xy()
        env.set_marker()
        obs = env.get_obs()
        done = False
        t = 1
        while not done:
            for _ in range(skip):
                act = policy(obs, goal)
                obs, rew, done, info = env.step(act)
                t += 1
                done |= get_success(obs, goal, env_id)
                if done:
                    break
            frame = env.render(mode="rgb_array").astype("uint8")
            put_infos(frame,
                      infos={
                          "i": start_idx + i,
                          "t": t,
                          "obs": obs,
                          "act": act,
                      })
            frames.append(frame)

    if return_frames or not mp4_path:
        return frames
    else:
        imageio.mimsave(mp4_path, frames, fps=fps)


def visualize_fn(
    start_idx: int,
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    num_episodes: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):

    if hasattr(args.policy, "decode_with_state"):
        decode_with_state = args.policy.decode_with_state
    else:
        decode_with_state = False

    # load model
    policy = Policy(
        state_dim=args.policy.state_dim,
        cond_dim=args.policy.cond_dim,
        out_dim=args.policy.out_dim,
        domain_dim=args.policy.domain_dim,
        latent_dim=args.policy.latent_dim,
        hid_dim=args.policy.hid_dim,
        num_hidden_layers=args.policy.num_hidden_layers,
        activation=args.policy.activation,
        repr_activation=args.policy.latent_activation,
        enc_sn=args.policy.spectral_norm,
        decode_with_state=decode_with_state,
    ).to(args.device)
    policy.load_state_dict(torch.load(model_path, map_location=args.device))

    # visualize
    frames = visualize(
        args=args,
        env_id=env_id,
        mp4_path=None,
        task_ids=args.task_ids,
        model=policy,
        domain_id=domain_id,
        num_task_ids=args.policy.cond_dim,
        reverse_obs=reverse_observations,
        reverse_act=reverse_actions,
        repeat=num_episodes,
        skip=args.skip_frames,
        fps=args.fps,
        device=args.device,
        return_frames=True,
        start_idx=start_idx,
    )

    return frames


def parallel_visualize(
    args: DictConfig,
    experiment: Optional[Union[Experiment, ExistingExperiment]],
    env_id: str,
    model_path: Union[str, Path],
    mp4_path: Union[str, Path],
    epoch: int,
    domain_id: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):

    num_processes = args.processes
    num_episodes = args.num_visualize_episodes
    assert num_episodes % num_processes == 0
    num_episodes_per_process = num_episodes // num_processes

    if torch.multiprocessing.get_start_method() == "fork":
        torch.multiprocessing.set_start_method("spawn", force=True)

    fn = partial(
        visualize_fn,
        args=args,
        env_id=env_id,
        model_path=model_path,
        domain_id=domain_id,
        num_episodes=num_episodes_per_process,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
    )

    p = mp.Pool(processes=args.processes)
    start_idxs = list(range(0, num_episodes, num_episodes_per_process))
    frames_list = list(p.map(fn, start_idxs))
    p.close()
    frames = []
    for f in frames_list:
        frames += f

    imageio.mimsave(mp4_path, frames, fps=args.fps)
    if experiment:
        experiment.log_asset(mp4_path, step=epoch)
    for _ in range(num_processes):
        print("\033[F\033[K", end="")
