import torch
import math
import random
from tqdm import tqdm
from sampler_utils import *
import numpy as np

exp_sampler = ExpSampler()

def decoder_target(side_info, var_side, var_x, var_u):
    mean_dec = side_info * var_x / var_side
    var_dec = var_u - var_x**2 / var_side
    return mean_dec, var_dec

def run_one_test(var_x, var_u_x, var_y_x, num_proposals, L_max, n_trials, test_type):
    B = 1
    N = 2 ** 15
    var_u = var_x + var_u_x     # proposal variance
    var_side = var_x + var_y_x  # variance of side info
    mse_db = 0
    for i in range(n_trials):
        #input X
        x = gauss_rv(var=var_x, B=B, N=1, dim=1, dtype=torch.double)            # encoder input
        side_info = torch.randn((num_proposals, 1, 1), 
                                device='cuda', dtype=torch.double)
        side_info = x + np.sqrt(var_y_x) * side_info                            # decoder side info
        y = gauss_rv(var=var_u, B=B, N=N, dim=1, dtype=torch.double)            # proposal
        mean_dec, var_dec = decoder_target(side_info, var_side, var_x, var_u)   # decoder target distribution
        
        if test_type == 0:
            logS_ = logexp_rv(B=num_proposals, N=N, dtype=torch.double)
            logS_min = logexp_rv(B=1, N=N, dtype=torch.double)
        elif test_type == 1:
            logS_min = logexp_rv(B=1, N=N, dtype=torch.double)
            logS_ = logS_min.repeat(num_proposals, 1, 1)
        elif test_type == 2:
            logS_ = logexp_rv(B=num_proposals, N=N, dtype=torch.double)
            logS_min, _ = torch.min(logS_, dim=0, keepdim=True)
            
        random_hash = ber_rv(B=B, N=N, L=L_max)
        
        # encoder selection
        k_selected_A, y_selected_A, out_message = exp_sampler.select(logS_min, y, mean_t=x, var_t=var_u_x,
                                                                    mean_p=0.0, var_p=var_u,
                                                                    hash_val=random_hash)
        # decoder selection
        k_selected_B, y_selected_B, _ = exp_sampler.select(logS_, y, mean_t=mean_dec, var_t=var_dec,
                                                        mean_p=0.0, var_p=var_u,
                                                        message=out_message, hash_val=random_hash)
                        
        # compute estimate
        X_hat = (y_selected_B.squeeze() * var_y_x + side_info.squeeze() * var_u_x) / (var_u_x + var_y_x + var_u_x * var_y_x / var_x)
        mse_db += 10 * np.log10(torch.min((X_hat - x.squeeze()) ** 2).item())
    
    return mse_db / n_trials
    
def test_config(L_max, num_proposals, test_type):
    n_trials_1 = 10000
    n_trials_2 = 100000

    var_x = 1.0                 # input x
    var_y_x = 0.5               # noise variance add to side info y

    var_u_x = np.array([0.01, 0.008, 0.006, 0.005, 0.003, 0.002, 0.001])
    mse_db = np.zeros_like(var_u_x)
    for i in tqdm(range(len(var_u_x))):
        mse_db[i] = run_one_test(var_x, var_u_x[i], var_y_x, num_proposals, L_max, n_trials_1, test_type)

    best = np.argmin(mse_db)
    best_mse_db = run_one_test(var_x, var_u_x[best], var_y_x, num_proposals, L_max, n_trials_2, test_type)
    return best_mse_db, var_u_x[best]

def test_loop(test_type):
    for num_proposals in [1, 2, 3, 4, 5, 6, 7, 8]:
        print(f'Testing with {num_proposals} proposals\n========================', flush=True)
        for log_L_max in [1, 2, 3, 4, 5, 6, 7, 8]:
            mse_db, var_u_x = test_config(2**log_L_max, num_proposals, test_type)
            print(f'log(Lmax)={log_L_max:2d}  var_w_a={var_u_x:.3f}  MSE={mse_db:.4e}', flush=True)
    print('', flush=True)

print(f'Testing Baseline Scheme With Repeated Common Randomness', flush=True)
test_loop(1)
print(f'Testing Our Scheme', flush=True)
test_loop(2)
