import numpy as np
import torch
import math
from torch.distributions import Normal

def kl_divergence_gaussian(mean1, var1, mean2, var2):
    """
    Compute KL divergence between two 1D Gaussian distributions.
    
    Args:
        mean1 (float): Mean of first Gaussian
        var1 (float): Variance of first Gaussian
        mean2 (float): Mean of second Gaussian
        var2 (float): Variance of second Gaussian
    
    Returns:
        float: KL divergence value (KL(P||Q) where P is N(mean1, var1) and Q is N(mean2, var2))
    """
    # Ensure variances are positive
    if var1 <= 0 or var2 <= 0:
        raise ValueError("Variances must be positive")
    
    # Convert to standard deviation for the formula
    std1 = np.sqrt(var1)
    std2 = np.sqrt(var2)
    
    # KL divergence formula for two Gaussians
    term1 = np.log(std2 / std1)
    term2 = (var1 + (mean1 - mean2)**2) / (2 * var2)
    kl = term1 + term2 - 0.5
    
    return kl/np.log(2)

def logexp_rv(B=1, N=1024, dim=1, dtype=torch.float, device='cuda'):
    exp = torch.empty((B, N, dim), dtype=dtype, device=device).exponential_()
    return torch.log(exp)

def uniform_rv(B=1, N=1024, dim=1, dtype=torch.float, device='cuda'):
    unif = torch.empty((B, N, dim), dtype=dtype, device=device).uniform_()
    return unif

def ber_rv(L=2, B=1, N=1024, dim=1, dtype=torch.long, device='cuda'):
    rand = torch.empty((B, N, dim), dtype=dtype, device=device).random_(L)
    return rand

# Precompute constant outside the function
@torch.no_grad()
def gauss_log_p(x, mean, variance):
    std = torch.sqrt(variance)  # Convert to standard deviation
    dist = Normal(loc=mean, scale=std)
    return dist.log_prob(x)

@torch.no_grad()
def estimate_omega(mean_p, var_p, mean_t, var_t):
    #max find t(x)/p(x) (target/proposal)
    x = (var_t*mean_p - var_p*mean_t)/(var_t - var_p)
    t_x = torch.exp(gauss_log_p(x, mean_t, var_t))
    p_x = torch.exp(gauss_log_p(x, mean_p, var_p))
    return t_x/p_x

def estimate_ers_batch_mean_prob(N, mean_t, var_t, mean_p=0.0, var_p=1.0):
    B = 1
    avg_prob = []
    omega = estimate_omega(mean_p, var_p, mean_t, var_t)
    for ite in range(100000):
        y = gauss_gen(B= B, N = N)
        logS_ = logexp_rv(B= B, N= N)
        k_selected_A, _ , _, ers_prob_a = exp_sampler.select(logS_, y, mean_t, var_t, hash_val=None, ers_selection=True, omega=omega)
        avg_prob.append(ers_prob_a.item()) 
    avg_prob = np.asarray(avg_prob)
    return avg_prob.mean()

def torch_to_numpy(tensor):
    tensor=tensor.detach()
    if isinstance(tensor, torch.Tensor):
        if tensor.is_cuda:
            tensor = tensor.cpu()
        return tensor.numpy()
    else:
        raise ValueError("Input must be a Torch tensor.")

def gauss_gen(mu=0.0, var=1.0, dim=1,  B=128, N=1024, dtype=torch.float, device='cuda'):
    # Generate Gaussian samples with mean 0 and the predefined variance
    gaussian_samples = torch.empty((B, N, dim), dtype=dtype, device=device).normal_(mu, 1) * torch.sqrt(var)
    return gaussian_samples

def fast_index_select(y, index_tensor):
    feature_dim = y.shape[-1]
    index_tensor = index_tensor.unsqueeze(-1).expand(-1, -1, feature_dim)
    result = torch.gather(y, dim=1, index=index_tensor)
    # Squeeze to remove the singleton dimension
    result = result.squeeze(1)
    return result

#import time

class Exp_Sampler():
    def __init__(self):
        pass
    
    @torch.no_grad()
    def encode(self, logS_, y, multiN, mean_p=0.0, var_p=1.0, hash_val=None, message=None, ers_selection=None, omega=None, num_batch=1):
        
        N = len(logS_) 
        feature_dim = y.shape[-1]
        batch_dim = int(feature_dim/num_batch)
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True) #proposal 
        
       
        log_t_ = []
        if len(y[0]) < 2**18:
            log_t_.append(multiN.log_prob(y[0,:].cuda())[None, :, None]) 
        else:
            for i in range(int(len(y[0])/2**18)):
                log_t_.append(multiN.log_prob(y[0,i*(2**18):(i+1)*(2**18)].cuda())[None, :, None]) 
        log_t_= torch.cat(log_t_, dim=1)

        if message == None:
            score_x_ = logS_ + log_p_ - log_t_
        else:
            filtered_ = (hash_val==message[0])*1.0 + 1e-25
            log_filtered_ = torch.log(filtered_)
            score_x_ = logS_ + log_p_ - log_t_  - log_filtered_[...,None]
        
        K_min = torch.argmin(score_x_, dim=1)
        selected_y = fast_index_select(y, K_min) #result.squeeze(1)

        if hash_val != None:
            out_m = fast_index_select(hash_val, K_min)
        else:
            out_m = None
        
        #ERS selection
        if ers_selection:
            weights = torch.exp(log_t_)/torch.exp(log_p_)
            z_bar = weights.sum() - weights[0,K_min,0].sum() + omega
            z_hat = weights.sum()
            ers_prob = z_hat/z_bar
        else:
            ers_prob = None
        
        return K_min, selected_y, out_m, ers_prob
    
    def decode(self, nce_model, logS_, hash_val, y,  cor_img, message, num_batch):
        total_feature_dim = y.shape[-1]
        sample_dim = int(total_feature_dim/num_batch)
        
        idx = torch.nonzero(hash_val == message, as_tuple=True)[1].view(1, -1, 1)
        #print(idx.shape, logS_.shape)
        
        logS_ = torch.gather(logS_, dim=1, index=idx.expand(logS_.shape[0], -1, -1))
        #print(logS_.shape)
        
        y = torch.gather(y, dim=1, index=idx.expand(-1, -1, y.shape[-1])).squeeze(0)
        #print(y.shape, cor_img.shape)

        with torch.no_grad():
            _, llr = nce_model.forward(torch.repeat_interleave(cor_img, y.shape[0], 0),
                                       y.repeat(cor_img.shape[0], 1))
            log_value = -llr.reshape(logS_.shape)
        #print(log_value.shape)
        #with torch.no_grad():
        #    log_value = 0
        #    _,llr = nce_model.forward_2(cor_img[:1], y[:,:sample_dim])
        #    log_value += -llr
        
        
        score_y_ = logS_ + log_value  #- log_filtered_[...,None]
        #print(score_y_.shape)
        
        K_min = torch.argmin(score_y_, dim=1)
        #print(K_min.shape, y.shape)
        
        selected_y = torch.gather(y.expand(logS_.shape[0], -1, -1), dim=1, 
                                  index=K_min.unsqueeze(-1).expand(-1, -1, y.shape[-1]))
        #print(selected_y.shape)
        #selected_values_y_ = U_[torch.arange(score_y_.shape[0]), argmin_indices_decoder_[:,0]]

        return K_min, selected_y.squeeze(1)

class RejectionSampler():
    def __init__(self):
        pass

    def select(self, y, mean_t, var_t, mean_p=0.0, var_p=1.0, hash_val=None, message=None, omega=None):
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True)
        log_t_ = gauss_log_p(y, mean_t, var_t).sum(dim=-1, keepdim=True)
        accept_prob = torch.exp(log_t_)/(torch.exp(log_p_) * omega)
        
        return accept_prob
        
def compute_decoder_target(side_mean, side_var, var_x, var_u):
    # mean_y, var_y: side info
    # var_x: encoder's input var
    # var_u: encoder's target variance.
    mean_t = side_mean*var_x/side_var
    var_t = var_u - var_x**2/side_var
    return mean_t, var_t