import torch
from torch.nn import functional as F
from vti.distributions.uniform_bs import UniformBinaryString

DEBUGMODE=False

class PermutationDAGUniformDistribution(torch.nn.Module):
    def __init__(self, num_nodes, device, dtype):
        """
        Constructs prior over DAGs using
        permutation P and upper triangular binary U matrices.
        """
        super().__init__()
        assert isinstance(num_nodes, int), "ERROR: num_nodes must be integer"
        self.device = device
        self.dtype = dtype
        self.num_nodes = num_nodes
        self.U_features = int(num_nodes * (num_nodes - 1) // 2)
        self.P_features = int(num_nodes - 1)
        self.flat_U_dist = UniformBinaryString(self.U_features, device, dtype)
        # precompute the uniform categorical probs
        cat_log_probs = torch.zeros(
            (self.P_features,), device=self.device, dtype=self.dtype
        )
        offset = 0
        for i in range(self.P_features):
            length = self.num_nodes - i
            # clp = -torch.log(torch.tensor(length, dtype=self.dtype, device=self.device))
            cat_log_probs[i] = -torch.log(
                torch.tensor(length, dtype=self.dtype, device=self.device)
            )
        self.P_log_prob = cat_log_probs.sum()

    def log_prob(self, inputs):
        if DEBUGMODE:
            assert (
                inputs.shape[1] == self.P_features + self.U_features
            ), "Feature mismatch, expected {}, got {}".format(
                self.P_features + self.U_features, inputs.shape[1]
            )
        # the below would change for a different prior on the number of edges
        U_log_prob = self.flat_U_dist.log_prob(inputs[:, self.P_features :])
        # sum of log probs for U and P
        return U_log_prob + self.P_log_prob
