import numpy as np
import torch
import math
from torch.distributions import Normal

def kl_divergence_gaussian(mean1, var1, mean2, var2):
    """
    Compute KL divergence between two 1D Gaussian distributions.
    
    Args:
        mean1 (float): Mean of first Gaussian
        var1 (float): Variance of first Gaussian
        mean2 (float): Mean of second Gaussian
        var2 (float): Variance of second Gaussian
    
    Returns:
        float: KL divergence value (KL(P||Q) where P is N(mean1, var1) and Q is N(mean2, var2))
    """
    # Ensure variances are positive
    if var1 <= 0 or var2 <= 0:
        raise ValueError("Variances must be positive")
    
    # Convert to standard deviation for the formula
    std1 = np.sqrt(var1)
    std2 = np.sqrt(var2)
    
    # KL divergence formula for two Gaussians
    term1 = np.log(std2 / std1)
    term2 = (var1 + (mean1 - mean2)**2) / (2 * var2)
    kl = term1 + term2 - 0.5
    
    return kl/np.log(2)

def logexp_rv(B=128, N=1024):
    exp = torch.cuda.DoubleTensor(B, N, 1).exponential_()
    return torch.log(exp)

def uniform_rv(B=128, N=1024):
    unif = torch.cuda.DoubleTensor(B, N, 1).uniform_()
    return unif

def ber_rv(B=128, N=1024, dim=1, L=2): #log_2(L) bits only
    random_tensor = torch.cuda.DoubleTensor(B, N).random_(L).long()
    return random_tensor

# Precompute constant outside the function
@torch.no_grad()
def gauss_log_p(x, mean, variance):
    std = torch.sqrt(variance)  # Convert to standard deviation
    dist = Normal(loc=mean, scale=std)
    return dist.log_prob(x)

@torch.no_grad()
def estimate_omega(mean_p, var_p, mean_t, var_t):
    #max find t(x)/p(x) (target/proposal)
    x = (var_t*mean_p - var_p*mean_t)/(var_t - var_p)
    t_x = torch.exp(gauss_log_p(x, mean_t, var_t))
    p_x = torch.exp(gauss_log_p(x, mean_p, var_p))
    return t_x/p_x

def estimate_ers_batch_mean_prob(N, mean_t, var_t, mean_p=0.0, var_p=1.0):
    B = 1
    avg_prob = []
    omega = estimate_omega(mean_p, var_p, mean_t, var_t)
    for ite in range(100000):
        y = gauss_gen(B= B, N = N)
        logS_ = logexp_rv(B= B, N= N)
        k_selected_A, _ , _, ers_prob_a = exp_sampler.select(logS_, y, mean_t, var_t, hash_val=None, ers_selection=True, omega=omega)
        avg_prob.append(ers_prob_a.item()) 
    avg_prob = np.asarray(avg_prob)
    return avg_prob.mean()

def torch_to_numpy(tensor):
    tensor=tensor.detach()
    if isinstance(tensor, torch.Tensor):
        if tensor.is_cuda:
            tensor = tensor.cpu()
        return tensor.numpy()
    else:
        raise ValueError("Input must be a Torch tensor.")

def gauss_gen(mu=0.0, var=1.0, dim=1,  B=128, N=1024):
    # Generate Gaussian samples with mean 0 and the predefined variance
    gaussian_samples = torch.cuda.DoubleTensor(B, N,dim).normal_(mu, 1) * torch.sqrt(var)
    return gaussian_samples

def fast_index_select(y, index_tensor):
    feature_dim = y.shape[-1]
    index_tensor = index_tensor.unsqueeze(-1).expand(-1, -1, feature_dim)
    result = torch.gather(y, dim=1, index=index_tensor)
    # Squeeze to remove the singleton dimension
    result = result.squeeze(1)
    return result
    

#import time

class Exp_Sampler():
    def __init__(self):
        pass
    
    @torch.no_grad()
    def select(self, logS_, y, mean_t, var_t, mean_p=0.0, var_p=1.0, hash_val=None, message=None, ers_selection=None, omega=None):
        
        N = len(logS_) 
        #print (mean_p, var_p, y.type())
        feature_dim = y.shape[-1]
        #torch.cuda.synchronize()
        #start_time = time.time() 
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True)
        log_t_ = gauss_log_p(y, mean_t, var_t).sum(dim=-1, keepdim=True)
        #torch.cuda.synchronize()
        #print ('3', time.time() - start_time)
        

        if message == None:
            score_x_ = logS_ + log_p_ - log_t_
        else:
            filtered_ = (hash_val==message[0])*1.0 + 1e-25
            log_filtered_ = torch.log(filtered_)
            score_x_ = logS_ + log_p_ - log_t_  - log_filtered_[...,None]
        
        K_min = torch.argmin(score_x_, dim=1)
        selected_y = fast_index_select(y, K_min) #result.squeeze(1)
        
        if hash_val != None:
            out_m = fast_index_select(hash_val[...,None], K_min)
        else:
            out_m = None
        
        #ERS selection
        if ers_selection:
            weights = torch.exp(log_t_)/torch.exp(log_p_)
            z_bar = weights.sum() - weights[0,K_min,0].sum() + omega
            z_hat = weights.sum()
            ers_prob = z_hat/z_bar
        else:
            ers_prob = None
        
        return K_min, selected_y, out_m, ers_prob

class RejectionSampler():
    def __init__(self):
        pass

    def select(self, y, mean_t, var_t, mean_p=0.0, var_p=1.0, hash_val=None, message=None, omega=None):
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True)
        log_t_ = gauss_log_p(y, mean_t, var_t).sum(dim=-1, keepdim=True)
        accept_prob = torch.exp(log_t_)/(torch.exp(log_p_) * omega)
        
        return accept_prob
        
def compute_decoder_target(side_mean, side_var, var_x, var_u):
    # mean_y, var_y: side info
    # var_x: encoder's input var
    # var_u: encoder's target variance.
    mean_t = side_mean*var_x/side_var
    var_t = var_u - var_x**2/side_var
    return mean_t, var_t