import torch.nn as nn
import torch
import torch.distributions as D
import torch.nn.functional as F

from nfmc_jax.utils.torch_distributions import gaussian_log_prob


def create_parameter_network(n_features, n_parameters):
    class ParameterNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(n_features * 2, n_parameters, bias=False)

        def forward(self, x_cond):
            # Polynomial features, no interactions.
            polynomial_features = [x_cond ** k for k in [1.0, 2.0]]

            inputs = torch.cat(polynomial_features, dim=1)
            return self.linear(inputs)

        def regularization(self):
            return self.linear.weight.abs().mean()  # L1

    return ParameterNetwork()


class Isotropic(nn.Module):
    def __init__(self, n_dim):
        super().__init__()
        self.n_dim = n_dim
        self.mu = nn.Parameter(torch.zeros(n_dim, 1))
        self.log_sigma = nn.Parameter(torch.zeros(n_dim, 1))

    def regularization(self):
        return torch.tensor(0.0)

    @staticmethod
    def static_log_prob(x, mu, log_sigma):
        return gaussian_log_prob(x, mu, log_sigma.exp()).sum(dim=-1)

    @staticmethod
    def static_sample(n, mu, log_sigma, n_dim):
        x = torch.randn(n, n_dim) * log_sigma.exp() + mu
        return x

    def log_prob(self, x):
        return self.static_log_prob(x, self.mu, self.log_sigma)

    def sample(self, n):
        return self.static_sample(n, self.mu, self.log_sigma, self.n_dim)


class ConditionalIsotropic(nn.Module):
    def __init__(self, n_dim, n_dim_cond):
        super().__init__()

        self.n_dim = n_dim
        self.n_dim_cond = n_dim_cond

        self.network = create_parameter_network(n_dim_cond, 2 * n_dim)

    def regularization(self):
        return self.network.regularization()

    def log_prob(self, x, x_cond):
        params = self.network(x_cond)
        mu = params[..., :self.n_dim].reshape(-1, self.n_dim)
        log_sigma = params[..., self.n_dim:].reshape(-1, self.n_dim)
        return Isotropic.static_log_prob(x, mu, log_sigma)

    def sample(self, x_cond):
        params = self.network(x_cond)
        mu = params[..., :self.n_dim].reshape(-1, self.n_dim)
        log_sigma = params[..., self.n_dim:].reshape(-1, self.n_dim)
        n = len(x_cond)
        return Isotropic.static_sample(n, mu, log_sigma, self.n_dim)


class Correlated(nn.Module):
    def __init__(self, n_dim):
        super().__init__()
        self.n_dim = n_dim
        self.mu = nn.Parameter(torch.zeros(n_dim))
        self.scale_tril_elements_below_diag = nn.Parameter(torch.zeros(int(n_dim * (n_dim - 1) / 2)))
        self.scale_tril_elements_diag_log = nn.Parameter(torch.zeros(n_dim))

    def mvt_normal_parameters(self):
        scale_tril = torch.tril(torch.zeros(self.n_dim, self.n_dim))
        i = 0
        for col in range(self.n_dim):
            for row in range(col - 1):
                scale_tril[row, col] = self.scale_tril_elements_below_diag[i]
                i += 1
        scale_tril[range(self.n_dim), range(self.n_dim)] = self.scale_tril_elements_diag_log.exp()
        return self.mu, scale_tril

    @staticmethod
    def static_log_prob(x, mean, scale_tril):
        return D.MultivariateNormal(loc=mean, scale_tril=scale_tril).log_prob(x)

    def log_prob(self, x):
        mean, scale_tril = self.mvt_normal_parameters()
        return self.static_log_prob(x, mean, scale_tril)

    @staticmethod
    def static_sample(n, mean, scale_tril):
        return D.MultivariateNormal(loc=mean, scale_tril=scale_tril).sample((n,))

    def sample(self, n):
        mean, scale_tril = self.mvt_normal_parameters()
        return self.static_sample(n, mean, scale_tril)


class ConditionalCorrelated(nn.Module):
    def __init__(self, n_dim, n_dim_cond):
        super().__init__()

        self.n_dim = n_dim
        self.n_dim_cond = n_dim_cond

        self.network = create_parameter_network(n_dim_cond, 2 * n_dim + int(n_dim * (n_dim - 1) / 2))

    def get_parameters(self, x_cond):
        params = self.network(x_cond)
        mu = params[..., :self.n_dim]
        scale_tril_elements_diag_log = params[..., self.n_dim:2 * self.n_dim]
        scale_tril_elements_below_diag = params[..., 2 * self.n_dim:]
        scale_tril = torch.tril(torch.zeros(self.n_dim, self.n_dim))
        i = 0
        for col in range(self.n_dim):
            for row in range(col - 1):
                scale_tril[row, col] = scale_tril_elements_below_diag[i]
                i += 1
        scale_tril[range(self.n_dim), range(self.n_dim)] = scale_tril_elements_diag_log.exp()
        return mu, scale_tril

    def log_prob(self, x, x_cond):
        mu, scale_tril = self.get_parameters(x_cond)
        return Correlated.static_log_prob(x, mu, scale_tril)

    def sample(self, x_cond):
        mu, scale_tril = self.get_parameters(x_cond)
        n = len(x_cond)
        return Correlated.static_sample(n, mu, scale_tril)


class IsotropicMixture(nn.Module):
    def __init__(self, n_dim, n_components):
        super().__init__()
        self.n_dim = n_dim
        self.n_components = n_components

        self.mixing_logits = nn.Parameter(torch.zeros(n_components))  # Need to apply softmax
        self.mu_base = nn.Parameter(torch.zeros(n_components, n_dim))
        self.log_sigma_base = nn.Parameter(torch.zeros(n_components, n_dim))

    @staticmethod
    def static_log_prob(x, mus, sigmas, mixing_logits):
        component_probs = F.softmax(mixing_logits, dim=-1)
        log_probs_stack = torch.stack([gaussian_log_prob(x, mu, sigma).sum(dim=-1) for mu, sigma in zip(mus, sigmas)])
        log_mix_stack = torch.log(torch.hstack([component_probs.reshape(-1, 1) for _ in range(len(x))]))
        return torch.logsumexp(log_mix_stack + log_probs_stack, dim=0)

    def component_parameters(self):
        mu = torch.cumsum(self.mu_base, dim=0)
        log_sigma = torch.cumsum(self.log_sigma_base, dim=0)
        return mu, log_sigma.exp()

    def log_prob(self, x: torch.Tensor):
        mus, sigmas = self.component_parameters()
        return self.static_log_prob(x, mus, sigmas, self.mixing_logits)

    @staticmethod
    def static_sample(n, mus, sigmas, mixing_logits, n_dim):
        samples = torch.zeros(n, n_dim)
        for i in range(n):
            component_index = D.Categorical(logits=mixing_logits).sample()
            mean = mus[component_index]
            sigma = sigmas[component_index]

            x = torch.randn(1, n_dim)
            for i in range(len(mean)):
                x[:, i] = x[:, i] * sigma[:, i] + mean[:, i]
            samples[i] = x
        return samples

    def sample(self, n):
        mus, sigmas = self.component_parameters()
        return self.static_sample(n, mus, sigmas, self.mixing_logits, self.n_dim)


class ConditionalIsotropicMixture(nn.Module):
    def __init__(self, n_dim, n_components, n_dim_cond):
        super().__init__()
        self.n_dim = n_dim
        self.n_components = n_components

        self.network = create_parameter_network(n_dim_cond, 2 * n_dim * n_components + n_components)

    def get_parameters(self, x_cond):
        params = self.network(x_cond)
        mu_base = params[:self.n_dim * self.n_components]
        log_sigma_base = params[self.n_dim * self.n_components: 2 * self.n_dim * self.n_components]
        mixing_params = params[2 * self.n_dim * self.n_components:]

        mu_base = mu_base.reshape(self.n_components, self.n_dim)
        log_sigma_base = log_sigma_base.reshape(self.n_components, self.n_dim)

        mu = torch.cumsum(mu_base, dim=0)
        log_sigma = torch.cumsum(log_sigma_base, dim=0)

        return mu, log_sigma.exp(), mixing_params

    def log_prob(self, x: torch.Tensor, x_cond):
        mus, sigmas, mixing_logits = self.get_parameters(x_cond)
        return IsotropicMixture.static_log_prob(x, mus, sigmas, mixing_logits)

    def sample(self, x_cond):
        mus, sigmas, mixing_logits = self.get_parameters(x_cond)
        n = len(x_cond)
        return IsotropicMixture.static_sample(n, mus, sigmas, mixing_logits, self.n_dim)


class CorrelatedMixture(nn.Module):
    def __init__(self, n_dim, n_components):
        super().__init__()
        self.n_dim = n_dim
        self.n_components = n_components

        self.mixing_logits = nn.Parameter(torch.zeros(n_components))  # Need to apply softmax
        self.mu_base = nn.Parameter(torch.zeros(n_components, n_dim))
        self.scale_tril_elements_below_diag_base = nn.Parameter(torch.zeros(n_components, int(n_dim * (n_dim - 1) / 2)))
        self.scale_tril_elements_diag_log_base = nn.Parameter(torch.zeros(n_components, n_dim))

    @staticmethod
    def static_log_prob(x, mus, scales_tril, mixing_logits):
        component_probs = F.softmax(mixing_logits, dim=-1)
        log_probs_stack = torch.stack([
            D.MultivariateNormal(loc=mu, scale_tril=scale_tril).log_prob(x) for mu, scale_tril in zip(mus, scales_tril)
        ])
        log_mix_stack = torch.log(torch.hstack([component_probs.reshape(-1, 1) for _ in range(len(x))]))
        return torch.logsumexp(log_mix_stack + log_probs_stack, dim=0)

    @staticmethod
    def static_sample(n, mus, scale_trils, mixing_logits, n_dim):
        samples = torch.zeros(n, n_dim)
        for i in range(n):
            component_index = D.Categorical(logits=mixing_logits).sample()
            mean = mus[component_index]
            scale_tril = scale_trils[component_index]
            samples[i] = D.MultivariateNormal(loc=mean, scale_tril=scale_tril).sample((n,))
        return samples

    def component_parameters(self):
        mu = torch.cumsum(self.mu_base, dim=0)

        scale_trils = torch.zeros(self.n_components, self.n_dim, self.n_dim)
        for component in range(self.n_components):
            i = 0
            for col in range(self.n_dim):
                for row in range(col - 1):
                    scale_trils[component, row, col] = self.scale_tril_elements_below_diag_base[component, i]
                    i += 1
            scale_trils[component, range(self.n_dim), range(self.n_dim)] = self.scale_tril_elements_diag_log_base[
                component].exp()
        scale_trils = torch.cumsum(scale_trils, dim=0)
        return mu, scale_trils

    def log_prob(self, x: torch.Tensor):
        mus, scales_tril = self.component_parameters()
        return self.static_log_prob(x, mus, scales_tril, self.mixing_logits)

    def sample(self, n):
        mus, scale_trils = self.component_parameters()
        return self.static_sample(n, mus, scale_trils, self.mixing_logits, self.n_dim)


class ConditionalCorrelatedMixture(nn.Module):
    def __init__(self, n_dim, n_components, n_dim_cond):
        super().__init__()
        self.n_dim = n_dim
        self.n_components = n_components

        # n_parameters = n_components + n_components * n_dim + n_components * int(n_dim * (n_dim - 1) / 2) + n_components * n_dim
        self.network = create_parameter_network(
            n_dim_cond,
            n_components * (1 + n_dim + int(n_dim * (n_dim - 1) / 2) + n_dim)
        )

    def get_parameters(self, x_cond):
        params = self.network(x_cond)
        delims = torch.cumsum(torch.tensor([
            self.n_components,
            self.n_components * self.n_dim,
            self.n_components * int(self.n_dim * (self.n_dim - 1) / 2),
            self.n_components * self.n_dim,
        ]), dim=-1)

        mixing_logits = params[:, :delims[0]]
        mu_base = params[:, delims[0]:delims[1]].reshape(self.n_components, self.n_dim)
        scale_tril_elements_below_diag_base = params[:, delims[1]:delims[2]].reshape(self.n_components,
                                                                                     int(self.n_dim * (
                                                                                             self.n_dim - 1) / 2))
        scale_tril_elements_diag_log_base = params[:, delims[2]:delims[3]].reshape(self.n_components, self.n_dim)

        mu = torch.cumsum(mu_base, dim=0)

        scale_trils = torch.zeros(self.n_components, self.n_dim, self.n_dim)
        for component in range(self.n_components):
            i = 0
            for col in range(self.n_dim):
                for row in range(col - 1):
                    scale_trils[component, row, col] = scale_tril_elements_below_diag_base[component, i]
                    i += 1
            scale_trils[component, range(self.n_dim), range(self.n_dim)] = scale_tril_elements_diag_log_base[
                component].exp()
        scale_trils = torch.cumsum(scale_trils, dim=0)
        return mu, scale_trils, mixing_logits

    def log_prob(self, x: torch.Tensor, x_cond):
        mus, scale_trils, mixing_logits = self.get_parameters(x_cond)
        return CorrelatedMixture.static_log_prob(x, mus, scale_trils, mixing_logits)

    def sample(self, x_cond):
        mus, scale_trils, mixing_logits = self.get_parameters(x_cond)
        n = len(x_cond)
        return CorrelatedMixture.static_sample(n, mus, scale_trils, mixing_logits, self.n_dim)
