from typing import List

import torch
from torch.distributions import MultivariateNormal

from action_masking.rlsampling.sets import ConvexSet


class ConvexSetNormal(MultivariateNormal):
    """
    This is the base clase for a normal distribution constrained to a set
    """

    def __init__(self, loc, covariance_matrix, sets: List[ConvexSet], validate_args: bool = False):
        super().__init__(loc, covariance_matrix, validate_args=validate_args)

        self.d = sets[0].d
        assert all([s.d == self.d for s in sets]), "All sets must have the same dimension."

        self._sets: List[ConvexSet] = sets
        # self._normalizing_constant = torch.ones(len(sets))
        self._normalizing_constant = None

        # Shape definitions in the torch distribution (https://bochang.me/blog/posts/pytorch-distributions/)
        # self.event_shape is the dimension
        # self.batch_shape is the shape of the batch
        # self.sample_shape is the shape of the sample_draw
        # Discussion of distributions and reparameterization trick: https://pytorch.org/docs/stable/distributions.html

    @property
    def mode(self) -> torch.Tensor:
        raise NotImplementedError

    @property
    def normalizing_constant(self) -> torch.Tensor:
        return self._normalizing_constant

    @normalizing_constant.setter
    def normalizing_constant(self, value: torch.Tensor):
        assert type(value) is torch.Tensor
        assert value.shape == torch.Size([len(self._sets)])
        self._normalizing_constant = value

    # All methods that do not work for the constrained normal distribution
    def expand(self, batch_shape, _instance=None):
        raise NotImplementedError("Not applicable for a constrained normal distribution")

    def rsample(self, sample_shape=torch.Size()):
        raise NotImplementedError("Not applicable for a constrained normal distribution")

    # Methods that must be implemented
    def sample(self, sample_shape: torch.Size = ...) -> torch.Tensor:
        # Input -> Output
        # () -> (#S,)
        raise NotImplementedError

    def log_prob(self, value) -> torch.Tensor:
        # Input -> Output
        # (d,) -> (#G,)
        # (#G, d) -> (#G,)
        # (n, #G, d) -> (n, #G)
        raise NotImplementedError

    def entropy(self):
        return super().entropy()

    def normal_log_prob(self, value) -> torch.Tensor:
        """Log prob of the standard normal distribution."""
        return super(ConvexSetNormal, self).log_prob(value)
