from typing import Tuple

import torch

T = torch.Tensor


class MetropolisHastingsAcceptance():
    def __init__(self) -> None:
        pass

    def __call__(self, log_prob_proposal: T, log_prob_state: T) -> Tuple[bool, T]:
        """
        accept = min ( p(x') / p(x) , 1)
        log_accept = min( log_p(x') - log_p(x) , 1)
        log_accept = min (log_ratio, 1)
        """

        if not torch.isnan(log_prob_proposal) or not torch.isinf(log_prob_proposal):
            log_ratio = (log_prob_proposal - log_prob_state)
            log_ratio = torch.min(log_ratio, torch.zeros_like(log_ratio))

            log_u = torch.zeros_like(log_ratio).uniform_(0, 1).log()

            log_accept = torch.gt(log_ratio, log_u)
            log_accept = log_accept.bool().item()  # type: ignore

            return log_accept, log_ratio  # type: ignore

        raise ValueError(f"log_prob_proposal is nan or inf {log_prob_proposal=}")
        return False, torch.Tensor([-1])
