import torch

from .double_well import DoubleWell
from .target_distribution import TargetDistribution


class ManyWell(TargetDistribution):
    """Many Well target distribution create by repeating the Double Well Boltzmann distribution."""

    def __init__(
        self,
        dim=4,
    ) -> None:
        super().__init__()
        assert dim % 2 == 0
        self.n_wells = dim // 2
        self._dim = dim
        self.centre = 1.7
        # otherwise we get memory issues on huuuuge test set
        self.max_dim_for_all_modes = 40
        if self.dim < self.max_dim_for_all_modes:
            dim_1_vals_grid = torch.meshgrid([torch.tensor([-self.centre, self.centre])for _ in range(self.n_wells)])
            dim_1_vals = torch.stack([torch.flatten(dim)
                                     for dim in dim_1_vals_grid], dim=-1)
            n_modes = 2**self.n_wells
            assert n_modes == dim_1_vals.shape[0]
            test_set = torch.zeros((n_modes, dim))
            test_set[:, torch.arange(dim) % 2 == 0] = dim_1_vals
            self.register_buffer("_test_set_modes", test_set)
        else:
            print("using test set containing not all modes to prevent memory issues")

        self.shallow_well_bounds = [-1.75, -1.65]
        self.deep_well_bounds = [1.7, 1.8]

        self.double_well = DoubleWell()

    @property
    def dim(self):
        return self._dim

    @property
    def log_Z(self):
        return torch.tensor(self.double_well.log_Z * self.n_wells)

    @property
    def Z(self):
        return torch.exp(self.log_Z)

    def sample(self, num_samples: int) -> torch.Tensor:
        """Sample by sampling each pair of dimensions from the double well problem
        using rejection sampling for the first dimension, and exact sampling for the second. """
        return torch.concat([self.double_well.sample(num_samples) for _ in range(self.n_wells)], dim=-1)

    def log_prob(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the log probability of the Many Well distribution by summing the log probabilities

        Arguments:
            x: (batch_size, dim) tensor of samples
        
        Returns:
            log_prob: (batch_size, 1) tensor of log probabilities
        """

        batch_size = x.shape[0]

        x_reshaped = x.view(-1, 2)

        log_prob_pairs = self.double_well.log_prob(x_reshaped) # (batch_size * n_wells, 1)
        log_prob = log_prob_pairs.view(batch_size, self.n_wells).sum(dim=-1) # (batch_size,)

        return log_prob.unsqueeze(-1)

    def log_prob_2D(self, x):
        return self.double_well.log_prob(x)

    def log_prob_marginal_pair(self, x_2d, i, j):
        x = torch.zeros((x_2d.shape[0], self._dim))
        x[:, i] = x_2d[:, 0]
        x[:, j] = x_2d[:, 1]
        return self.log_prob(x)
