from typing import Callable

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class DiscreteSpaceInputFormatter(nn.Module):
    """Formats samples from a discrete space for input to a network.

    args:
        space: The space from which to format.
        output_dtype: The datatype in which to output the one-hot tensor.
    """
    def __init__(self, space: gym.spaces.Discrete, output_dtype: torch.dtype = torch.float32):
        assert isinstance(space, gym.spaces.Discrete)
        super().__init__()
        self.n = space.n
        self.output_dtype = output_dtype

    @property
    def output_size(self) -> int:
        """Returns the size of a single sample."""
        return self.n

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Converts the input to a tensor of one-hot values.

        Args:
            x: The tensor of discrete values (or already formatted values).

        Returns:
            Tesnor of one-hot values.
        """
        if x.dtype == self.output_dtype:
            # Observations may already be one-hot formatted.
            assert len(x.shape) == 2
            assert x.shape[1] == self.output_size
            return x
        elif x.dtype in [torch.int32, torch.int64]:
            return F.one_hot(x, num_classes=self.n).to(self.output_dtype)
        else:
            ValueError(f"Input is of invalid type: {x.dtype}")


class MultiDiscreteSpaceInputFormatter(nn.Module):
    """Formats samples from a multi-discrete space for input to a network.

    args:
        space: The space from which to format.
        output_dtype: The datatype in which to output the one-hot tensor.
    """
    def __init__(self, space: gym.spaces.MultiDiscrete, output_dtype: torch.dtype = torch.float32):
        assert isinstance(space, gym.spaces.MultiDiscrete)
        super().__init__()
        self.size = sum(space.nvec)
        self.dim_sizes = space.nvec
        self.output_dtype = output_dtype

    @property
    def output_size(self) -> int:
        """Returns the size of a single sample."""
        return self.size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Converts the input to a tensor of multi-one-hot values.

        Args:
            x: The tensor of multi-discrete indices (or already formatted values).

        Returns:
            Tesnor of multi-one-hot values.
        """
        assert len(x.shape) == 2, f"Expected 2D tensor but got tesnor of shape {x.shape}"
        if x.dtype == self.output_dtype:
            # Observations may already be one-hot formatted.
            assert x.shape[1] == self.size, f"Expected: {self.size}, but got {x.shape[1]}"
            return x
        elif x.dtype in [torch.int32, torch.int64]:
            # TODO(redacted): Vectorize this.
            # Since the training data is already formatted as multi-one-hot,
            # this is only used during visualization (at least right now).
            # So just do it with loops.
            batch_size = x.shape[0]
            output = torch.zeros((batch_size, self.size))
            for i in range(batch_size):
                base = 0
                for j, index in enumerate(x[i]):
                    output[i, base + index] = 1
                    base += self.dim_sizes[j]
            return output
        else:
            ValueError(f"Input is of invalid type: {x.dtype}")


class BoxSpaceInputFormatter(nn.Module):
    """Formats samples from a box space for input to a network.

    args:
        space: The space from which to format.
        should_normalize: Whether or not to normalize the observation wrt the box space.
    """
    def __init__(
            self,
            space: gym.spaces.Box,
            should_normalize: bool = True,
    ):
        assert isinstance(space, gym.spaces.Box)
        super().__init__()
        self.size = len(space.low)
        self.should_normalize = should_normalize
        if should_normalize:
            assert not np.any(np.isinf(space.low))
            assert not np.any(np.isinf(space.high))
        self.register_buffer("low", torch.FloatTensor(space.low))
        self.register_buffer("high", torch.FloatTensor(space.high))

    @property
    def output_size(self) -> int:
        """Returns the size of a single sample."""
        return self.size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Returns the tensor after optionally normalizing it.

        Args:
            x: A continuous-valued tensor.

        Returns:
            The tensor optionally normalized by the box low and high values.
        """
        assert len(x.shape) == 2, f"Expected 2D tensor but got tesnor of shape {x.shape}"
        if self.should_normalize:
            return self._normalize(x)
        return x

    def _normalize(self, x):
        """Normalizes the tensor based on the low and high values of the box.

        Specifically, this returns the value normalized to the range [-1, 1].
        """
        # The built-in `torch.clamp` doesn't work with vectors.
        x = torch.max(torch.min(x, self.high), self.low)
        return (((x - self.low) / (self.high - self.low)) - 0.5) * 2.0


def one_hot_to_index_transform(one_hot: np.ndarray) -> np.ndarray:
    """Converts a numpy array from one-hot format to index format.

    This is intended for use in data loading (e.g., in a torch dataset subclass) and not
    for use in any nn.Module class.

    Args:
        one_hot: The one-hot array to convert.

    Returns:
        Array of index values.
    """
    assert one_hot.ndim == 1
    return np.argmax(one_hot)


def get_numpy_space_transform(transform_str: str) -> Callable:
    if transform_str == "one_hot_to_index":
        return one_hot_to_index_transform
    else:
        raise ValueError(f"Invalid transform string: {transform_str}")


def _make_not_implemented_error(space: gym.spaces.Space):
    """Makes an error for invalid space usage."""
    return NotImplementedError("Space of type {} not implemented.".format(type(space)))


def get_space_size(space: gym.spaces.Space) -> int:
    """Gets the number of elements in a space.

    Only works for discrete spaces.

    Args:
        space: The space for which to get the size.

    Returns:
        The size of the space.
    """
    if isinstance(space, gym.spaces.Discrete):
        return space.n
    elif isinstance(space, gym.spaces.MultiDiscrete):
        return np.prod(space.nvec)
    else:
        raise _make_not_implemented_error(space)


def get_index_to_space_converter(space: gym.spaces.Space) -> Callable:
    """Gets a function that converts from an index to a element of the space.

    Only works for discrete spaces.

    Args:
        space: The space for which to define the converter.

    Returns:
        A callable that converts from an index into the space to an actual state of it.
    """
    if isinstance(space, gym.spaces.Discrete):
        return lambda x: x
    elif isinstance(space, gym.spaces.MultiDiscrete):
        index_sizes = tuple(space.nvec)
        return lambda x: np.unravel_index(x, index_sizes)
    else:
        raise _make_not_implemented_error(space)


def clip_actions_to_space_bounds(actions: torch.Tensor, space: gym.spaces.Box) -> torch.Tensor:
    """Clips a batch of actions to the bounds of the provided box space.

    Assumes that the low and high values are the same for all dimensions of the space.

    Args:
        actions: The actions to clip of shape (batch_size, action_dim).
        space: A box space defining the bounds.

    Returns:
        The clipped actions.
    """
    assert isinstance(space, gym.spaces.Box), "Action clipping only implemented for box spaces."
    low = torch.tensor(space.low).to(actions.dtype).to(actions.device)
    high = torch.tensor(space.high).to(actions.dtype).to(actions.device)
    # Use min/max b/c torch.clamp only works for scalars.
    return torch.min(torch.max(actions, low), high)
