from typing import Dict, List

import d4rl
import gym
import numpy as np
from omegaconf import DictConfig
from stable_baselines3 import DDPG
from tqdm import tqdm

from ours.utils.utils import task_id_map_inv_m2p, get_success

from .make_policy import make_policy


def evaluate(
    args: DictConfig,
    env_id: str,
    task_ids: List[int],
    model: DDPG,
    domain_id: int,
    num_task_ids: int,
    env_kwargs: Dict = {"reward_type": "dense"},
    reverse_observations: bool = False,
    reverse_actions: bool = False,
    n_episodes: int = 30,
):
    env = gym.make(env_id, **env_kwargs)
    goal_to_id = {tuple(v.astype("int")): k for k, v in env.id_to_xy.items()}
    max_episode_steps = env.spec.max_episode_steps

    morph = env_id.split("-")[0]
    act_dim_domain = {"point": 2, "ant": 8, "maze2d": 2}[morph]
    scale = 1 if morph == "maze2d" else 4

    policy = make_policy(
        args=args,
        env_id=env_id,
        agent=model,
        reverse_observations=reverse_observations,
        reverse_actions=reverse_actions,
        domain_id=domain_id,
        num_task_ids=num_task_ids,
        act_dim_domain=act_dim_domain,
    )

    if "maze2d" in env_id:
        maze_type = env_id.split("-")[1]
        targets = [
            env.id_to_xy[task_id_map_inv_m2p[maze_type][i]] for i in task_ids
        ]
    else:
        targets = [env.id_to_xy[i] for i in task_ids]

    metrics_dict = {
        "ep_lengths": [],
        "rewards": [],
        "success": [],
    }
    for i in tqdm(range(n_episodes)):
        env.reset()
        idx = np.random.randint(len(targets))
        # goal = targets[idx] + np.random.uniform(-0.5, 0.5, size=(2, ))
        goal = targets[idx]
        env.set_target(goal)
        env.set_init_xy()
        env.set_marker()
        obs = env.get_obs()
        done = False
        cumulative_reward = 0
        t = 0
        while True:
            act = policy(obs, goal)
            obs, rew, done, info = env.step(act)
            cumulative_reward += rew
            t += 1
            if done or get_success(obs, goal, env_id):
                break

        metrics_dict["ep_lengths"].append(t)
        metrics_dict["rewards"].append(cumulative_reward)
        if t < max_episode_steps and cumulative_reward > -1000:
            metrics_dict["success"].append(1)
        else:
            metrics_dict["success"].append(0)

    del policy, env
    return metrics_dict
