import torch

from ...targets.target_distribution import TargetDistribution

MIN_TIME = 1e-10
TOL = 1e5

def mala_corrector_step(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: torch.Tensor,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    log_prob_x: torch.Tensor = None,
    grad_log_prob_x: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Performs a single Metropolis-Adjusted Langevin Algorithm (MALA) step with adaptive step size scaling.
    
    Arguments:
        base_target: the target distribution to sample from
        x: (..., dim) tensor of current samples
        z: (..., dim) tensor of reference samples
        step_size: current step size scaling factor
        time: current time
        target_acceptance: target acceptance rate for adaptation
        adaptation_rate: adaptation rate for step size adaptation
        log_prob_x: (..., 1) tensor of current log probabilities at x (optional)
        grad_log_prob_x: (..., dim) tensor of current gradients of log probability at x (optional)
    
    Returns:
        new_x: (..., dim) tensor of updated samples after MALA step
        new_log_prob: (..., 1) tensor of updated log probabilities at new_x
        new_grad_log_prob: (..., dim) tensor of updated gradients of log probability at new_x
        new_step_size: updated step size scaling factor after adaptation
    """



    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    if log_prob_x is None or grad_log_prob_x is None:
        x0 = z + inv_time * (x - z) # (..., dim)
        grad_log_prob_x, log_prob_x = base_target.grad_log_prob(x0, return_log_prob=True) # (..., dim)
        grad_log_prob_x = torch.clamp(grad_log_prob_x, min=-TOL, max=TOL)

    is_valid_logprob = torch.isfinite(log_prob_x)  # (..., 1)
    is_valid_grad = torch.isfinite(grad_log_prob_x).all(dim=-1).unsqueeze(-1)  # (..., 1)
    is_valid_step = is_valid_logprob & is_valid_grad  # (..., 1)
    
    noise = torch.randn_like(x, device=x.device, dtype=x.dtype)  # (..., dim)

    drift_x = step_size * grad_log_prob_x  # (..., dim)
    diff_scale = torch.sqrt(2 * step_size * time_clamp)  # (1,)
    y = x + drift_x + diff_scale * noise  # (..., dim)

    y0 = z + inv_time * (y - z)  # (batch_size, dim)
    grad_log_prob_y, log_prob_y = base_target.grad_log_prob(y0, return_log_prob=True)  # (..., dim)
    grad_log_prob_y = torch.clamp(grad_log_prob_y, min=-TOL, max=TOL)

    drift_y = step_size * grad_log_prob_y  # (..., dim)

    norm_fwd = ((y - (x + drift_x)) ** 2).sum(dim=-1, keepdim=True)  # (..., 1)
    norm_bwd = ((x - (y + drift_y)) ** 2).sum(dim=-1, keepdim=True)  # (..., 1)

    log_q_dif = (1.0 / (4.0 * step_size * time_clamp + MIN_TIME)) * (norm_fwd - norm_bwd) # (..., 1)

    log_acceptance = log_prob_y - log_prob_x + log_q_dif  # (..., 1)
    log_acceptance = torch.where(is_valid_step, log_acceptance, torch.tensor(-float('inf'), device=x.device))

    # Accept/Reject
    rand_val = torch.rand_like(log_acceptance).log()  # (..., 1)
    mask = rand_val < log_acceptance  # (..., 1)

    x = torch.where(mask, y, x)  # (..., dim)
    log_prob_x = torch.where(mask, log_prob_y, log_prob_x)  # (..., 1)
    grad_log_prob_x = torch.where(mask, grad_log_prob_y, grad_log_prob_x)  # (..., dim)

    # Update step size
    if adaptation_rate > 0.0:
        log_acceptance = torch.nan_to_num(log_acceptance, nan=-float('inf')) # (..., 1)
        acceptance = torch.exp(log_acceptance) # (..., 1)
        acceptance = torch.clamp(acceptance, min=0.0, max=1.0)
        avg_acc = acceptance.mean()  # (1,)
        diff = avg_acc - target_acceptance  # (1,)
        step_size = step_size * torch.exp(adaptation_rate * diff)  # (1,)
    
    return x.detach(), log_prob_x.detach(), grad_log_prob_x.detach(), step_size.detach()

def mala_corrector(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: torch.Tensor,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    n_corrector_steps: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Performs Metropolis-Adjusted Langevin Algorithm (MALA) steps with adaptive step size scaling.
    
    Arguments:
        base_target: the target distribution to sample from
        x: (batch_size, dim) tensor of current samples
        z: (batch_size, dim) tensor of reference samples
        step_size: current step size scaling factor
        time: current time
        target_acceptance: target acceptance rate for adaptation
        adaptation_rate: adaptation rate for step size adaptation
        n_corrector_steps: number of MALA corrector steps to perform
    
    Returns:
        new_x: (batch_size, dim) tensor of updated samples after MALA step
        new_log_prob: (batch_size, 1) tensor of updated log probabilities at new_x
        new_grad_log_prob: (batch_size, dim) tensor of updated gradients of log probability at new_x
        new_step_size: updated step size scaling factor after adaptation
    """ 

    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    x0 = z + inv_time * (x - z) # (batch_size, dim)
    grad_log_prob_x, log_prob_x = base_target.grad_log_prob(x0, return_log_prob=True) # (batch_size, dim)
    grad_log_prob_x = torch.clamp(grad_log_prob_x, min=-TOL, max=TOL)

    for _ in range(n_corrector_steps):

        x, log_prob_x, grad_log_prob_x, step_size = mala_corrector_step(
            base_target,
            x,
            z,
            step_size,
            time,
            target_acceptance,
            adaptation_rate,
            log_prob_x,
            grad_log_prob_x
        )

    return x.detach(), log_prob_x.detach(), grad_log_prob_x.detach(), step_size.detach()