from gym.spaces import Discrete, Box, MultiDiscrete, Space
import numpy as np
import tree  # pip install dm_tree
from typing import Union, Optional

from models.action_dist import ActionDistribution
from models.modelv2 import ModelV2
from utils.annotations import override
from utils.exploration.exploration import Exploration
from utils import force_tuple
from utils.framework import try_import_torch, TensorType
from utils.spaces.simplex import Simplex
from utils.spaces.space_utils import get_base_struct_from_space

torch, _ = try_import_torch()


class Random(Exploration):
    """A random action selector (deterministic/greedy for explore=False).

    If explore=True, returns actions randomly from `self.action_space` (via
    Space.sample()).
    If explore=False, returns the greedy/max-likelihood action.
    """

    def __init__(
        self, action_space: Space, *, model: ModelV2, framework: Optional[str], **kwargs
    ):
        """Initialize a Random Exploration object.

        Args:
            action_space: The gym action space used by the environment.
            framework: One of None, "tf", "tfe", "torch".
        """
        super().__init__(
            action_space=action_space, model=model, framework=framework, **kwargs
        )

        self.action_space_struct = get_base_struct_from_space(self.action_space)

    @override(Exploration)
    def get_exploration_action(
        self,
        *,
        action_distribution: ActionDistribution,
        timestep: Union[int, TensorType],
        explore: bool = True
    ):
        # Instantiate the distribution object.
        return self.get_torch_exploration_action(action_distribution, explore)

    def get_torch_exploration_action(
        self, action_dist: ActionDistribution, explore: bool
    ):
        if explore:
            req = force_tuple(
                action_dist.required_model_output_shape(
                    self.action_space, getattr(self.model, "model_config", None)
                )
            )
            # Add a batch dimension?
            if len(action_dist.inputs.shape) == len(req) + 1:
                batch_size = action_dist.inputs.shape[0]
                a = np.stack([self.action_space.sample() for _ in range(batch_size)])
            else:
                a = self.action_space.sample()
            # Convert action to torch tensor.
            action = torch.from_numpy(a).to(self.device)
        else:
            action = action_dist.deterministic_sample()
        logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device)
        return action, logp
