import torch
import torch.nn as nn
from torch.distributions.beta import Beta
from torch.distributions.dirichlet import Dirichlet


class DirichletPrior(nn.Module):
    def __init__(self, benchmark, config, device):
        super().__init__()

        self.benchmark = benchmark
        self.prior = benchmark.get_prior()
        self.config = config
        self.bins = config["bins"]
        self.parameter_dim = benchmark.get_parameter_dim()
        parameter_ranges = benchmark.get_domain()
        self.LOWER = parameter_ranges[0].to(device)
        self.UPPER = parameter_ranges[1].to(device)
        self.device = device

        self.bin_log_volume = torch.sum(
            torch.log((self.UPPER - self.LOWER) / self.bins)
        )
        self.offset = (
            torch.LongTensor([self.bins**x for x in range(self.parameter_dim)])
            .unsqueeze(0)
            .to(device)
        )

        base_distribution = self.get_base_distribution()
        alphas = base_distribution * len(base_distribution) * config["concentration"]
        self.distribution = Dirichlet(alphas)
        alpha_sum = sum(alphas)
        self.marginal_distributions = [
            Beta(alpha, alpha_sum - alpha) for alpha in alphas
        ]

    def get_base_distribution(self):
        bin_probas = []
        for bin in range(self.bins**self.parameter_dim):
            theta = []
            for dim in range(self.parameter_dim):
                dim_bin = (bin // self.bins**dim) % self.bins
                bin_size = (self.UPPER[dim] - self.LOWER[dim]) / self.bins

                # Evaluate in the middle of the bin
                dim_theta_lower = self.LOWER[dim] + dim_bin * bin_size + bin_size / 2
                theta.append(dim_theta_lower)

            theta = torch.Tensor(theta)

            bin_probas.append(
                torch.exp(self.prior.log_prob(theta) + self.bin_log_volume)
            )

        bin_probas = torch.Tensor(bin_probas).to(self.device)

        # Should be close to normalized already but renormalize to account for
        # approximation.
        bin_probas = bin_probas / bin_probas.sum()

        return bin_probas

    def get_bin(self, theta):
        # Move to numbers between 0 and 1 where 0 correspond to min and 1 to max.
        theta = (theta - self.LOWER.unsqueeze(0)) / (
            self.UPPER.unsqueeze(0) - self.LOWER.unsqueeze(0)
        )

        # Multiply by the number of bins per dimension
        theta = (theta * self.bins).int()

        # Add offset to each axis
        theta = theta * self.offset

        # Sum axes.
        bins = theta.sum(dim=1)
        return bins

    def sample_distribution(self):
        return self.distribution.sample()

    def sample_functions(self, theta, x, n_samples):
        """Sample function ouputs.

        Parameters
        ----------
        theta : Tensor
            The simulator's parameters associated to the data points for which to sample functions.
        x : Tensor
            The observations associated to the data points for which to sample functions.
        n_samples : int
           The number of functions to sample.

        Returns
        -------
        Tensor
            log probabilities of shape [n_data, n_func, 1]
        """
        outputs = []

        for _ in range(n_samples):
            estimator = self.sample_distribution()

            bin_output = estimator[self.get_bin(theta)].log()
            output = bin_output - self.bin_log_volume
            outputs.append(output)

        outputs = torch.stack(outputs, dim=1)
        outputs = outputs.unsqueeze(dim=2)

        return outputs

    def functions_log_prob(self, theta, x, outputs):
        """Compute the log probability of functions 

        Parameters
        ----------
        theta : Tensor[n_data, parameter_dim]
            The simulator's parameters associated to the data points at which the functions are evaluated.
        x : Tensor[n_data, observation_dim]
            The observations associated to the data points at which the functions are evaluated.
        outputs : Tensor[n_data, n_func]
            The output probabilities for each data point 

        Returns
        -------
        Tensor[n_func]
            The log probabilities associated to each function.
        """
        
        log_probs = []
        for curr_theta, curr_outputs in zip(theta, outputs):
            # Change of variable from log proba on theta to proba on bin
            curr_outputs = torch.exp(curr_outputs + self.bin_log_volume)

            bin = self.get_bin(curr_theta)
            marginal_distribution = self.marginal_distributions[bin]
            curr_log_probs = marginal_distribution.log_prob(curr_outputs)

            # Jacobian of the above change of variable
            jacobian = curr_outputs
            
            log_probs.append(curr_log_probs + jacobian.log())

        log_probs = torch.stack(log_probs, dim=0).sum(dim=0)
        return log_probs
