import torch
from sampling.targets.target_distribution import TargetDistribution

MIN_TIME = 1e-10
TOL = 1e5


def hmc_corrector_step(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: float,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    n_leapfrog_steps: int = 5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Performs a single Hamiltonian Monte Carlo (HMC) step.

    Arguments:
        base_target: the target distribution to sample from
        x: (batch_size, dim) tensor of current samples
        z: (batch_size, dim) tensor of reference samples
        step_size: current step size scaling factor
        time: current time
        n_leapfrog_steps: number of leapfrog steps to perform.

    Returns:
        x_new: (batch_size, dim) tensor of updated samples after HMC step
        log_prob_new: (batch_size, 1) tensor of updated log probabilities at x_new
        grad_log_prob_new: (batch_size, dim) tensor of updated gradients of log probability at x_new
    """

    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    v = torch.randn_like(x, device=x.device, dtype=x.dtype)

    current_U = - log_prob_x
    current_K = 0.5 * (v ** 2).sum(dim=-1, keepdim=True)
    current_H = current_U + current_K

    x_new = x.clone()
    v_new = v.clone()
    grad_new = grad_log_prob_x.clone()
    log_prob_new = log_prob_x.clone()

    v_new = v_new + 0.5 * step_size * grad_new

    for i in range(n_leapfrog_steps):
        x_new = x_new + step_size * v_new
        
        x0_new = z + inv_time * (x_new - z)
        grad_new, log_prob_new = base_target.grad_log_prob(x0_new, return_log_prob=True)
        grad_new = torch.clamp(grad_new, min=-TOL, max=TOL)
        
        if i < n_leapfrog_steps - 1:
            # Full-step momentum for intermediate steps
            v_new = v_new + step_size * grad_new
        else:
            # Half-step momentum for the final step
            v_new = v_new + 0.5 * step_size * grad_new

    proposed_U = -log_prob_new
    proposed_K = 0.5 * (v_new ** 2).sum(dim=-1, keepdim=True)
    proposed_H = proposed_U + proposed_K

    log_acceptance = current_H - proposed_H
    
    # Check for validity of the proposal (NaNs in new grad or logprob)
    is_valid_prop_logprob = torch.isfinite(log_prob_new)
    is_valid_prop_grad = torch.isfinite(grad_new).all(dim=-1).unsqueeze(-1)
    is_valid_proposal = is_valid_prop_logprob & is_valid_prop_grad
    log_acceptance = torch.where(is_valid_proposal, log_acceptance, torch.tensor(-float('inf'), device=x.device))

    # Accept/Reject
    rand_val = torch.rand_like(log_acceptance).log()
    mask = rand_val < log_acceptance
    
    x = torch.where(mask, x_new, x)
    log_prob_x = torch.where(mask, log_prob_new, log_prob_x)
    grad_log_prob_x = torch.where(mask, grad_new, grad_log_prob_x)

    if adaptation_rate > 0.0:
        log_acceptance = torch.nan_to_num(log_acceptance, nan=-float('inf')) 
        acceptance = torch.exp(log_acceptance)
        acceptance = torch.clamp(acceptance, min=0.0, max=1.0)
        avg_acc = acceptance.mean()
        diff = avg_acc - target_acceptance
        step_size = step_size * torch.exp(adaptation_rate * diff)
    
    return x.detach(), log_prob_x.detach(), grad_log_prob_x.detach(), step_size.detach()


def hmc_corrector(
    base_target: TargetDistribution,
    x: torch.Tensor,
    z: torch.Tensor,
    step_size: float,
    time: torch.Tensor,
    target_acceptance: float,
    adaptation_rate: float,
    n_leapfrog_steps: int = 5,
    n_corrector_steps: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Performs Hamiltonian Monte Carlo (HMC) steps with adaptive step size scaling.

    Arguments:
        base_target: the target distribution to sample from
        x: (batch_size, dim) tensor of current samples
        z: (batch_size, dim) tensor of reference samples
        step_size: current s
        time: current time
        target_acceptance: target acceptance rate for adaptation
        adaptation_rate: adaptation rate for 
        n_leapfrog_steps: number of leapfrog steps to perform.
        n_corrector_steps: number of HMC corrector steps to perform

    Returns:
        x_final: (batch_size, dim) tensor of updated samples after HMC step
        log_prob_final: (batch_size, 1) tensor of updated log probabilities at x_final
        grad_log_prob_final: (batch_size, dim) tensor of updated gradients of log probability at x_final
        step_size_new: updated 
    """
    
    time_clamp = torch.clamp(time, min=MIN_TIME, max=1.0)
    inv_time = 1.0 / time_clamp

    # Initial gradient and log prob
    x0 = z + inv_time * (x - z)  # (batch_size, dim)
    grad_log_prob_x, log_prob_x = base_target.grad_log_prob(x0, return_log_prob=True)  # (batch_size, dim)
    grad_log_prob_x = torch.clamp(grad_log_prob_x, min=-TOL, max=TOL)

    for _ in range(n_corrector_steps):
        x, log_prob_x, grad_log_prob_x, step_size = hmc_corrector_step(
            base_target,
            x,
            z,
            step_size,
            time,
            target_acceptance,
            adaptation_rate,
            n_leapfrog_steps,
        )       
        
    return x.detach(), log_prob_x.detach(), grad_log_prob_x.detach(), step_size.detach()
