import multiprocessing as mp
import os
import sys
import warnings
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"
sys.path.append("..")

import d4rl  # Import required to register environments
import gym
import h5py
import imageio.v2
import moviepy.editor as mpy
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import PPO
from tqdm import tqdm

from ours.utils.dataset_utils import DatasetWriter
from ours.utils.visualize_episodes import put_infos
from ours.utils.utils import get_success

warnings.filterwarnings("ignore")


def target_to_direction_id(target):
    x, y = target
    if y <= x and y > -x:
        return 0
    if y > x and y > -x:
        return 1
    if y > x and y <= -x:
        return 2
    if y <= x and y <= -x:
        return 3


def make_collect_policy(
    env: gym.Env,
    expert_path: Union[Path, str],
    device: str,
    env_id: str,
) -> Callable:
    raw_policy = PPO.load(expert_path, device=device)

    if "ant" in env_id:

        def goal_reaching_policy_fn(obs, target):
            # Target is given as goal position
            target = np.array(target)
            direction_id = target_to_direction_id(target)
            assert direction_id in [0, 1, 2, 3]

            obs = np.copy(obs[:-2])
            obs[:2] = np.mod(obs[:2] + 2, 4) - 2
            if direction_id == 0:
                obs[:2] += np.array([4, 0])
            elif direction_id == 1:
                obs[:2] += np.array([8, 4])
            elif direction_id == 2:
                obs[:2] += np.array([4, 8])
            elif direction_id == 3:
                obs[:2] += np.array([8, 4])
            target = obs[:2] + target

            obs = np.concatenate((obs, target), axis=-1)
            action = raw_policy.predict(obs)[0]
            return action
    elif "point" in env_id:

        def goal_reaching_policy_fn(obs, target):
            target = np.array(target)
            obs = np.copy(obs[:-2])
            obs[:2] = np.mod(obs[:2] + 2, 4) - 2

            obs = np.concatenate((obs, target), axis=-1)
            action = raw_policy.predict(obs)[0]
            return action

    policy = env.create_navigation_policy(
        goal_reaching_policy_fn,
        obs_to_robot=lambda obs: obs[:2],
        obs_to_target=lambda obs: obs[-2:],
        relative=False,
    )

    return policy


def collect_fn(
    idx_task_id_device: Tuple,
    env_id: str,
    env_kwargs: Dict,
    expert_path: Union[Path, str],
    n_task_ids: int,
    n_trajs: int,
) -> DatasetWriter:
    idx, task_id, device = idx_task_id_device
    env = gym.make(env_id, **env_kwargs)
    goal = env.id_to_xy[task_id]
    policy = make_collect_policy(env, expert_path, device, env_id)

    writer = DatasetWriter()

    current = mp.current_process()
    pos = current._identity[0] - 1
    total_trajs = (n_task_ids - 1) * n_trajs
    pbar = tqdm(
        total=total_trajs,
        desc=f"#{idx:>2} ",
        position=pos,
        bar_format="{l_bar}{bar:64}{r_bar}",
        leave=False,
    )

    step = 0
    for start_id in range(1, n_task_ids + 1):
        if start_id == task_id:
            continue
        start = env.id_to_xy[start_id]
        distance = env.get_distance(start, goal)
        step_per_square = min(40 + distance * 2.5, 52)
        max_step = distance * step_per_square

        episode = 0
        while episode < n_trajs:
            env.reset()
            start_ = start + 0.5 * np.random.uniform(-1, 1, size=(2, ))
            goal_ = goal + 0.5 * np.random.uniform(-1, 1, size=(2, ))
            env.set_target(goal_)
            env.set_init_xy(start_)
            obs = env.get_obs()

            done = False
            t = 0
            cumulative_reward = 0
            observations = [obs[:-2]]
            actions = []
            rewards = []
            dones = []
            while not (done or t >= max_step):
                act = policy(obs)
                obs, rew, done, infos = env.step(act)
                cumulative_reward += rew
                done |= get_success(obs, goal, env_id)
                t += 1
                observations.append(obs[:-2])
                actions.append(act)
                rewards.append(rew)
                dones.append(done)

            goals = [goal for _ in range(len(actions))]
            goal_ids = [task_id for _ in range(len(actions))]

            if cumulative_reward < -1000 or t >= max_step:
                continue
            else:
                step += t
                episode += 1
                pbar.update(1)

            writer.extend_data(
                s=observations[:-1],
                a=actions,
                s_=observations[1:],
                r=rewards,
                done=dones,
                goal=goals,
                goal_id=goal_ids,
            )

    return writer


def parallel_collect_demos(
        env_id: str,
        env_kwargs: Dict,
        expert_path: Union[Path, str],
        task_ids: List[int],
        n_task_ids: int,
        n_trajs: int = 7,
        devices: Tuple[int] = (0, ),
        n_process: int = 16,
) -> DatasetWriter:
    """Collect demonstrations for different goals in parallel
    """
    if torch.multiprocessing.get_start_method() == "fork":
        torch.multiprocessing.set_start_method("spawn", force=True)

    f = partial(
        collect_fn,
        env_id=env_id,
        env_kwargs=env_kwargs,
        expert_path=expert_path,
        n_task_ids=n_task_ids,
        n_trajs=n_trajs,
    )

    writer = DatasetWriter()
    n_devices = len(devices)
    args = [(idx, task_id, f"cuda:{devices[idx % n_devices]}")
            for idx, task_id in enumerate(task_ids)]

    print("Start collecting demonstrations...")
    with mp.Pool(n_process) as p:
        imap = p.imap(f, args)
        writers = list(imap)

    for writer_ in writers:
        writer.merge(writer_)

    print("Finished!")
    print(f"Total dataset length: {len(writer)}")

    return writer


def visualize_demos(
    env_id: str,
    env_kwargs: Dict,
    dataset: h5py.File,
    video_path: Union[Path, str],
    skip_frame: int = 5,
    max_frames: int = 1000,
    offset: int = 0,
    fps: int = 20,
) -> None:

    env = gym.make(env_id, **env_kwargs)
    env.reset()
    env.step(env.action_space.sample())
    env.render(mode="rgb_array")

    observations = np.array(dataset["observations"])
    actions = np.array(dataset["actions"])
    rewards = np.array(dataset["rewards"])
    goal = np.array(dataset["infos/goal"])

    frames = []
    for i in range(max_frames):
        idx = offset + i * skip_frame
        env.set_target(goal[idx])
        env.reset_to_state(observations[idx])
        env.set_marker()
        frame = env.render(mode="rgb_array").astype("uint8")
        put_infos(frame,
                  infos={
                      "t": idx,
                      "obs": observations[idx],
                      "act": actions[idx],
                      "rew": rewards[idx],
                      "goal": goal[idx],
                  })
        frames.append(frame)

    imageio.mimsave(video_path, frames, fps=fps)


def make_dataset(
        env_id: str,
        env_kwargs: Dict,
        expert_path: Union[Path, str],
        dataset_path: Union[Path, str],
        task_ids: List[int],
        n_task_ids: int,
        max_size: Optional[int] = None,
        visualize: bool = True,
        video_path: Optional[Union[Path, str]] = None,
        n_trajs: int = 7,
        devices: Tuple[int] = (0, ),
        n_process: int = 10,
) -> None:
    writer = parallel_collect_demos(
        env_id=env_id,
        env_kwargs=env_kwargs,
        expert_path=expert_path,
        task_ids=task_ids,
        n_task_ids=n_task_ids,
        n_trajs=n_trajs,
        devices=devices,
        n_process=n_process,
    )
    writer.write_dataset(dataset_path, max_size=max_size)

    if visualize:
        assert video_path is not None
        dataset = h5py.File(dataset_path, "r")
        visualize_demos(
            env_id=env_id,
            env_kwargs=env_kwargs,
            dataset=dataset,
            video_path=video_path,
        )


# yapf: disable
def main():
    base_args = OmegaConf.create({
        "env_id": "ant-umaze-v1",
        "expert_path": "experts/ant/ant-umaze-v1_ppo.zip",
        "n_trajs": 300,
        "devices": (0,),
        "n_process": 4,
    })
    cli_args = OmegaConf.from_cli()
    args = OmegaConf.merge(base_args, cli_args)

    morph = args.env_id.split("-")[0]
    dataset_dir = Path("datasets") / morph
    demo_dir = Path("demo") / morph

    dataset_dir.mkdir(exist_ok=True, parents=True)
    demo_dir.mkdir(exist_ok=True, parents=True)

    if "umaze" in args.env_id:
        TASK_IDS = list(range(1, 8))
        PROXY_TASK_IDS = list(range(1, 8))
    elif "medium" in args.env_id:
        TASK_IDS = list(range(1, 27))
        PROXY_TASK_IDS = list(range(1, 27))
    elif "large" in args.env_id:
        TASK_IDS = list(range(1, 47))
        PROXY_TASK_IDS = list(range(1, 47))
    else:
        raise ValueError
    PROXY_TASK_IDS.remove(7)
    INFERENCE_TASK_IDS = [7]

    env_kwargs = {
        "return_direction": True,
        "eval": False,
    }

    make_dataset(
        env_id=args.env_id,
        env_kwargs=env_kwargs,
        expert_path=args.expert_path,
        dataset_path=dataset_dir / f"{args.env_id}.hdf5",
        task_ids=TASK_IDS,
        n_task_ids=len(TASK_IDS),
        video_path=demo_dir / f"{args.env_id}.mp4",
        n_trajs=args.n_trajs,
        devices=args.devices,
        n_process=args.n_process,
    )


if __name__ == "__main__":
    main()
