import multiprocessing as mp
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Union

os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"

import d4rl
import gym
import numpy as np
import torch
import torch.nn as nn
from common.dail.models import DAILAgent
from omegaconf import DictConfig

from ours.utils.utils import task_id_map_inv_m2p, get_success

from .make_policy import make_policy


def evaluate(
    args: DictConfig,
    env_id: str,
    task_ids: List[int],
    model: DAILAgent,
    domain_id: int,
    num_task_ids: int,
    env_kwargs: Dict = {"eval": True},
    reverse_observations: bool = False,
    reverse_actions: bool = False,
    n_episodes: int = 30,
    device: str = "cuda:0",
):
    env = gym.make(env_id, **env_kwargs)
    max_episode_steps = env.spec.max_episode_steps

    morph = env_id.split("-")[0]
    act_dim_domain = {"point": 2, "ant": 8, "maze2d": 2}[morph]

    policy = make_policy(
        args=args,
        env_id=env_id,
        agent=model,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        domain_id=domain_id,
        num_task_ids=num_task_ids,
        act_dim_domain=act_dim_domain,
    )

    if "maze2d" in env_id:
        maze_type = env_id.split("-")[1]
        targets = [
            env.id_to_xy[task_id_map_inv_m2p[maze_type][i]] for i in task_ids
        ]
    else:
        targets = [env.id_to_xy[i] for i in task_ids]

    success = 0
    for i in range(n_episodes):
        env.reset()
        idx = np.random.randint(len(targets))
        # goal = targets[idx] + np.random.uniform(-0.5, 0.5, size=(2, ))
        goal = targets[idx]
        env.set_target(goal)
        env.set_init_xy()
        env.set_marker()
        obs = env.get_obs()
        done = False
        cumulative_reward = 0
        t = 1
        while not done:
            act = policy(obs, goal)
            obs, rew, done, info = env.step(act)
            cumulative_reward += rew
            t += 1
            if done or get_success(obs, goal, env_id):
                break

        if t <= max_episode_steps and cumulative_reward > -1000:
            success += 1

    del policy, env

    success_rate = success / n_episodes
    return success_rate


def evaluate_fn(
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    num_episodes: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):
    # load model
    policy = DAILAgent(args).to(args.device)
    policy.load_state_dict(torch.load(model_path, map_location=args.device))

    # evaluate
    success_rate = evaluate(
        args=args,
        env_id=env_id,
        task_ids=args.task_ids,
        model=policy,
        domain_id=domain_id,
        num_task_ids=args.num_task_ids,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        n_episodes=num_episodes,
    )
    return success_rate


def parallel_evaluate(
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):
    num_processes = args.processes
    num_episodes = args.num_eval_episodes
    assert num_episodes % num_processes == 0
    num_episodes_per_process = num_episodes // num_processes

    if torch.multiprocessing.get_start_method() == "fork":
        torch.multiprocessing.set_start_method("spawn", force=True)

    fn = partial(
        evaluate_fn,
        env_id=env_id,
        model_path=model_path,
        domain_id=domain_id,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        num_episodes=num_episodes_per_process,
    )

    p = mp.Pool(processes=args.processes)
    success_rates = list(p.map(fn, [args] * num_processes))
    p.close()
    return np.mean(success_rates)
