import torch


def compute_weights_safe(state, logpdf_fn, logproposal_fn, eps=1.0, clip=20.0):
    log_w = eps * (logpdf_fn(state) - logproposal_fn(state))

    # Remove NaNs and infs
    log_w = torch.nan_to_num(log_w, neginf=-1e10, posinf=1e10)

    # Robust clipping around median
    med = log_w.median()
    log_w = torch.clamp(log_w, med - clip, med + clip)

    # Stable normalization
    log_w = log_w - log_w.max()
    w = torch.exp(log_w)

    # Final safety
    w = torch.nan_to_num(w, nan=0.0, posinf=0.0)
    w_sum = w.sum()

    if w_sum <= 0:
        # fallback: uniform weights
        return torch.ones_like(w) / len(w)

    return w / w_sum


def compute_weights (state,logpdf_fn,logproposal_fn,eps=1.0):
    """
    Compute importance weights of E2MC algo.
    Args : 
        state : Tensor of shape [N,1]
        logpdf_fn : fct taking x and returning log π(x)
        logproposal_fn : fct taking x and returning log mu(x).
        eps : epsilon parameter

    Returns :
        w_x Tensor of shape [N]      
    """

    log_w_x =  eps * (logpdf_fn(state) - logproposal_fn(state)) #[N]
    #w = torch.exp(log_w_x) #[N]
    log_w_x = torch.clamp(log_w_x, max=log_w_x.median() + 5.0)

    log_w_max = log_w_x.max(dim=0, keepdim=True).values

    # Compute the weights
    w = torch.exp(log_w_x - log_w_max)




    return w/w.sum() # [N]




def sample_from_mixture (x, w_x, y, w_y, lamda):
    """ 
    Sample from a mixture of two weighted empirical distributions
    Z ~ λ . Σ w_x * δ_x + (1-λ) . Σ w_y * δ_y *

    Args : 
        x, y : tensors of shape [N,1]
        w_x, w_y : tensors of shape [N], sum to 1 for each variables (x or y)
        lamda : float in [0,1]
    Returns : 
        Z : Tensor [N,1]
    """
    N = x.shape[0]
    samples = torch.cat([x,y],dim=0) # [2N,1] stack under first dim
    weights = torch.cat([lamda * w_x, (1 - lamda) * w_y]) # [2N]
    #weights = weights / weights.sum()
    
    idx = torch.multinomial(weights, N, replacement=True)

    return samples[idx] # z
