import numpy as np
import torch
import math

def kl_divergence_gaussian_old(mean1, var1, mean2, var2):
    """
    Compute KL divergence between two 1D Gaussian distributions.
    
    Args:
        mean1 (float): Mean of first Gaussian
        var1 (float): Variance of first Gaussian
        mean2 (float): Mean of second Gaussian
        var2 (float): Variance of second Gaussian
    
    Returns:
        float: KL divergence value (KL(P||Q) where P is N(mean1, var1) and Q is N(mean2, var2))
    """
    # Ensure variances are positive
    if var1 <= 0 or var2 <= 0:
        raise ValueError("Variances must be positive")
    
    # Convert to standard deviation for the formula
    std1 = np.sqrt(var1)
    std2 = np.sqrt(var2)
    
    # KL divergence formula for two Gaussians
    term1 = np.log(std2 / std1)
    term2 = (var1 + (mean1 - mean2)**2) / (2 * var2)
    kl = term1 + term2 - 0.5
    
    return kl/np.log(2)

def kl_divergence_gaussian(mean1, var1, mean2, var2):
    """
    Compute KL divergence between two 1D Gaussian distributions.
    
    Args:
        mean1 (float): Mean of first Gaussian
        var1 (float): Variance of first Gaussian
        mean2 (float): Mean of second Gaussian
        var2 (float): Variance of second Gaussian
    
    Returns:
        float: KL divergence value (KL(P||Q) where P is N(mean1, var1) and Q is N(mean2, var2))
    """
    # Ensure variances are positive
    if var1 <= 0 or var2 <= 0:
        raise ValueError("Variances must be positive")
    
    # Convert to standard deviation for the formula
    std1 = np.sqrt(var1)
    std2 = np.sqrt(var2)
    
    # KL divergence formula for two Gaussians
    term1 = np.log2(std2 / std1)
    term2 = np.log2(np.exp(1))*(var1 + (mean1 - mean2)**2) / (2 * var2)
    kl = term1 + term2 - 0.5 * np.log2(np.exp(1)) 
    
    return kl

def logexp_rv(B=128, N=1024):
    exp = torch.cuda.DoubleTensor(B, N, 1).exponential_()
    return torch.log(exp)

def uniform_rv(B=128, N=1024):
    unif = torch.cuda.DoubleTensor(B, N, 1).uniform_()
    return unif

def ber_rv(B=128, N=1024, dim=1, L=2): #log_2(L) bits only
    random_tensor = torch.cuda.DoubleTensor(B, N).random_(L).long()
    return random_tensor

def gauss_log_p(x, mean, variance):
    log_prob = -0.5 * (math.log(2 * math.pi * variance) + ((x - mean) ** 2) / variance)
    return log_prob

def cauchy_log_p(x, location, scale):
    log_prob = -math.log(math.pi * scale) - math.log(1 + ((x - location) ** 2) / (scale ** 2))
    return log_prob

def estimate_omega(mean_p, var_p, mean_t, var_t):
    #max find t(x)/p(x) (target/proposal)
    x = (var_t*mean_p - var_p*mean_t)/(var_t - var_p)
    t_x = np.exp(gauss_log_p(x, mean_t, var_t))
    p_x = np.exp(gauss_log_p(x, mean_p, var_p))
    return t_x/p_x

def estimate_ers_batch_mean_prob(N, mean_t, var_t, mean_p=0.0, var_p=1.0):
    B = 1
    avg_prob = []
    omega = estimate_omega(mean_p, var_p, mean_t, var_t)
    for ite in range(100000):
        y = gauss_gen(B= B, N = N)
        logS_ = logexp_rv(B= B, N= N)
        k_selected_A, _ , _, ers_prob_a = exp_sampler.select(logS_, y, mean_t, var_t, hash_val=None, ers_selection=True, omega=omega)
        avg_prob.append(ers_prob_a.item()) 
    avg_prob = np.asarray(avg_prob)
    return avg_prob.mean()

def torch_to_numpy(tensor):
    tensor=tensor.detach()
    if isinstance(tensor, torch.Tensor):
        if tensor.is_cuda:
            tensor = tensor.cpu()
        return tensor.numpy()
    else:
        raise ValueError("Input must be a Torch tensor.")

def gauss_gen(mu=0.0, var=1.0, dim=1,  B=128, N=1024):
    # Generate Gaussian samples with mean 0 and the predefined variance
    gaussian_samples = torch.cuda.DoubleTensor(B, N,dim).normal_(mu, 1) * np.sqrt(var)
    return gaussian_samples

class Exp_Sampler():
    def __init__(self):
        pass

    def select(self, logS_, y, mean_t, var_t, mean_p=0.0, var_p=1.0, hash_val=None, message=None, ers_selection=None, omega=None):
        N = len(logS_) 
        
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True)
        log_t_ = gauss_log_p(y, mean_t, var_t).sum(dim=-1, keepdim=True)

        if message == None:
            score_x_ = logS_ + log_p_ - log_t_
        else:
            filtered_ = (hash_val==message[0])*1.0 + 1e-25
            log_filtered_ = torch.log(filtered_)
            score_x_ = logS_ + log_p_ - log_t_  - log_filtered_[...,None]

        K_min = torch.argmin(score_x_, dim=1)
        selected_y = y[torch.arange(score_x_.shape[0]), K_min[:,0]]

        out_m = hash_val[torch.arange(score_x_.shape[0]), K_min[:,0]] if hash_val != None else None

        if message != None:
            assert out_m == message

        #ERS selection
        #print (log_t_)
        #print (log_p_)
        if ers_selection:
            weights = torch.exp(log_t_)/torch.exp(log_p_)
            z_bar = weights.sum() - weights[0,K_min,0].sum() + omega
            z_hat = weights.sum()
            ers_prob = z_hat/z_bar
        else:
            ers_prob = None
        return K_min, selected_y, out_m, ers_prob