import torch

from ...targets.target_distribution import TargetDistribution

MIN_TIME = 1e-10


def mh_corrector_step(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: float,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    *args,
    log_prob_x: torch.Tensor = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Performs a single Metropolis-Hastings (MH) step with adaptive step size scaling.

    Arguments:
        base_target: the target distribution to sample from
        x: (batch_size, dim) tensor of current samples
        z: (batch_size, dim) tensor of reference samples
        step_size: current step size scaling factor
        time: current time
        target_acceptance: target acceptance rate for adaptation
        adaptation_rate: adaptation rate for step size adaptation
    
    Returns:
        new_x: (batch_size, dim) tensor of updated samples after MH step
        new_log_prob: (batch_size, 1) tensor of updated log probabilities at new_x
        new_step_size: updated step size scaling factor after adaptation    
    """

    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    if log_prob_x is None:
        x0 = z + inv_time * (x - z) # (batch_size, dim)
        log_prob_x = base_target.log_prob(x0) # (batch_size, 1)

    is_valid_logprob = torch.isfinite(log_prob_x)  # (batch_size, 1)

    noise = torch.randn_like(x, device=x.device, dtype=x.dtype)  # (batch_size, dim)

    y = x + step_size * noise  # (batch_size, dim)
    
    y0 = z + inv_time * (y - z)  # (batch_size, dim)
    log_prob_y = base_target.log_prob(y0)  # (batch_size, 1)

    log_acceptance = log_prob_y - log_prob_x  # (batch_size, 1)
    log_acceptance = torch.where(is_valid_logprob, log_acceptance, torch.tensor(-float('inf'), device=x.device))

    # Accept/Reject
    rand_val = torch.rand_like(log_acceptance).log()  # (batch_size, 1)
    mask = rand_val < log_acceptance  # (batch_size, 1)
    x = torch.where(mask, y, x)  # (batch_size, dim)
    log_prob_x = torch.where(mask, log_prob_y, log_prob_x)  # (batch_size, 1)

    # Update step
    if adaptation_rate > 0.0:
        log_acceptance = torch.nan_to_num(log_acceptance, nan=-float('inf'))
        acceptance = torch.exp(log_acceptance)
        acceptance = torch.clamp(acceptance, min=0.0, max=1.0)
        avg_acc = acceptance.mean()
        diff = avg_acc - target_acceptance
        step_size = step_size * torch.exp(adaptation_rate * diff)

    return x.detach(), log_prob_x.detach(), step_size.detach()

def mh_corrector(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: float,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    n_corrector_steps: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Performs Metropolis-Hastings (MH) steps with adaptive step size scaling.

    Arguments:
        base_target: the target distribution to sample from
        x: (batch_size, dim) tensor of current samples
        z: (batch_size, dim) tensor of reference samples
        step_size: current step size scaling factor
        time: current time
        target_acceptance: target acceptance rate for adaptation
        adaptation_rate: adaptation rate for step size adaptation
        n_corrector_steps: number of MH corrector steps to perform
    
    Returns:
        new_x: (batch_size, dim) tensor of updated samples after MH step
        new_log_prob: (batch_size, 1) tensor of updated log probabilities at new_x
        new_grad_log_prob: (batch_size, dim) tensor of gradients of log probabilities at new_x
        new_step_size: updated step size scaling factor after adaptation    
    """

    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    x0 = z + inv_time * (x - z) # (batch_size, dim)
    log_prob_x = base_target.log_prob(x0) # (batch_size, 1)

    for _ in range(n_corrector_steps):

        x, log_prob_x, step_size = mh_corrector_step(
            base_target,
            x,
            z,
            step_size,
            time,
            target_acceptance,
            adaptation_rate,
            log_prob_x=log_prob_x
        )

    # Compute gradient of log prob at final x
    # Note: MH does not use gradients for proposals, so this is just for output consistency
    x0 = z + inv_time * (x - z) # (batch_size, dim)
    grad_log_prob_x, log_prob_x = base_target.grad_log_prob(x0, return_log_prob=True) # (batch_size, dim), (batch_size, 1)

    return x.detach(), log_prob_x.detach(), grad_log_prob_x.detach(), step_size.detach()