import torch
import mcmc
import utils

from sklearn.mixture import GaussianMixture
from typing import List



class FrozenGMM:
    """
    Frozen (non-trainable) Gaussian Mixture Model with diagonal covariance.

    This class represents a fixed proposal μ_t used inside EM2C iterations.


    Mathematical form:
        μ_t(x) = ∑_{k=1}^K w_k · N(x | m_k, diag(σ_k^2))
    """

    def __init__(self, weights, means, log_stds):
        # Detach everything explicitly to prevent gradient flow
        self.weights = weights.clone().detach().float()
        self.means   = means.clone().detach().float()
        self.log_stds = log_stds.clone().detach().float()

        self.K, self.d = self.means.shape

        # Precompute standard deviations for efficiency
        self.stds = torch.exp(self.log_stds).clamp(min=1e-3)

    def log_mu(self, x):
        """
        Compute log μ_t(x) 

        Params
            x : Tensor [N, d]

        Returns
            Tensor [N] :  Log-density of the frozen mixture at x.
        """

        # Shape alignment
        x = x.unsqueeze(1)                # [N, 1, d]
        means = self.means.unsqueeze(0)   # [1, K, d]
        stds  = self.stds.unsqueeze(0)    # [1, K, d]

        # Mahalanobis term ||(x − m_k) / σ_k||²
        quad = ((x - means) / stds).pow(2).sum(dim=-1)  # [N, K]

        # Gaussian log-normalizer
        log_det = torch.log(stds).sum(dim=-1)           # [1, K]
        log_norm = log_det + 0.5 * self.d * torch.log(
            torch.tensor(2 * torch.pi)
        )

        # Component log-densities
        log_gauss = -0.5 * quad - log_norm               # [N, K]

        # Mixture marginalization
        return torch.logsumexp(
            torch.log(self.weights).view(1, -1) + log_gauss,
            dim=1
        )

    def sample(self, N):
        """
        Draw N i.i.d. samples from μ_t.

        Sampling procedure:
            1. sample component indices k ~ Categorical(w)
            2. sample x ~ N(m_k, diag(σ_k²))

        Params :
            N : int : Number of samples.
        Returns :
            Tensor [N, d] : Samples drawn from the frozen mixture.
        """

        # Sample mixture component indices
        comp = torch.distributions.Categorical(self.weights).sample((N,)) # [N]

        # Sample from the selected Gaussian components
        eps = torch.randn(N, self.d, device=self.means.device) # i.i.d. standard normal noise [N, d]
        samples = self.means[comp] + eps * self.stds[comp] # [N,d]

        return samples

class FrozenGMMFull:
    """
    Gaussian Mixture Model with full covariance.
    """
    def __init__(self, weights, means, covs):
        self.weights = weights
        self.means = means
        self.covs = covs
        self.K = weights.shape[0]

    def log_mu(self, x):
        log_probs = []
        for k in range(self.K):
            mvn = torch.distributions.MultivariateNormal(
                self.means[k], self.covs[k]
            )
            log_probs.append(
                torch.log(self.weights[k]) + mvn.log_prob(x)
            )
        return torch.logsumexp(torch.stack(log_probs, dim=0), dim=0)

    def sample(self, N):
        comps = torch.distributions.Categorical(self.weights).sample((N,))
        samples = torch.zeros(N, self.means.shape[1])
        for k in range(self.K):
            idx = comps == k
            if idx.any():
                mvn = torch.distributions.MultivariateNormal(
                    self.means[k], self.covs[k]
                )
                samples[idx] = mvn.sample((idx.sum(),))
        return samples

#different blocks
class FrozenTensorizedGMM:
    """
    μ(x) = ∏_{j=1}^{d/2} μ^{(j)}(x^{(j)})
    where each μ^{(j)} is a 2D Gaussian mixture.
    """

    def __init__(self, block_models):
        """
        block_models: list of FrozenGMMFull (2D), length = d/2
        """
        self.block_models = block_models
        self.n_blocks = len(block_models)
        self.d = 2 * self.n_blocks

    def log_mu(self, x):
        # x: [N, d]
        x = x.view(x.shape[0], self.n_blocks, 2)
        logp = 0.0
        for j, model in enumerate(self.block_models):
            logp += model.log_mu(x[:, j, :])
        return logp

    def sample(self, N):
        samples = []
        for model in self.block_models:
            samples.append(model.sample(N))  # [N,2]
        return torch.cat(samples, dim=1)    # [N,d]
    


# fit by sklearn EM algo any dim
def fit_gmm_sklearn(
    Z,
    K,
    random_state=0,
    max_iter=500,
    reg_covar=1e-3,
):
    """
    Fit a GMM using scikit-learn and convert it to FrozenGMM.
    
    Z : torch.Tensor [N, d]
    covariance_type : "diag" or "full"
    """
    Z_np = Z.detach().cpu().numpy()

    gmm = GaussianMixture(
        n_components=K,
        covariance_type="full", # or diag
        init_params="k-means++", #kmeans #k-means++
        max_iter=max_iter,
        random_state=random_state,
        reg_covar=reg_covar,
        n_init=3
    )
    

    gmm.fit(Z_np)

    weights = torch.tensor(gmm.weights_, dtype=torch.float32)
    means = torch.tensor(gmm.means_, dtype=torch.float32)
    covs = torch.tensor(gmm.covariances_, dtype=torch.float32)

    proposal = FrozenGMMFull(
        weights,
        means,
        covs,
    )

    return proposal

#calls previous function by 2D blocks of Z
def fit_tensorized_gmm_by_em(
    Z,
    d,
    K_2d, #K0
    random_state=0,
):
    """
    Tensorized EM projection.

    Z : Tensor [N, d]
    Returns : FrozenTensorizedGMM
    """
    N = Z.shape[0]
    n_blocks = d // 2
    Z = Z.view(N, n_blocks, 2)

    block_models = []

    for j in range(n_blocks):
        Z_j = Z[:, j, :]  # [N,2]

        block_model = fit_gmm_sklearn(
            Z_j,
            K=K_2d,
            random_state=random_state,
        )
        block_models.append(block_model)

    return FrozenTensorizedGMM(block_models)


def run_e2mc_gmm(
    *,
    target,
    init_proposal,
    K_config: mcmc.MCMCConfig,
    L_config: mcmc.MCMCConfig,
    N: int,
    T: int,
    eps: float,
    lamda: float,
    K0: int,
):
    """
    Run E2MC on tensorized GMM targets.

    Returns:
        x_final
        final_proposal
        proposal_history
    """
    print("use lamda : ",lamda)
    proposal = init_proposal
    x = proposal.sample(N)

    proposal_history = [proposal]

    for _ in range(T):

        # -------------------------
        # Exploration (kernel K)
        # -------------------------
        y = mcmc.mcmc(x, target.logpi, K_config)[-1]

        # -------------------------
        # Importance weights
        # -------------------------
        w_x = utils.compute_weights(
            x, target.logpi, proposal.log_mu, eps
        )
        w_y = utils.compute_weights(
            y, target.logpi, proposal.log_mu, eps
        )

        # -------------------------
        # Mixture resampling
        # -------------------------
        Z = utils.sample_from_mixture(x, w_x, y, w_y, lamda)

        # -------------------------
        # Move step (kernel L)
        # -------------------------
        Z = mcmc.mcmc(Z, target.logpi, L_config)[-1].detach()

        # -------------------------
        # EM projection
        # -------------------------
        proposal = fit_tensorized_gmm_by_em(
            Z=Z,
            d=Z.shape[1],
            K_2d=K0,
        )

        proposal_history.append(proposal)
        x = proposal.sample(N)

    return x, proposal, proposal_history
