import numpy as np
import torch
import math

def logexp_rv(B=1, N=1024, dim=1, dtype=torch.float, device='cuda'):
    exp = torch.empty((B, N, dim), dtype=dtype, device=device).exponential_()
    return torch.log(exp)

def uniform_rv(B=1, N=1024, dim=1, dtype=torch.float, device='cuda'):
    unif = torch.empty((B, N, dim), dtype=dtype, device=device).uniform_()
    return unif

def ber_rv(L=2, B=1, N=1024, dim=1, dtype=torch.long, device='cuda'):
    rand = torch.empty((B, N, dim), dtype=dtype, device=device).random_(L)
    return rand

def gauss_rv(mu=0.0, var=1.0, B=1, N=1024, dim=1, dtype=torch.float, device='cuda'):
    # Generate Gaussian samples with mean 0 and the predefined variance
    gauss = torch.empty((B, N, dim), dtype=dtype, device=device).normal_(mu, 1) * np.sqrt(var)
    return gauss

def gauss_log_p(x, mean, variance):
    log_prob = -0.5 * (math.log(2 * math.pi * variance) + ((x - mean) ** 2) / variance)
    return log_prob

def estimate_omega(mean_t, var_t, mean_p, var_p):
    samples = gauss_rv(mean_p, var_p, B=1, N=2**20, dim=1)
    likelihood_proposal = gauss_log_p(samples, mean_p, var_p)
    likelihood_target = gauss_log_p(samples, mean_t, var_t)
    return samples, torch.max(torch.exp(likelihood_target - likelihood_proposal))

class ExpSampler():
    def __init__(self):
        pass

    def select(self, logS_, y, mean_t, var_t, mean_p=0.0, var_p=1.0, hash_val=None, message=None, omega=None):
        N = len(logS_)
        
        log_p_ = gauss_log_p(y, mean_p, var_p).sum(dim=-1, keepdim=True)
        log_t_ = gauss_log_p(y, mean_t, var_t).sum(dim=-1, keepdim=True)

        if message == None:
            score_x_ = logS_ + log_p_ - log_t_
        else:
            filtered_ = (hash_val == message[0]) * 1.0 + 1e-25
            log_filtered_ = torch.log(filtered_)
            score_x_ = logS_ + log_p_ - log_t_ - log_filtered_

        K_min = torch.argmin(score_x_, dim=1)
        selected_y = y[:score_x_.shape[0], K_min[:, 0]]
        
        out_m = hash_val[:score_x_.shape[0], K_min[:, 0]] if hash_val != None else None

        if message != None:
            assert torch.all(out_m == message, dim=1)

        return K_min, selected_y, out_m
