import math

import torch
import torch.distributions as D
import numpy as np
import abc

from torch.distributions import Uniform, SigmoidTransform, AffineTransform, TransformedDistribution


def gaussian_log_prob(x, mu=torch.tensor(0.0), sigma=torch.tensor(1.0)):
    return -0.5 * math.log(2 * math.pi) - torch.log(sigma) - 0.5 * ((x - mu) / sigma) ** 2


def cauchy_log_prob(x, loc=torch.tensor(0.0), scale=torch.tensor(1.0)):
    return -math.log(math.pi) - torch.log(scale) - torch.log1p(((x - loc) / scale) ** 2)


def centered_positive_half_cauchy_log_prob(x, scale=torch.tensor(1.0)):
    assert torch.all(x >= 0)
    return math.log(2) + cauchy_log_prob(x, scale=scale)


class CustomDistribution(abc.ABC):
    @abc.abstractmethod
    def log_prob_instance(self, x):
        """
        Evaluate log prob for a single sample.
        """
        pass

    def log_prob(self, x: torch.Tensor):
        return torch.stack([self.log_prob_instance(element) for element in x])  # Could be parallelized?

    @abc.abstractmethod
    def sample_instance(self):
        """
        Draw a single sample.
        """
        pass

    def sample(self, n):
        return torch.stack([self.sample_instance() for _ in range(n)])  # Could be parallelized?


class Funnel(CustomDistribution):
    def __init__(self, n_dim: int = 2, beta: float = 1.0):
        super().__init__()
        self.n_dim = n_dim
        self.beta = beta
        self.base_sigma = torch.tensor(3.0 / self.beta)

    def log_prob_instance(self, x):
        try:
            part0 = gaussian_log_prob(x[0], sigma=self.base_sigma)
            parts = gaussian_log_prob(x[1:], sigma=torch.exp(x[0] / 2)).sum()
            return part0 + parts
        except ValueError:
            return torch.tensor(torch.nan).view(1)

    def sample_instance(self):
        x = torch.zeros(self.n_dim)
        x[0] = torch.randn(1) * self.base_sigma
        x[1:] = torch.randn(self.n_dim - 1) * torch.exp(x[0] / 2)
        return x


class CorrelatedGaussian(CustomDistribution):
    def __init__(self, n_dim: int = 2, scale: float = 1.0, mean=None, scale_tril=None):
        super().__init__()
        self.n_dim = n_dim
        if mean is None:
            mean = torch.randn(n_dim)
        if scale_tril is None:
            scale_tril = torch.tril(torch.randn(n_dim, n_dim) * scale)
            scale_tril[range(n_dim), range(n_dim)] = torch.abs(scale_tril[range(n_dim), range(n_dim)])
        self.dist = D.MultivariateNormal(loc=mean, scale_tril=scale_tril)

    def log_prob_instance(self, x):
        return self.dist.log_prob(x)

    def sample_instance(self):
        return self.dist.sample()


class Mixture(CustomDistribution):
    def __init__(self, mix, components):
        super().__init__()
        self.n_dim = components[0].n_dim
        self.mix = mix
        self.components = components

    def log_prob_instance(self, x):
        return self.log_prob(x.reshape(1, -1)).sum()

    def log_prob(self, x):
        component_probs = self.mix
        log_probs_stack = torch.stack([c.log_prob(x) for c in self.components])
        log_mix_stack = torch.log(torch.hstack([component_probs.reshape(-1, 1) for _ in range(len(x))]))
        return torch.logsumexp(log_mix_stack + log_probs_stack, dim=0)

    def sample_instance(self):
        component_index = D.Categorical(probs=self.mix).sample()
        return self.components[component_index].sample_instance()


class GaussianMixture(Mixture):
    def __init__(self, n_dim=2, scale: float = 1.0, n_components=5, mix=None):
        if mix is None:
            mix = torch.ones(n_components) / n_components  # Uniform mixing
        super().__init__(
            mix=mix,
            components=[CorrelatedGaussian(n_dim=n_dim, scale=scale) for _ in range(n_components)]
        )


def funnel_log_prob(x: torch.Tensor, a: float = 1.0, b: float = 0.5) -> torch.Tensor:
    base_dist = torch.distributions.Normal(loc=0, scale=a)
    _a = base_dist.log_prob(x[:, 0])

    cond_dist_mean = torch.zeros_like(_a)
    cond_dist_scale = torch.exp(b * x[:, 0])
    cond_dist = torch.distributions.Normal(loc=cond_dist_mean, scale=cond_dist_scale)

    n_dim = x.shape[1]
    _b = torch.stack([cond_dist.log_prob(x[:, i]) for i in range(1, n_dim)]).sum(dim=0)
    return _a + _b


class TestDistribution1(CustomDistribution):
    def __init__(self):
        """
        Structure:
        * x ~ N(0, 1)
        * y ~ N(x, 1)
        * z ~ N(y ** 2, sigmoid(x))
        * w ~ N(x, exp(y))
        """
        self.n_dim = 4

    def sample_instance(self):
        return self.sample(1).reshape(1, -1)

    def sample(self, n: int):
        x_samples = torch.randn(n)
        y_samples = x_samples + torch.randn(n)
        z_samples = y_samples ** 2 + torch.randn(n) * torch.sqrt(torch.sigmoid(x_samples))
        w_samples = x_samples + torch.randn(n) * torch.sqrt(torch.exp(y_samples))
        samples = torch.stack([x_samples, y_samples, z_samples, w_samples]).T
        return samples

    def log_prob_instance(self, x):
        try:
            x_ = x[0]
            y_ = x[1]
            z_ = x[2]
            w_ = x[3]

            total = (
                    torch.distributions.Normal(0, 1).log_prob(x_) +
                    torch.distributions.Normal(x_, 1).log_prob(y_) +
                    torch.distributions.Normal(y_ ** 2, torch.sqrt(torch.sigmoid(x_))).log_prob(z_) +
                    torch.distributions.Normal(x_, torch.sqrt(torch.exp(y_))).log_prob(w_)
            )
            return total
        except ValueError:
            return torch.tensor(torch.nan)


class TestDistribution2(CustomDistribution):
    def Logistic(self, loc, scale):
        return TransformedDistribution(Uniform(0, 1), [SigmoidTransform().inv, AffineTransform(loc, scale)])

    def __init__(self):
        """
        Structure:
        * A ~ N(0, 1)
        * B ~ Laplace(A ** 2, exp(A))
        * C ~ Logistic(A, 0.01)
        * D ~ N((B + C) / 2, log1p(exp(C)))
        * E ~ N(D ** 2 / 4, 0.01)
        * F ~ Logistic(-3, exp(B / 200))
        * G ~ N(F ** 2 / 10, exp(F / 10))
        * H ~ Laplace(C * E / 10, (E / 10)**2)
        """
        self.n_dim = 8

    def sample_instance(self):
        a_ = torch.randn(1)
        b_ = torch.distributions.Laplace(a_ ** 2, torch.sqrt(torch.exp(a_))).sample()
        c_ = self.Logistic(a_, math.sqrt(0.01)).sample()
        d_ = torch.distributions.Normal((b_ + c_) / 2, torch.sqrt(torch.log1p(torch.exp(c_)))).sample()
        e_ = torch.distributions.Normal(d_ ** 2 / 4, math.sqrt(0.01)).sample()
        f_ = self.Logistic(-3, torch.sqrt(torch.exp(b_ / 200))).sample()
        g_ = torch.distributions.Normal(f_ ** 2 / 10, torch.sqrt(torch.exp(f_ / 10))).sample()
        h_ = torch.distributions.Laplace(c_ * e_ / 10, torch.sqrt(torch.square(e_ / 10))).sample()
        sample = torch.concat([a_, b_, c_, d_, e_, f_, g_, h_])
        return sample

    def log_prob_instance(self, x):
        try:
            a_ = x[0]
            b_ = x[1]
            c_ = x[2]
            d_ = x[3]
            e_ = x[4]
            f_ = x[5]
            g_ = x[6]
            h_ = x[7]

            total = (
                    torch.distributions.Normal(0, 1).log_prob(a_) +
                    torch.distributions.Laplace(a_ ** 2, torch.sqrt(torch.exp(a_))).log_prob(b_) +
                    self.Logistic(a_, math.sqrt(0.01)).log_prob(c_) +
                    torch.distributions.Normal((b_ + c_) / 2, torch.sqrt(torch.log1p(torch.exp(c_)))).log_prob(d_) +
                    torch.distributions.Normal(d_ ** 2 / 4, math.sqrt(0.01)).log_prob(e_) +
                    self.Logistic(-3, torch.sqrt(torch.exp(b_))).log_prob(f_ / 200) +
                    torch.distributions.Normal(f_ ** 2 / 10, torch.sqrt(torch.exp(f_ / 10))).log_prob(g_) +
                    torch.distributions.Laplace(c_ * e_ / 10, torch.sqrt(torch.square(e_ / 10))).log_prob(h_)
            )
            return total
        except ValueError:
            return torch.tensor(torch.nan)


class TestDistribution3(CustomDistribution):
    def __init__(self):
        super().__init__()
        self.d0 = Mixture(
            mix=torch.tensor([0.5, 0.5]),
            components=[
                CorrelatedGaussian(n_dim=1, mean=torch.tensor([-3.0]), scale_tril=torch.tensor([[1.0]])),
                CorrelatedGaussian(n_dim=1, mean=torch.tensor([3.0]), scale_tril=torch.tensor([[1.0]])),
            ]
        )
        self.sigma1 = torch.tensor(0.05)
        self.n_dim = 2
        # x1 ~ N(1 / x0, 1)

    def sample_instance(self):
        x0 = self.d0.sample_instance()
        x1 = torch.randn(1) * self.sigma1 + 1 / x0
        return torch.tensor([x0, x1])

    def log_prob_instance(self, x):
        try:
            return self.d0.log_prob_instance(x[0]) + gaussian_log_prob(x[1], 1 / x[0], self.sigma1)
        except ValueError:
            return torch.tensor(torch.nan)


class TestDistribution4(CustomDistribution):
    def __init__(self):
        super().__init__()
        # The first five dimensions are a multivariate correlated Gaussian.
        # The last five dimensions are the funnel that depends on the first five.
        self.n_dim = 10
        self.d0 = CorrelatedGaussian(n_dim=5)

    def sample_instance(self):
        mvt_gaussian_part = self.d0.sample_instance()
        funnel_part = torch.randn(5) * torch.exp(mvt_gaussian_part / 2)
        return torch.cat([mvt_gaussian_part, funnel_part])

    def log_prob_instance(self, x):
        try:
            mvt_gaussian_part = self.d0.log_prob_instance(x[:5])
            funnel_part_0 = gaussian_log_prob(x[5], sigma=torch.exp(x[0] / 2)).sum()
            funnel_part_1 = gaussian_log_prob(x[6], sigma=torch.exp(x[1] / 2)).sum()
            funnel_part_2 = gaussian_log_prob(x[7], sigma=torch.exp(x[2] / 2)).sum()
            funnel_part_3 = gaussian_log_prob(x[8], sigma=torch.exp(x[3] / 2)).sum()
            funnel_part_4 = gaussian_log_prob(x[9], sigma=torch.exp(x[4] / 2)).sum()
            return mvt_gaussian_part + funnel_part_0 + funnel_part_1 + funnel_part_2 + funnel_part_3 + funnel_part_4
        except ValueError:
            return torch.tensor(torch.nan)


class DoubleGaussian(Mixture):
    def __init__(self, n_dim=2):
        super().__init__(
            mix=torch.tensor([0.5, 0.5]),
            components=[
                CorrelatedGaussian(n_dim=n_dim, mean=torch.tensor([-3.0] * n_dim), scale_tril=torch.eye(n_dim)),
                CorrelatedGaussian(n_dim=n_dim, mean=torch.tensor([3.0] * n_dim), scale_tril=torch.eye(n_dim)),
            ]
        )


class TenComponentGaussian(Mixture):
    def __init__(self, n_dim=2):
        locs = np.cumsum(np.ones(10) * 3) - 15.0
        super().__init__(
            mix=torch.ones(10) / 10.0,
            components=[
                CorrelatedGaussian(n_dim=n_dim, mean=torch.tensor([float(loc)] * n_dim), scale_tril=torch.eye(n_dim))
                for loc in locs
            ]
        )


if __name__ == '__main__':
    import seaborn as sns
    import matplotlib.pyplot as plt
    import pandas as pd

    torch.manual_seed(0)

    # dist = TestDistribution1()
    # dist = TestDistribution2()
    # dist = TestDistribution3()
    dist = TestDistribution4()
    # dist = Funnel()
    # dist = CorrelatedGaussian()
    # dist = GaussianMixture()

    samples = dist.sample(1000)
    print(samples.shape)
    print(dist.log_prob(samples).shape)

    sns.pairplot(pd.DataFrame(samples.numpy()), corner=True)
    plt.show()
