import sys

import gymnasium as gym
import numpy as np

from _tianshou_custom.examples.atari.atari_wrapper import make_atari_env
from _tianshou_custom.examples.mujoco.mujoco_env import make_mujoco_env
from _examples.mo_gymnasium.mo_gymnasium_env import make_mo_gymnasium_env
from utils.argparse_util import get_atari_argparser, get_mujoco_argparser, \
    get_mo_gymnasium_argparser, get_classic_control_argparser

CLASSIC_CONTROL = "classic_control"
BOX2D = "box2d"
MUJOCO = "mujoco"
ATARI = "atari"
ROBOTICS = "robotics"
TOY_TEXT = "toy_text"
MO_GYMNASIUM = "mo_gymnasium"

ENV_CATEGORY_KEYWORDS = {
    CLASSIC_CONTROL,
    BOX2D,
    MUJOCO,
    ATARI,
    ROBOTICS,
    TOY_TEXT,
    MO_GYMNASIUM
}

MAKE_ENV_MAP = {
    CLASSIC_CONTROL: make_mo_gymnasium_env,
    ATARI: make_atari_env,
    MUJOCO: make_mujoco_env,
    MO_GYMNASIUM: make_mo_gymnasium_env,
}

ENV_ARGPARSER_MAP = {
    CLASSIC_CONTROL: get_classic_control_argparser,
    ATARI: get_atari_argparser,
    MUJOCO: get_mujoco_argparser,
    MO_GYMNASIUM: get_mo_gymnasium_argparser,
}

CUSTOM_TASK_ARGPARSER_MAP = {
    'breakable-bottles-v0': MO_GYMNASIUM
}

def get_env_category(task):
    """Retrieve the category of a Gymnasium environment from its entry point."""
    try:
        env_spec = gym.spec(task)
        entry_point = env_spec.entry_point.lower()

        for keyword in ENV_CATEGORY_KEYWORDS:
            if keyword in entry_point:
                break
        else:
            keyword = CUSTOM_TASK_ARGPARSER_MAP.get(task, None)
            if keyword is None:
                return "Unknown Category"

        return keyword
    except gym.error.Error:
        print("Environment not found")
        sys.exit(1)


def make_env(task, *args, **kwargs):

    env_category = get_env_category(task)

    env, train_envs, test_envs, watch_env = MAKE_ENV_MAP[env_category](
        task=task,
        *args,
        **kwargs
    )

    return env, train_envs, test_envs, watch_env

def flatten_obs(obs, keys):
    if keys is None:
        return obs

    return np.concatenate([
        np.asarray(obs[k]).reshape(-1) for k in keys
    ], axis=0)
