import pyro
import torch
from pyro.infer.mcmc import NUTS, MCMC


class NUTSSampler:
    """
    NUTS baseline using Pyro.
    Gradient-based oracle competitor.
    """

    def __init__(
        self,
        num_samples: int,
        warmup_steps: int = 1000,
        max_tree_depth: int = 5,
        target_accept_prob: float = 0.8,
        device="cpu",
    ):
        """
        Args:
            num_samples        : number of retained samples
            warmup_steps       : warmup / adaptation steps
            max_tree_depth     : NUTS tree depth
            target_accept_prob : target acceptance probability
            device             : cpu / cuda
        """
        self.num_samples = num_samples
        self.warmup_steps = warmup_steps
        self.max_tree_depth = max_tree_depth
        self.target_accept_prob = target_accept_prob
        self.device = device

    def run(
        self,
        target,
        init_proposal,
    ):
        """
        Run NUTS and return samples.

        Returns:
            x_nuts : Tensor [num_samples, d]
        """

        pyro.clear_param_store()

        # Initial state
        z0 = init_proposal.sample(1).squeeze(0).to(self.device)

        def potential_fn(params):
            x = params["z"]
            return -target.logpi(x.unsqueeze(0)).sum()

        nuts_kernel = NUTS(
            potential_fn=potential_fn,
            adapt_step_size=True,
            target_accept_prob=self.target_accept_prob,
            max_tree_depth=self.max_tree_depth,
        )

        mcmc = MCMC(
            nuts_kernel,
            num_samples=self.num_samples,
            warmup_steps=self.warmup_steps,
            initial_params={"z": z0},
        )

        mcmc.run()
        return mcmc.get_samples()["z"]
