"""Definition of topological/mathematical spaces with probability densities defined on."""

from abc import ABC, abstractmethod
import torch
import numpy as np
import vmf
import spaces_utils as sut
import scipy.special as sc
from  torch.distributions import multivariate_normal


class Space(ABC):
    """Base class."""

    @abstractmethod
    def uniform(self, size, device):
        pass

    @abstractmethod
    def normal(self, mean, std, size, device):
        pass

    @abstractmethod
    def laplace(self, mean, std, size, device):
        pass

    @abstractmethod
    def generalized_normal(self, mean, lbd, p, size, device):
        pass

    @property
    @abstractmethod
    def dim(self):
        pass


class NRealSpace(Space):
    """Unconstrained space over the real numbers, i.e., R^N."""

    def __init__(self, n):
        self.n = n

    @property
    def dim(self):
        return self.n

    def uniform(self, size, device="cpu"):
        raise NotImplementedError("Not defined on R^n")

    def normal(self, mean, std, size, device="cpu"):
        """Sample from a Normal distribution in R^N.

        Args:
            mean: Value(s) to sample around.
            std: Concentration parameter of the distribution (=standard deviation).
            size: Number of samples to draw.
            device: torch device identifier
        """

        if len(mean.shape) == 1 and mean.shape[0] == self.n:
            mean = mean.unsqueeze(0)
        if not torch.is_tensor(std):
            std = torch.ones(self.n) * std
        if len(std.shape) == 1 and std.shape[0] == self.n:
            std = std.unsqueeze(0)
        assert len(mean.shape) == 2
        assert len(std.shape) == 2

        if torch.is_tensor(mean):
            mean = mean.to(device)
        if torch.is_tensor(std):
            std = std.to(device)

        return torch.randn((size, self.n), device=device) * std + mean

    def laplace(self, mean, lbd, size, device="cpu"):
        """Sample from a Laplace distribution in R^N.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        if len(mean.shape) == 1 and mean.shape[0] == self.n:
            mean = mean.unsqueeze(0)
        assert len(mean.shape) == 2
        assert isinstance(lbd, float)

        mean = mean.to(device)

        return (
            torch.distributions.Laplace(torch.zeros(self.n), lbd)
            .rsample(sample_shape=(size,))
            .to(device)
            + mean
        )

    def generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        if len(mean.shape) == 1 and mean.shape[0] == self.n:
            mean = mean.unsqueeze(0)
        assert len(mean.shape) == 2
        assert isinstance(lbd, float)

        result = sut.sample_generalized_normal(mean, lbd, p, (size, self.n))

        if device is not None:
            result = result.to(device)

        return result

    def mixed_generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: List of exponents of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        if len(mean.shape) == 1 and mean.shape[0] == self.n:
            mean = mean.unsqueeze(0)
        assert len(mean.shape) == 2
        assert isinstance(lbd, float)
        assert len(p) == self.n

        # result = sut.sample_mixed_generalized_normal(mean, lbd, p, (size, self.n))

        def ppf_mixed_generalized_normal(x, p, sigma):
            return torch.stack([torch.FloatTensor(np.sign(2 * x[:,i] - 1) * (sigma * sc.gammaincinv(1/p[i], np.abs(2 * x[:,i] - 1))) ** (1/p[i])) for i in range(len(p))], dim=-1)

        result = ppf_mixed_generalized_normal(np.random.rand(size, self.n), p, lbd)

        if device is not None:
            result = result.to(device)

        result += mean  # TODO

        return result


class NSphereSpace(Space):
    """N-dimensional hypersphere, i.e. {x | |x| = r and x € R^N}."""

    def __init__(self, n, r=1):
        self.n = n
        self._n_sub = n - 1
        self.r = r

    @property
    def dim(self):
        return self.n

    def uniform(self, size, device="cpu"):
        x = torch.randn((size, self.n), device=device)
        x /= torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))
        # TODO: Verify scaling strategy
        return x

    def non_uniform(self, kappa, size, device="cpu"):
        """Sample from a von Mises-Fisher distribution (=Normal distribution on a hypersphere).

        Args:
            kappa: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        mean = np.zeros((size // 2, self.n))
        mean[:, 0] = 1
        samples_np = vmf.sample_vMF(mean, kappa, size//2)
        samples = torch.Tensor(samples_np).to(device)

        x = torch.randn((size//2, self.n), device=device)
        x /= torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))

        samples = torch.cat([samples, x], dim=0)

        return samples

    def normal(self, mean, std, size, device="cpu"):
        """Sample from a Normal distribution in R^N and then project back on the sphere.

        Args:
            mean: Value(s) to sample around.
            std: Concentration parameter of the distribution (=standard deviation).
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)
        if not torch.is_tensor(std):
            std = torch.ones(self.n) * std
        std = std.to(device)

        assert mean.shape[1] == self.n
        assert torch.allclose(
            torch.sqrt((mean ** 2).sum(-1)), torch.Tensor([self.r]).to(device)
        )

        result = torch.randn((size, self.n), device=device) * std + mean
        # project back on sphere
        result /= torch.sqrt(torch.sum(result ** 2, dim=-1, keepdim=True))

        return result

    def laplace(self, mean, lbd, size, device="cpu"):
        """Sample from a Laplace distribution in R^N and then project back on the sphere.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)

        assert mean.shape[1] == self.n
        assert torch.allclose(
            torch.sqrt((mean ** 2).sum(-1)), torch.Tensor([self.r]).to(device)
        )

        result = NRealSpace(self.n).laplace(mean, lbd, size, device)
        # project back on sphere
        result /= torch.sqrt(torch.sum(result ** 2, dim=-1, keepdim=True))

        return result

    def generalized_normal(self, mean, lbd, p, size, device="cpu"):
        """Sample from a Generalized Normal distribution in R^N and then project back on the sphere.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)

        assert mean.shape[1] == self.n
        assert torch.allclose(
            torch.sqrt((mean ** 2).sum(-1)), torch.Tensor([self.r]).to(device)
        )

        result = NRealSpace(self.n).generalized_normal(
            mean=mean, lbd=lbd, p=p, size=size, device=device
        )
        # project back on sphere
        result /= torch.sqrt(torch.sum(result ** 2, dim=-1, keepdim=True))

        return result

    def von_mises_fisher(self, mean, kappa, size, device="cpu"):
        """Sample from a von Mises-Fisher distribution (=Normal distribution on a hypersphere).

        Args:
            mean: Value(s) to sample around.
            kappa: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        mean = mean.cpu().detach().numpy()

        if len(mean.shape) == 1:
            mean = np.repeat(np.expand_dims(mean, 0), size, axis=0)

        assert mean.shape[1] == self.n
        assert np.allclose(np.sqrt((mean ** 2).sum(-1)), self.r)

        samples_np = vmf.sample_vMF(mean, kappa, size)
        samples = torch.Tensor(samples_np).to(device)

        return samples


class NBoxSpace(Space):
    """Constrained box space in R^N, i.e. {x | a <= x_i <= b and x € R^N} for
    lower and upper limit a, b"""

    def __init__(self, n, min_=-1, max_=1):
        self.n = n
        self.min_ = min_
        self.max_ = max_

    @property
    def dim(self):
        return self.n

    def uniform(self, size, device="cpu"):
        return (
            torch.rand(size=(size, self.n), device=device) * (self.max_ - self.min_)
            + self.min_
        )

    def non_uniform(self, corr, size, device='cpu'):
        assert self.n % 2 == 0
        loc = torch.ones(self.n, device=device) * (self.max_ + self.min_) * 0.5
        cov_mat = 0.05 * torch.tensor([[1.0, 1.0/corr], [1.0/corr, 1.0]], device=device)
        # loc = torch.ones(self.n) * (self.max_ + self.min_) * 0.5
        # cov_mat = 0.05 * torch.tensor([[1.0, 1.0/corr], [1.0/corr, 1.0]])
        cov_mat = torch.block_diag(*([cov_mat] * (self.n // 2)))
        dist = multivariate_normal.MultivariateNormal(loc=loc, covariance_matrix=cov_mat)
        
        def sample(size):
            return dist.sample([size])
        
        values = sut.truncated_rejection_resampling(
            sample, self.min_, self.max_, size, self.n, device=device
        )
        return values.view((size, self.n))
    
    def normal(self, mean, std, size, device="cpu"):
        """Sample from a Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            std: Concentration parameter of the distribution (=standard deviation).
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)

        sampler = lambda s: torch.randn((s, self.n), device=device) * std + mean
        values = sut.truncated_rejection_resampling(
            sampler, self.min_, self.max_, size, self.n, device=device
        )

        return values.view((size, self.n))

    def laplace(self, mean, lbd, size, device="cpu"):
        """Sample from a Laplace distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: torch.distributions.Laplace(
            torch.zeros(self.n), lbd
        ).rsample(sample_shape=(s,)).to(device) + mean.to(device)
        values = sut.truncated_rejection_resampling(
            sampler, self.min_, self.max_, size, self.n, device=device
        )

        return values.view((size, self.n))

    def generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling(
            sampler, self.min_, self.max_, size, self.n, device=device
        )

        return values.view((size, self.n))

    def mixed_generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: List of exponents of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n
        assert len(p) == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_mixed_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling(
            sampler, self.min_, self.max_, size, self.n, device=device
        )

        return values.view((size, self.n))

    def generalized_normal_with_checkerboard_pattern(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling_checkerboard(
            sampler, self.min_, self.max_, size, self.n, device=device
        )

        return values.view((size, self.n))


class NHollowBallSpace(Space):
    """N-dimensional hollow ball, i.e. {x | r_min <= |x| <= r_max and x € R^N}."""

    def __init__(self, n, r_min=0.5, r_max=2.0):
        self.n = n
        self.r_min = r_min
        self.r_max = r_max

    @property
    def dim(self):
        return self.n

    def uniform(self, size, device="cpu"):
        # TODO: currently only uniform over the surface, not radius
        x = torch.randn((size, self.n), device=device)
        x /= torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))
        x *= (torch.rand((size, 1)) * (self.r_max - self.r_min) + self.r_min).to(device=device)
        return x

    def normal(self, mean, std, size, device="cpu"):
        """Sample from a Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            std: Concentration parameter of the distribution (=standard deviation).
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)

        sampler = lambda s: torch.randn((s, self.n), device=device) * std + mean
        values = sut.truncated_rejection_resampling_hollow_ball(
            sampler, self.r_min, self.r_max, size, self.n, device=device
        )

        return values.view((size, self.n))

    def laplace(self, mean, lbd, size, device="cpu"):
        """Sample from a Laplace distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: torch.distributions.Laplace(
            torch.zeros(self.n), lbd
        ).rsample(sample_shape=(s,)).to(device) + mean.to(device)
        values = sut.truncated_rejection_resampling_hollow_ball(
            sampler, self.r_min, self.r_max, size, self.n, device=device
        )

        return values.view((size, self.n))

    def generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling_hollow_ball(
            sampler, self.r_min, self.r_max, size, self.n, device=device
        )

        return values.view((size, self.n))

    def mixed_generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: List of exponents of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n
        assert len(p) == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_mixed_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling_hollow_ball(
            sampler, self.r_min, self.r_max, size, self.n, device=device
        )

        return values.view((size, self.n))


class NCubeGridSpace(Space):
    """N-dimensional cube grid, i.e. {x | for all i: -1 < x_i < -b or b < x_i < 1"""

    def __init__(self, n, cube_size=1.0, gap_size=0.25):
        self.n = n
        self.cube_size = cube_size
        self.gap_size = gap_size
        self.limits = [
            -self.cube_size - 0.5 * self.gap_size,
            -0.5 * self.gap_size,
            0.5 * self.gap_size,
            self.cube_size + 0.5 * self.gap_size
        ]

    @property
    def dim(self):
        return self.n

    def uniform(self, size, device="cpu"):
        x = (torch.rand(size=(size, self.n), device=device) * 2 - 1) * self.cube_size
        x[x >= 0] += 0.5 * self.gap_size
        x[x < 0] -= 0.5 * self.gap_size
        return x

    def normal(self, mean, std, size, device="cpu"):
        """Sample from a Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            std: Concentration parameter of the distribution (=standard deviation).
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        mean = mean.to(device)

        sampler = lambda s: torch.randn((s, self.n), device=device) * std + mean
        values = sut.truncated_rejection_resampling_cube_grid(
            sampler, self.limits, size, self.n, device=device
        )

        return values.view((size, self.n))

    def laplace(self, mean, lbd, size, device="cpu"):
        """Sample from a Laplace distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: torch.distributions.Laplace(
            torch.zeros(self.n), lbd
        ).rsample(sample_shape=(s,)).to(device) + mean.to(device)
        values = sut.truncated_rejection_resampling_cube_grid(
            sampler, self.limits, size, self.n, device=device#, mu=mean.to(device)
        )

        return values.view((size, self.n))

    def generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: Exponent of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling_cube_grid(
            sampler, self.limits, size, self.n, device=device
        )

        return values.view((size, self.n))

    def mixed_generalized_normal(self, mean, lbd, p, size, device=None):
        """Sample from a Generalized Normal distribution in R^N and then restrict the samples to a box.

        Args:
            mean: Value(s) to sample around.
            lbd: Concentration parameter of the distribution.
            p: List of exponents of the distribution.
            size: Number of samples to draw.
            device: torch device identifier
        """

        assert len(mean.shape) == 1 or (len(mean.shape) == 2 and len(mean) == size)
        assert mean.shape[-1] == self.n
        assert len(p) == self.n

        if len(mean.shape) == 1:
            mean = mean.unsqueeze(0)

        sampler = lambda s: sut.sample_mixed_generalized_normal(mean, lbd, p, (s, self.n))
        values = sut.truncated_rejection_resampling_cube_grid(
            sampler, self.limits, size, self.n, device=device
        )

        return values.view((size, self.n))
