from typing import List, Optional

import numpy as np
from gym import spaces
from stable_baselines3.common.envs import IdentityEnv


class InvalidActionEnvDiscrete(IdentityEnv):
    """
    Identity env with a discrete action space. Supports action masking.
    """

    def __init__(
        self,
        dim: Optional[int] = None,
        ep_length: int = 100,
        n_invalid_actions: int = 0,
    ):
        if dim is None:
            dim = 1
        assert (
            n_invalid_actions < dim
        ), f"Too many invalid actions: {n_invalid_actions} < {dim}"

        space = spaces.Discrete(dim)
        self.n_invalid_actions = n_invalid_actions
        self.possible_actions = np.arange(space.n)
        self.invalid_actions: List[int] = []
        super().__init__(space=space, ep_length=ep_length)

    def _choose_next_state(self) -> None:
        self.state = self.action_space.sample()
        # Randomly choose invalid actions that are not the current state
        potential_invalid_actions = [
            i for i in self.possible_actions if i != self.state
        ]
        self.invalid_actions = np.random.choice(
            potential_invalid_actions, self.n_invalid_actions, replace=False
        )

    def action_masks(self) -> List[bool]:
        return [action not in self.invalid_actions for action in self.possible_actions]


class InvalidActionEnvMultiDiscrete(IdentityEnv):
    """
    Identity env with a multidiscrete action space. Supports action masking.
    """

    def __init__(
        self,
        dims: Optional[List[int]] = None,
        ep_length: int = 100,
        n_invalid_actions: int = 0,
    ):
        if dims is None:
            dims = [1, 1]

        if n_invalid_actions > sum(dims) - len(dims):
            raise ValueError(
                f"Cannot find a valid action for each dim. Set n_invalid_actions <= {sum(dims) - len(dims)}"
            )

        space = spaces.MultiDiscrete(dims)
        self.n_invalid_actions = n_invalid_actions
        self.possible_actions = np.arange(sum(dims))
        self.invalid_actions: List[int] = []
        super().__init__(space=space, ep_length=ep_length)

    def _choose_next_state(self) -> None:
        self.state = self.action_space.sample()

        converted_state: List[int] = []
        running_total = 0
        for i in range(len(self.action_space.nvec)):
            converted_state.append(running_total + self.state[i])
            running_total += self.action_space.nvec[i]

        # Randomly choose invalid actions that are not the current state
        potential_invalid_actions = [
            i for i in self.possible_actions if i not in converted_state
        ]
        self.invalid_actions = np.random.choice(
            potential_invalid_actions, self.n_invalid_actions, replace=False
        )

    def action_masks(self) -> List[bool]:
        return [action not in self.invalid_actions for action in self.possible_actions]


class InvalidActionEnvMultiBinary(IdentityEnv):
    """
    Identity env with a multibinary action space. Supports action masking.
    """

    def __init__(
        self,
        dims: Optional[int] = None,
        ep_length: int = 100,
        n_invalid_actions: int = 0,
    ):
        if dims is None:
            dims = 1

        if n_invalid_actions > dims:
            raise ValueError(
                f"Cannot find a valid action for each dim. Set n_invalid_actions <= {dims}"
            )

        space = spaces.MultiBinary(dims)
        self.n_invalid_actions = n_invalid_actions
        self.possible_actions = np.arange(2 * dims)
        self.invalid_actions: List[int] = []
        super().__init__(space=space, ep_length=ep_length)

    def _choose_next_state(self) -> None:
        self.state = self.action_space.sample()

        converted_state: List[int] = []
        running_total = 0
        for i in range(self.action_space.n):
            converted_state.append(running_total + self.state[i])
            running_total += 2

        # Randomly choose invalid actions that are not the current state
        potential_invalid_actions = [
            i for i in self.possible_actions if i not in converted_state
        ]
        self.invalid_actions = np.random.choice(
            potential_invalid_actions, self.n_invalid_actions, replace=False
        )

    def action_masks(self) -> List[bool]:
        return [action not in self.invalid_actions for action in self.possible_actions]
