# Original implementation: https://github.com/abaisero/asym-rlpo
#
####
#
# Extended to informed POMDPs by anonymous authors (2025)
#
####

from __future__ import annotations

import re
from typing import Dict, Optional, Protocol, Tuple, List

import gym
import gym.spaces
import gym_pomdps
import copy

from asym_rlpo.utils.debugging import checkraise
from asym_rlpo.wrapper import IndexWrapper

from . import extra_hai, extra_lyu
from .env import (
    Action,
    Environment,
    EnvironmentType,
    Latent,
    Information,
    LatentType,
    Observation,
    State,
)


def make_gym_env(id: str, *, latent_type: LatentType) -> Environment:
    """makes a stateful gym environment or converts a fully observable openai environment into a partially observable openai environment"""

    if latent_type is not LatentType.STATE:
        raise ValueError(f'Invalid latent type {latent_type} for gym envs')

    try:
        return make_po_gym_env(id)

    except ValueError:
        print('Loading using gym.make')
        try:
            if id.startswith("POMDP-"):
                gym_env = make_pomdp_gym_env(id)
                return gym_env

            else:
                gym_env = gym.make(id)

        except gym.error.Error as e:
            raise ValueError from e

        else:
            if isinstance(gym_env.unwrapped, gym_pomdps.POMDP):
                return GymEnvironment(gym_env, EnvironmentType.FLAT)

            if re.fullmatch(r'extra-dectiger-v\d+', gym_env.spec.id):
                return GymEnvironment(gym_env, EnvironmentType.EXTRA_DECTIGER)

            if re.fullmatch(r'extra-cleaner-v\d+', gym_env.spec.id):
                return GymEnvironment(gym_env, EnvironmentType.EXTRA_CLEANER)

            if re.fullmatch(r'extra-car-flag-v\d+', gym_env.spec.id):
                return GymEnvironment(gym_env, EnvironmentType.EXTRA_CARFLAG)

            return GymEnvironment(gym_env, EnvironmentType.OTHER)


def make_po_gym_env(name: str) -> Environment:
    """convert a fully observable openai environment into a partially observable openai environment"""

    pattern = r'^PO-([\w:.-]+)-([\w:.-]+)-v(\d+)$'
    m = re.match(pattern, name)
    # m[0] is the full name
    # m[1] is the first capture, i.e., the type of partial observability
    # m[2] is the second capture, i.e., the name w/o the version
    # m[3] is the third capture, i.e., the version

    checkraise(
        m is not None,
        ValueError,
        f'env name {name} does not satisfy regex',
    )

    assert m is not None  # silly forcing of type checking
    po_type = m[1]
    env_name = m[2]
    version = m[3]
    non_po_name = f'{env_name}-v{version}'

    env: StatefulGymEnv

    if env_name == 'CartPole':
        indices_obs_dict = {
            'pos': [0, 2],  # ignore velocities
            'vel': [1, 3],  # ignore positions
            'full': [0, 1, 2, 3],  # ignore nothing
        }

        # Information space definition
        indices_info_dict = {
            'pos': [0, 2, 1],  # ignore 1 velocity dim
            'vel': [1, 3, 0],  # ignore 1 position dim
            'full': [0, 1, 2, 3],  # ignore nothing
        }

        checkraise(
            po_type in indices_obs_dict.keys(),
            ValueError,
            f'invalid partial observability {po_type}',
        )

        gym_env = gym.make(non_po_name)
        indices_obs = indices_obs_dict[po_type]
        indices_info = indices_info_dict[po_type]
        env = IndexWrapper(gym_env, indices_obs, indices_info)

    elif env_name == 'LunarLander':
        indices_dict = {
            'pos': [0, 1, 4, 6, 7],  # ignore velocities
            'vel': [2, 3, 5, 6, 7],  # ignore positions
            'full': [0, 1, 2, 3, 4, 5, 6, 7],  # ignore nothing
        }

        checkraise(
            po_type in indices_dict.keys(),
            ValueError,
            f'invalid partial observability {po_type}',
        )

        gym_env = gym.make(non_po_name)
        indices = indices_dict[po_type]
        env = IndexWrapper(gym_env, indices)

    elif env_name == 'Acrobot':
        indices_dict = {
            'pos': [0, 1, 2, 3],  # ignore velocities
            'vel': [4, 5],  # ignore positions
            'full': [0, 1, 2, 3, 4, 5],  # ignore nothing
        }

        checkraise(
            po_type in indices_dict.keys(),
            ValueError,
            f'invalid partial observability {po_type}',
        )

        gym_env = gym.make(non_po_name)
        indices = indices_dict[po_type]

        env = IndexWrapper(gym_env, indices)

    else:
        raise ValueError('invalid env name {env_name}')

    return GymEnvironment(env, EnvironmentType.OPENAI)


def make_pomdp_gym_env(name: str) -> Environment:

    env = gym.make(name)

    # Informed HeavenHell-3 environment with Earth Mover's distance between current position and heaven exit as additional information
    if name.startswith("POMDP-heavenhell_3"):
        state_information_mapping = [6, 5, 4, 3, 2, 1, 0, 4, 5, 6, 7, 8, 9, 10, 6, 5, 4, 3, 4, 5, 6, 2, 1, 0, 7, 8, 9, 10]
        information_values = sorted(set(state_information_mapping))
        return InformedHeavenHellGymEnvironment(env, information_values, state_information_mapping, EnvironmentType.FLAT)

    # Informed Shopping-5 environment with Earth Mover's distance between current position and item location as additional information
    elif name.startswith("POMDP-shopping_5"):
        state_information_mapping = [
            0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 2, 3, 4, 5, 6, 3, 4, 5, 6, 7, 4, 5, 6, 7, 8,
            1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4, 5, 4, 3, 4, 5, 6, 5, 4, 5, 6, 7,
            2, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4, 5, 4, 3, 4, 5, 6, 5, 4, 5, 6,
            3, 2, 1, 0, 1, 4, 3, 2, 1, 2, 5, 4, 3, 2, 3, 6, 5, 4, 3, 4, 7, 6, 5, 4, 5,
            4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 6, 5, 4, 3, 2, 7, 6, 5, 4, 3, 8, 7, 6, 5, 4,
            1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 2, 3, 4, 5, 6, 3, 4, 5, 6, 7,
            2, 1, 2, 3, 4, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4, 5, 4, 3, 4, 5, 6,
            3, 2, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4, 5, 4, 3, 4, 5,
            4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 4, 3, 2, 1, 2, 5, 4, 3, 2, 3, 6, 5, 4, 3, 4,
            5, 4, 3, 2, 1, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 6, 5, 4, 3, 2, 7, 6, 5, 4, 3,
            2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5, 2, 3, 4, 5, 6,
            3, 2, 3, 4, 5, 2, 1, 2, 3, 4, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4, 5,
            4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4, 3, 2, 3, 4,
            5, 4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 4, 3, 2, 1, 2, 5, 4, 3, 2, 3,
            6, 5, 4, 3, 2, 5, 4, 3, 2, 1, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 6, 5, 4, 3, 2,
            3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 1, 2, 3, 4, 5,
            4, 3, 4, 5, 6, 3, 2, 3, 4, 5, 2, 1, 2, 3, 4, 1, 0, 1, 2, 3, 2, 1, 2, 3, 4,
            5, 4, 3, 4, 5, 4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 2, 3,
            6, 5, 4, 3, 4, 5, 4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 4, 3, 2, 1, 2,
            7, 6, 5, 4, 3, 6, 5, 4, 3, 2, 5, 4, 3, 2, 1, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1,
            4, 5, 6, 7, 8, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4,
            5, 4, 5, 6, 7, 4, 3, 4, 5, 6, 3, 2, 3, 4, 5, 2, 1, 2, 3, 4, 1, 0, 1, 2, 3,
            6, 5, 4, 5, 6, 5, 4, 3, 4, 5, 4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1, 2,
            7, 6, 5, 4, 5, 6, 5, 4, 3, 4, 5, 4, 3, 2, 3, 4, 3, 2, 1, 2, 3, 2, 1, 0, 1,
            8, 7, 6, 5, 4, 7, 6, 5, 4, 3, 6, 5, 4, 3, 2, 5, 4, 3, 2, 1, 4, 3, 2, 1, 0
        ]
        information_values = sorted(set(state_information_mapping))

        return InformedShoppingGymEnvironment(env, information_values, state_information_mapping, EnvironmentType.FLAT)

    else:
        raise ValueError('invalid env name {env_name}')

    return GymEnvironment(env, EnvironmentType.OTHER)


class StatefulGymEnv(Protocol):
    """Protocol based on gym.Env which also contains state_space and state"""

    state_space: gym.spaces.Space
    action_space: gym.spaces.Discrete
    observation_space: gym.spaces.Space

    state: State

    def seed(self, seed=None):
        ...

    def reset(self) -> Observation:
        ...

    def step(self, action) -> Tuple[Observation, float, bool, Dict]:
        ...

    def render(self, mode="human"):
        ...


class GymEnvironment(Environment):
    """Converts gym.Env to the Environment protocol"""

    def __init__(self, env: StatefulGymEnv, type: EnvironmentType):
        self._env = env
        self.type = type
        self.latent_type = LatentType.STATE
        self.action_space: gym.spaces.Discrete = env.action_space
        self.observation_space: gym.spaces.Space = env.observation_space
        self.latent_space: gym.spaces.Space = env.state_space
        self.information_space: gym.spaces.Space = env.information_space

    def seed(self, seed: Optional[int] = None) -> None:
        self._env.seed(seed)
        self.action_space.seed(seed)
        self.observation_space.seed(seed)
        self.latent_space.seed(seed)
        self.information_space.seed(seed)

    def reset(self) -> Tuple[Observation, Latent, Information]:
        observation = self._env.reset()
        latent = self._env.state
        information = self._env.information()
        return observation, latent, information

    def step(self, action: Action) -> Tuple[Observation, Latent, float, bool, Information]:
        observation, reward, done, _ = self._env.step(action)
        latent = self._env.state
        information = self._env.information()
        return observation, latent, reward, done, information

    def render(self) -> None:
        self._env.render()


# Informed Heaven-Hell environment
class InformedHeavenHellGymEnvironment(Environment):

    def __init__(self, env: StatefulGymEnv, information_values: List, state_information_mapping: List, type: EnvironmentType):
        self._env = env
        self.type = type
        self.latent_type = LatentType.STATE
        self.action_space: gym.spaces.Discrete = env.action_space
        self.observation_space: gym.spaces.Space = env.observation_space
        self.latent_space: gym.spaces.Space = env.state_space
        
        # Information space
        self.information_space: gym.spaces.Space = gym.spaces.Discrete(len(information_values))
        self.information_values = information_values
        
        # Mapping between states and information
        self.state_information_mapping = state_information_mapping

    def seed(self, seed: Optional[int] = None) -> None:
        self._env.seed(seed)
        self.action_space.seed(seed)
        self.observation_space.seed(seed)
        self.latent_space.seed(seed)
        self.information_space.seed(seed)

    def reset(self) -> Tuple[Observation, Latent, Information]:
        observation = self._env.reset()
        latent = self._env.state
        information = self.information_values[self.state_information_mapping[latent]]
        return observation, latent, information

    def step(self, action: Action) -> Tuple[Observation, Latent, float, bool, Information]:
        observation, reward, done, _ = self._env.step(action)
        latent = self._env.state
        information = self.information_values[self.state_information_mapping[latent]]
        return observation, latent, reward, done, information

    def render(self) -> None:
        self._env.render()


# Informed Shopping environment
class InformedShoppingGymEnvironment(Environment):

    def __init__(self, env: StatefulGymEnv, information_values: List, state_information_mapping: List, type: EnvironmentType):
        self._env = env
        self.type = type
        self.latent_type = LatentType.STATE
        self.action_space: gym.spaces.Discrete = env.action_space
        self.observation_space: gym.spaces.Space = env.observation_space
        self.latent_space: gym.spaces.Space = env.state_space
        
        # Information space
        self.information_space: gym.spaces.Space = gym.spaces.Discrete(len(information_values))
        self.information_values = information_values
        
        # Mapping between states and information
        self.state_information_mapping = state_information_mapping

    def seed(self, seed: Optional[int] = None) -> None:
        self._env.seed(seed)
        self.action_space.seed(seed)
        self.observation_space.seed(seed)
        self.latent_space.seed(seed)
        self.information_space.seed(seed)

    def reset(self) -> Tuple[Observation, Latent, Information]:
        observation = self._env.reset()
        latent = self._env.state
        information = self.information_values[self.state_information_mapping[latent]]
        return observation, latent, information

    def step(self, action: Action) -> Tuple[Observation, Latent, float, bool, Information]:
        observation, reward, done, _ = self._env.step(action)
        latent = self._env.state
        information = self.information_values[self.state_information_mapping[latent]]
        return observation, latent, reward, done, information

    def render(self) -> None:
        self._env.render()
