
import torch



# GM2 base mixture (2D)
weights_gm2 = torch.tensor([0.2, 0.8])

means_gm2 = torch.tensor([
    [0.0, 0.0],
    [20.0, 20.0]
])
covs_gm2 = torch.tensor([
    [[1.0, 0.0],
     [0.0, 1.0]],
    [[10.0, -4.0],
     [-4.0,  3.0]]
])

#GM4
weights_gm4 = torch.ones(4) / 4 # shape [K] = [4] / content  : [0.25, 0.25, 0.25, 0.25]

means_gm4 = torch.tensor([ #shape [K=4,2]
    [-10.0,  10.0],
    [ 10.0, -10.0],
    [ 15.0,  15.0],
    [-15.0, -15.0]
])

covs_gm4 = torch.stack([ #shape [K=2,2,2]
    torch.tensor([[3.0, 4.0],
                  [4.0,10.0]])
    for _ in range(4)
])

#GM25
weights_gm25 = torch.ones(25) / 25

means_gm25 = torch.stack([
    torch.tensor([5.0 * i, 5.0 * j])
    for i in range(5)
    for j in range(5)
])

covs_gm25 = torch.stack([
    0.25 * torch.eye(2)
    for _ in range(25)
])




class TensorizedGaussianMixture:
    """
    High-dimensional target distribution defined as a tensor product
    of identical 2D Gaussian mixtures.
    π_d(x) = ∏_{j=1}^{d/2} \~{π}(x^{(j)}),
    x^{(j)} ∈ ℝ² follows the same 2D Gaussian mixture:
        \~{π}(u) = ∑_{k=1}^K w_k N(u | m_k, Σ_k).
    """

    def __init__(self, base_means, base_covs, base_weights, d):
        """
        base_means   : Tensor [K, 2]    : means m_k of the 2D Gaussian Mixture components
        base_covs    : Tensor [K, 2, 2] : Covariance matrices Σ_k of the mixture components.
        base_weights : Tensor [K]       : Mixture weights w_k, summing to 1.
        d            : int              : dimension space (must be even)
        """
        self.base_means = base_means          # [K,2]
        self.base_covs = base_covs            # [K,2,2]
        self.base_weights = base_weights      # [K]
        self.K = len(base_weights)
        self.d = d
        self.n_blocks = d // 2

        # Torch multivariate normal for all mixture components
        # This represents {N(· | m_k, Σ_k)} where k ∈ {1,,K}

        self.mvns = torch.distributions.MultivariateNormal(
            base_means, base_covs
        )

    def logpi(self, x):
        """
        Compute log π_d(x) for a batch of points.
        Params : 
            x : Tensor [N,d] : natch of N points in ℝ^d
        Returns : 
            log_pi : Tensor [N] : Log-density log π_d(x) evaluated at each input point.

        - x is reshaped into n_blocks = d/2 independent 2D blocks:
              x → (x^{(1)}, ..., x^{(d/2)}), with x^{(j)} ∈ ℝ²
        - π_d(x) = ∏_{j=1}^{d/2} ~{π}(x^{(j)}) where ~{π} is the base 2D density
        - For each block, we compute : log ~{π}(x^{(j)}) = log ∑_k w_k N(x^{(j)} | m_k, Σ_k)
        - The full log-density :
              log π_d(x) = ∑_{j=1}^{d/2} log ~{π}(x^{(j)}).
        
        """

        # x: [N, d]
        x = x.view(x.shape[0], self.n_blocks, 2) #shape [N,d/2,2]
        logp = 0.0 # accumulator for log π_d(x)
        for b in range(self.n_blocks): #loop over iid blocks
            xb = x[:, b, :]                         # [N,2] : extract the block b
            xb = xb.unsqueeze(1)         # [N, 1, 2]
            # Compute log w_k + log N(xb | m_k, Σ_k) 
            # log N_k(x_b) for all k
            log_gauss = self.mvns.log_prob(xb)  # [N, K]

            log_comps = (
                torch.log(self.base_weights).view(1, self.K)
                + log_gauss
            )  # [N, K]
            logp += torch.logsumexp(log_comps, dim=1) #sum over blocks
        return logp

    def grad_logpi(self, x):
        """
        Compute the gradient ∇_x log π_d(x).
        Params : 
            x : Tensor [N, d]     : Batch of input points.
        Returns :
            grad : Tensor [N, d]  : Gradient of the log-density with respect to x.
     
        - log π_d(x) = ∑_{j=1}^{d/2} log ~{π}(x^{(j)}),
         so the gradient decomposes blockwise:
            ∇_x log π_d(x) = (∇_{x^{(1)}} log ~{π}, ..., ∇_{x^{(d/2)}} log ~{π}).
        """
        x.requires_grad_(True)  # Enable gradient tracking on x
        logp = self.logpi(x) # [N] 
        grad = torch.autograd.grad(logp.sum(), x)[0]  # [N, d]
        return grad
    
    def sample(self, N):
        """
        Sample N points from the tensorized Gaussian mixture π_d.

        Returns:
            x : Tensor [N, d]
        """
        # Sample mixture component indices for each block and each sample
        # Shape: [N, n_blocks]
        comp_idx = torch.distributions.Categorical(self.base_weights).sample(
            (N, self.n_blocks)
        )

        # Allocate output
        x = torch.zeros(N, self.d, device=self.base_means.device)

        # For each block, sample from the selected 2D Gaussian
        for b in range(self.n_blocks):
            idx_b = comp_idx[:, b]           # [N]
            means_b = self.base_means[idx_b] # [N,2]
            covs_b = self.base_covs[idx_b]   # [N,2,2]

            mvn = torch.distributions.MultivariateNormal(means_b, covs_b)
            samples_b = mvn.sample()         # [N,2]

            x[:, 2*b:2*b+2] = samples_b

        return x


def build_target(model_name, d):
    if model_name == "GM2":
        base_means = means_gm2
        base_covs = covs_gm2
        base_weights = weights_gm2
        k0 = 2

    elif model_name == "GM4":
        base_means = means_gm4
        base_covs = covs_gm4
        base_weights = weights_gm4
        k0 = 4

    elif model_name == "GM25":
        base_means = means_gm25
        base_covs = covs_gm25
        base_weights = weights_gm25
        k0 = 25

    else:
        raise ValueError("Unknown model")

    target = TensorizedGaussianMixture(
        base_means=base_means,
        base_covs=base_covs,
        base_weights=base_weights,
        d=d
    )

    return target, k0


