import torch
import mcmc


class RandomWalkSampler:
    """
    Random-Walk Metropolis-Hastings sampler
    used as a baseline competitor.
    """

    def __init__(
        self,
        sigma: float,
        n_steps: int,
        device="cpu",
    ):
        """
        Args:
            sigma   : std of RW Gaussian kernel
            n_steps : number of MCMC steps
            device  : cpu / cuda
        """
        self.sigma = sigma
        self.n_steps = n_steps
        self.device = device

    def _make_mcmc_config(self):
        """
        Build MCMCConfig for RW kernel.
        """
        return mcmc.MCMCConfig(
            n_iter=self.n_steps,
            kernel_fn=mcmc.rw_kernel,
            kernel_params={
                "noise_dist": torch.distributions.Normal(
                    loc=0.0,
                    scale=self.sigma,
                )
            },
            grad_logpdf_fn=None,
        )

    def run(
        self,
        target,
        init_proposal,
        N: int,
    ):
        """
        Run RW-MH and return final samples.

        Args:
            target        : object with target.logpi(x)
            init_proposal : proposal with sample(N)
            N             : number of particles

        Returns:
            x_rw : Tensor [N, d]
        """

        # Initial samples
        x0 = init_proposal.sample(N).to(self.device)

        # RW configuration
        mcmc_config = self._make_mcmc_config()

        # Run MCMC
        traj = mcmc.mcmc(
            init_state=x0,
            logpdf_fn=target.logpi,
            config=mcmc_config,
        )

        # Return final state only
        return traj[-1].detach()
