
import torch
import mcmc


class AISSampler:
    """
    Annealed Importance Sampling competitor.
    """

    def __init__(
        self,
        mcmc_config: mcmc.MCMCConfig,
        n_intermediate: int = 50,
        device="cpu",
    ):
        """
        Args:
            mcmc_config   : MCMCConfig (RW / MALA / ULA)
                            Same as RW competitor
                            
            n_intermediate: number of annealing steps
            device        : cpu / cuda
        """
        self.mcmc_config = mcmc_config
        self.n_intermediate = n_intermediate
        self.device = device

    def run(
        self,
        target,
        init_proposal,
        N: int,
    ):
        """
        Run AIS and return unweighted samples.

        Returns:
            x_ais_resampled : Tensor [N, d]
        """
        x, w = run_ais_generic(
            target=target,
            init_proposal=init_proposal,
            mcmc_config=self.mcmc_config,
            N=N,
            n_intermediate=self.n_intermediate,
            device=self.device,
        )

        return resample_from_weights(x, w, N)


def run_ais_generic(
    target,
    init_proposal,
    mcmc_config,
    N=5000,
    n_intermediate=50,
    device="cpu",
):
    """
    Canonical AIS with configurable MCMC kernel.
    Compatible with MCMCConfig as a regular class.
    """

    # Annealing schedule
    betas = torch.linspace(0.0, 1.0, n_intermediate + 1, device=device)

    # Initial particles
    x = init_proposal.sample(N).to(device)

    # Log weights
    logw = torch.zeros(N, device=device)

    for k in range(1, len(betas)):
        beta_prev = betas[k - 1]
        beta = betas[k]

        # ----- Weight update -----
        logpi_beta_prev = (
            beta_prev * target.logpi(x)
            + (1.0 - beta_prev) * init_proposal.log_prob(x)
        )

        logpi_beta = (
            beta * target.logpi(x)
            + (1.0 - beta) * init_proposal.log_prob(x)
        )

        logw += logpi_beta - logpi_beta_prev

        # ----- Build π_beta logpdf -----
        def logpdf_beta(z):
            return (
                beta * target.logpi(z)
                + (1.0 - beta) * init_proposal.log_prob(z)
            )

        # Override gradient
        if mcmc_config.grad_logpdf_fn is not None:

            def grad_logpdf_beta(z):
                z = z.requires_grad_(True)
                lp = logpdf_beta(z).sum()
                return torch.autograd.grad(lp, z)[0]

            mcmc_config_beta = mcmc.MCMCConfig(
                n_iter=mcmc_config.n_iter,
                kernel_fn=mcmc_config.kernel_fn,
                kernel_params=mcmc_config.kernel_params,
                grad_logpdf_fn=grad_logpdf_beta,
            )
        else:
            # RW case: no gradient
            mcmc_config_beta = mcmc_config

        # ----- MCMC move invariant for π_beta -----
        x = mcmc.mcmc(x, logpdf_beta, mcmc_config_beta)[-1]

    # Normalize weights
    logw = logw - logw.max()
    w = torch.exp(logw)
    w = w / w.sum()

    return x, w



def resample_from_weights(x, w, N):
    idx = torch.multinomial(w, N, replacement=True)
    return x[idx]
