from .base import Algo
import torch
import tqdm
import numpy as np

class MetropolisHasting(Algo):

    def __init__(self, net, forward_op, mh_steps=1000, max_dist=1, device='cuda'):
        """
            Initializes the DAPS sampler with the given configurations.

            Parameters:
                annealing_scheduler_config (dict): Configuration for annealing scheduler.
                diffusion_scheduler_config (dict): Configuration for diffusion scheduler.
                lgvd_config (dict): Configuration for Langevin dynamics.
        """
        super().__init__(net=net, forward_op=forward_op)
        
        self.device = device
        self.mh_steps = mh_steps
        self.max_dist = max_dist

    @torch.no_grad()
    def inference(self, observation=None, num_samples=1, verbose=True):
        xt = torch.randint(self.net.dim, (num_samples, self.net.length), device=self.device)
        pbar = tqdm.trange(self.mh_steps) if verbose else range(self.mh_steps)
        current_log_likelihood = self.forward_op.log_likelihood(xt, observation)
        for i in pbar:
            # Get proposal
            for _ in range(self.max_dist):
                proposal = xt.clone() # proposal, shape = [N, L]
                idx = torch.randint(self.net.length, (num_samples,), device=self.device)
                v = torch.randint(self.net.dim, (num_samples,), device=self.device)
                proposal.scatter_(1, idx[:, None], v.unsqueeze(1))
            # Compute log prob difference
            log_likelihood = self.forward_op.log_likelihood(proposal, observation)
            log_ratio = log_likelihood - current_log_likelihood
            
            # Metropolis-Hasting step
            rho = torch.clip(torch.exp(log_ratio), max=1.0)
            seed = torch.rand_like(rho)
            xt = xt * (seed > rho).unsqueeze(-1) + proposal * (seed < rho).unsqueeze(-1)
            current_log_likelihood = log_likelihood * (seed < rho)+ current_log_likelihood * (seed > rho)
        return xt