import multiprocessing as mp
from functools import partial
from pathlib import Path
from typing import Dict, List, Union

import d4rl
import gym
import numpy as np
import torch
import torch.nn as nn
from common.models import Policy
from omegaconf import DictConfig

from .make_policy import make_policy
from .utils import task_id_map_inv_m2p, get_success


def evaluate(
    args: DictConfig,
    env_id: str,
    task_ids: List[int],
    model: nn.Module,
    domain_id: int,
    num_task_ids: int,
    env_kwargs: Dict = {"eval": True},
    reverse_observations: bool = False,
    reverse_actions: bool = False,
    n_episodes: int = 30,
    device: str = "cuda:0",
):
    env = gym.make(env_id, **env_kwargs)
    max_episode_steps = env.spec.max_episode_steps
    print(env_id)

    morph = env_id.split("-")[0]
    act_dim_domain = {"point": 2, "ant": 8, "maze2d": 2}[morph]

    policy = make_policy(
        args=args,
        env_id=env_id,
        model=model,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        domain_id=domain_id,
        num_task_ids=num_task_ids,
        act_dim_domain=act_dim_domain,
    )

    if "maze2d" in env_id:
        maze_type = env_id.split("-")[1]
        targets = [
            env.id_to_xy[task_id_map_inv_m2p[maze_type][i]] for i in task_ids
        ]
    else:
        targets = [env.id_to_xy[i] for i in task_ids]

    success = 0
    for i in range(n_episodes):
        env.reset()
        idx = np.random.randint(len(targets))
        # goal = targets[idx] + np.random.uniform(-0.5, 0.5, size=(2, ))
        goal = targets[idx]
        env.set_target(goal)
        env.set_init_xy()
        env.set_marker()
        obs = env.get_obs()
        done = False
        cumulative_reward = 0
        t = 1
        while not done:
            act = policy(obs, goal)
            obs, rew, done, info = env.step(act)
            cumulative_reward += rew
            t += 1
            if done or get_success(obs, goal, env_id):
                break

        if t <= max_episode_steps and cumulative_reward > -1000:
            success += 1

    del policy, env

    success_rate = success / n_episodes
    return success_rate


def evaluate_fn(
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    num_episodes: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):

    if hasattr(args.policy, "decode_with_state"):
        decode_with_state = args.policy.decode_with_state
    else:
        decode_with_state = False

    # load model
    policy = Policy(
        state_dim=args.policy.state_dim,
        cond_dim=args.policy.cond_dim,
        out_dim=args.policy.out_dim,
        domain_dim=args.policy.domain_dim,
        latent_dim=args.policy.latent_dim,
        hid_dim=args.policy.hid_dim,
        num_hidden_layers=args.policy.num_hidden_layers,
        activation=args.policy.activation,
        repr_activation=args.policy.latent_activation,
        enc_sn=args.policy.spectral_norm,
        decode_with_state=decode_with_state,
    ).to(args.device)
    policy.load_state_dict(torch.load(model_path, map_location=args.device))

    # evaluate
    success_rate = evaluate(
        args=args,
        env_id=env_id,
        task_ids=args.task_ids,
        model=policy,
        domain_id=domain_id,
        num_task_ids=args.policy.cond_dim,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        n_episodes=num_episodes,
    )
    return success_rate


def parallel_evaluate(
    args: DictConfig,
    env_id: str,
    model_path: Union[str, Path],
    domain_id: int,
    reverse_observations: bool = False,
    reverse_actions: bool = False,
):
    num_processes = args.processes
    num_episodes = args.num_eval_episodes
    assert num_episodes % num_processes == 0
    num_episodes_per_process = num_episodes // num_processes

    if torch.multiprocessing.get_start_method() == "fork":
        torch.multiprocessing.set_start_method("spawn", force=True)

    fn = partial(
        evaluate_fn,
        env_id=env_id,
        model_path=model_path,
        domain_id=domain_id,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        num_episodes=num_episodes_per_process,
    )

    p = mp.Pool(processes=args.processes)
    success_rates = list(p.map(fn, [args] * num_processes))
    p.close()
    return np.mean(success_rates)
