import numpy as np
import torch
import math
from sampler_utils import *
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

torch.backends.cudnn.benchmark = True

mse = torch.nn.MSELoss()
exp_sampler = Exp_Sampler()


def ERS_WZ_comm(B, N, n_trials, dim, L_max, var_u=None, var_u_x=0.01, var_x=1.0, var_y_x=0.01, threshold=3.0, debug=False):
    #var_x: x variance
    #var_y: y variance
    var_x = torch.tensor(var_x).cuda()
    var_y_x = torch.tensor(var_y_x).cuda()
    var_side = var_x + var_y_x #variance of side info.

    var_u_x = var_u_x #desired distortion.
    threshold = threshold
    
    var_u_x = torch.tensor(var_u_x).cuda()
    
    if var_u == None:
        var_u = var_x + var_u_x #proposal variance
    var_u = torch.tensor(var_u).double().cuda()

    omegas = []
    n_batch = []
    y_selected = []


    avg_accept = 0.0
    all_ns = 0
    for _ in [L_max]:
        match = 0
        distortion = 0
        distortion_A = 0
        x = gauss_gen(var= var_x, B= B, N = 1, dim= dim)
        mean_p=torch.zeros_like(x)
        
        for i in tqdm(range(n_trials)):
            #input X
            n = 1
            x = gauss_gen(var= var_x, B= B, N = 1, dim= dim) # encoder input
            x[x> threshold] = threshold
            x[x<-threshold] = -threshold
            side_info = x + torch.sqrt(var_y_x)*torch.randn_like(x) # decoder side info
            mean_dec, var_dec = compute_decoder_target(side_info, var_side, var_x, var_u) #decoder target distribution.
            mean_dec = mean_dec.cuda()
            var_dec = var_dec.cuda()

            omega = 1
            for d in range(dim):
                omega *= estimate_omega(0.0, var_u, x[0,0,d], var_u_x)
            
            counter = 0
            while True:
                counter += 1
                y = gauss_gen(var= var_u, B= B, N = N, dim=dim) # proposal
                logS_ = logexp_rv(B= B, N= N)
                random_hash = ber_rv(B=B, N=N, L=L_max)

                #encoder selection
                k_selected_A, y_selected_A, out_message, ers_prob_a  = exp_sampler.select(logS_, y, mean_t = x, var_t= var_u_x, \
                                                                                 mean_p=mean_p, var_p=var_u, \
                                                                                 hash_val=random_hash, ers_selection=True, omega=omega)
                if debug:
                    print (ers_prob_a)
                rand_float = random.random()
                
                if rand_float < ers_prob_a: # or counter > 100:
                    k_selected_B, y_selected_B, _ , _  = exp_sampler.select(logS_, y, mean_t= mean_dec, var_t=var_dec, \
                                                                            mean_p=mean_p, var_p=var_u, \
                                                                            message=out_message, hash_val=random_hash)
                    all_ns+=n
                    match += k_selected_A[0,0]== k_selected_B[0,0]
                    distortion += mse(y_selected_B, x)
                    distortion_A += mse(y_selected_A, x)
                    break
                n+=1
    ESS = int(N*all_ns/n_trials)
    return (np.log2(L_max) + all_ns/(n_trials))/dim, match/n_trials, 10*np.log10(distortion.item()/n_trials), 10*np.log10(distortion_A.item()/n_trials), ESS


def PML(B, N, n_trials, dim, L_max, var_u=None, var_u_x=0.01, var_x=1.0, var_y_x=0.01, threshold=3.0, debug=False):
    var_x = torch.tensor(var_x).cuda()
    var_y_x = torch.tensor(var_y_x).cuda()
    var_side = var_x + var_y_x #variance of side info.

    var_u_x = var_u_x #desired distortion.
    threshold = threshold
    
    var_u_x = torch.tensor(var_u_x).cuda()
    
    if var_u == None:
        var_u = var_x + var_u_x #proposal variance
    var_u = torch.tensor(var_u).double().cuda()

    omegas = []
    n_batch = []
    y_selected = []


    avg_accept = 0.0
    all_ns = 0
    with torch.no_grad():
        for _ in [L_max]:
            match = 0
            distortion = 0
            distortion_A = 0
            x = gauss_gen(var= var_x, B= B, N = 1, dim= dim)
            mean_p=torch.zeros_like(x)
            total_runs = 0
            ys = []
            
            match = 0
            total_N = 0
            for i in tqdm(range(n_trials)):
                #input X
                n = 1
                x = gauss_gen(var= var_x, B= B, N = 1, dim= dim) # encoder input
                x[x> threshold] = threshold
                x[x<-threshold] = -threshold
                side_info = x + torch.sqrt(var_y_x)*torch.randn_like(x) # decoder side info
                mean_dec, var_dec = compute_decoder_target(side_info, var_side, var_x, var_u) #decoder target distribution.
                mean_dec = mean_dec.cuda()
                var_dec = var_dec.cuda()

                omega = 1
                for d in range(dim):
                    omega *= estimate_omega(0.0, var_u, x[0,0,d], var_u_x)

                counter = 0
                #print (i)
                t = 0
                n = 0
               


                temp_y = []
                temp_t = []
                temp_m = []
                #encoder selection:
                break_signal = False
                t0 = 0
                min_s = torch.inf
                message = -1
                prevargmins = -1
                
                while True:
                    total_N += 1
                    counter += 1
                    y = gauss_gen(var= var_u, B= B, N = N, dim=dim) # proposal
                    t = t0+ torch.cumsum(logexp_rv(B= B, N= N).exp(), dim=1)
                    random_hash = ber_rv(B=B, N=N, L=L_max) 
                    temp_y.append(y)
                    temp_t.append(t)
                    temp_m.append(random_hash)
                    log_proposal = gauss_log_p(y, mean_p, var_u).sum(dim=-1, keepdim=True)
                    log_target = gauss_log_p(y, x, var_u_x).sum(dim=-1, keepdim=True)
                    benchmark = torch.log(t) - torch.log(omega)
                    s = torch.log(t) + log_proposal - log_target
                    


                    s = s.flatten().cpu()
                    benchmark = benchmark.flatten()
                    min_s_c = s.min().item()

                    if min_s_c < min_s:
                        argmins = s.argmin().item()
                        min_s = min_s_c
                        min_y = y[:,argmins,:].clone()
                        message = random_hash[:,argmins].clone()
                    else:
                        argmins = 0
                    
                    max_benchmark = benchmark[argmins:].max()
                    if min_s < max_benchmark:
                        ys.append(min_y)
                        break_signal = True

                    if break_signal:
                        #print ('lol, message=', message)
                        break
                    else:
                        t0 = t[:,-1,:].item()
                y_A = min_y.clone()
                #decoder selection
                min_s = torch.inf
                break_signal=False
                past = len(temp_y)
                count = 0 
                while True:
                    count +=1
                    if min_s == torch.inf:
                        y = torch.cat(temp_y).view(1, past*N, dim)
                        t = torch.cat(temp_t).view(1, past*N, 1)
                        random_hash = torch.cat(temp_m).view(1, past*N, 1)
                        #print ('Stop:', y.shape, t.shape, random_hash.shape)
                    else:
                        y = gauss_gen(var= var_u, B= B, N = N, dim=dim) # proposal
                        t = t0+ torch.cumsum(logexp_rv(B= B, N= N).exp(), dim=1)
                        random_hash = ber_rv(B=B, N=N, L=L_max) 
                    
                    log_proposal = gauss_log_p(y, mean_p, var_u).sum(dim=-1, keepdim=True)
                    log_target = gauss_log_p(y, mean_dec, var_dec).sum(dim=-1, keepdim=True)
                    filtered_ = (random_hash==message[0])*1.0 + 1e-25
                    log_filtered_ = torch.log(filtered_)
                    benchmark = torch.log(t) - torch.log(omega)
                    
                    s = torch.log(t).flatten() + log_proposal.flatten() - log_target.flatten()- log_filtered_.flatten()
                    s = s.flatten().cpu()
                    benchmark = benchmark.flatten()
                    min_s_c = s.min().item()

                    if min_s_c < min_s:
                        argmins = s.argmin().item()
                        min_s = min_s_c
                        min_y = y[:,argmins,:].clone()
                    else:
                        argmins = 0
                    
                    max_benchmark = benchmark[argmins:].max()
                    if min_s < max_benchmark:
                        ys.append(min_y)
                        break_signal = True
                    
                    if break_signal:
                        break
                    else:
                        t0 = t[:,-1,:].item()
                
                y_B = min_y.clone()    
                diff = ((y_A - y_B)**2).sum()
                
                distortion+= mse(y_B, x).item()
                distortion_A += mse(y_A, x).item()
                if diff < 0.001:
                    match +=1
                torch.cuda.empty_cache()
                #print ('match:',match)
                    
    return np.log2(L_max)/dim, match/n_trials, 10*np.log10(distortion/n_trials), 10*np.log10(distortion_A/n_trials), total_N/n_trials

import time

def IS_WZ(B, N, n_trials, dim, L_max, var_u_x=0.01, var_u=None, var_x=1.0, var_y_x=0.01, threshold=3.0):
    #var_x: x variance
    #var_y: y variance
    var_x = torch.tensor(var_x).cuda()
    var_y_x = torch.tensor(var_y_x).cuda()
    var_side = var_x + var_y_x #variance of side info.

    var_u_x = var_u_x #desired distortion.
    threshold = threshold
    
    var_u_x = torch.tensor(var_u_x).cuda()
    
    if var_u == None:
        var_u = var_x + var_u_x #proposal variance
    var_u = torch.tensor(var_u).double().cuda()

    omegas = []
    n_batch = []
    y_selected = []

    avg_accept = 0.0
    all_ns = 0
    with torch.inference_mode():
        for _ in [L_max]:
            match = 0
            distortion = 0
            distortion_A = 0
            x = gauss_gen(var= var_x, B= B, N = 1, dim= dim)
            mean_p=torch.zeros_like(x)
            for i in tqdm(range(n_trials)):
                #start_time = time.time()
                n = 1
                x = gauss_gen(var= var_x, B= B, N = 1, dim= dim) # encoder input
                x[x> threshold] = threshold
                x[x<-threshold] = -threshold
                
                

                side_info = x + torch.sqrt(var_y_x)*torch.randn_like(x) # decoder side info
                
                
                y = gauss_gen(var= var_u, B= B, N = N, dim=dim) # proposal
                
                

                mean_dec, var_dec = compute_decoder_target(side_info, var_side, var_x, var_u) #decoder target distribution.
                mean_dec = mean_dec.cuda()
                var_dec = var_dec.cuda()
                logS_ = logexp_rv(B= B, N= N)
                random_hash = ber_rv(B=B, N=N, L=L_max)
                
                #start_time = time.time()
                k_selected_A, y_selected_A, out_message, _  = exp_sampler.select(logS_, y, mean_t = x, var_t= var_u_x, \
                                                                                 mean_p=mean_p, var_p=var_u, \
                                                                                 hash_val=random_hash, ers_selection=False)

                #print ('2:',time.time() - start_time)
                k_selected_B, y_selected_B, _ , _  = exp_sampler.select(logS_, y, mean_t= mean_dec, var_t=var_dec, \
                                                                        mean_p=mean_p, var_p=var_u, \
                                                                        message=out_message, hash_val=random_hash)
                
                all_ns+=n
                
                match += k_selected_A[0,0]== k_selected_B[0,0]
                distortion += mse(y_selected_B, x)
                distortion_A += mse(y_selected_A, x)
                n+=1
                
                #print (time.time() - start_time)
            
    
    return (np.log2(L_max))/dim, match/n_trials,   10*np.log10(distortion.item()/n_trials), 10*np.log10(distortion_A.item()/n_trials), N

    
import argparse
import json

def parse_arguments():
    parser = argparse.ArgumentParser(description="Run ERS_WZ_comm or IS_WZ experiments")
    parser.add_argument('--N_list', nargs='+', type=int, required=True, help="List of N values")
    parser.add_argument('--var_u_x', type=float, required=True, help="Variance parameter")
    parser.add_argument('--n_trials', type=int, required=True, help="Number of trials")
    parser.add_argument('--dim', type=int, required=True, help="Dimension")
    parser.add_argument('--L_max_list', nargs='+', type=int, required=True, help="List of L_max values")
    parser.add_argument('--B', 
                       type=int, 
                       default=1, 
                       help='B parameter')
    
    parser.add_argument('--threshold', 
                       type=float, 
                       default=3.0, 
                       help='Threshold value')
    parser.add_argument('--method', choices=['ERS_WZ_comm', 'IS_WZ', 'PML'], default='ERS_WZ_comm', 
                        help="Method to use: ERS_WZ_comm or IS_WZ (default: ERS_WZ_comm)")
    
    return parser.parse_args()

def main():
    args = parse_arguments()
    
    N_list_ers = args.N_list
    var_u_x = args.var_u_x
    n_trials = args.n_trials
    dim = args.dim
    L_max_list = args.L_max_list
    method = args.method
    
    # Create filename with parameters
    N_str = '_'.join(map(str, N_list_ers))
    L_max_str = '_'.join(map(str, L_max_list))
    filename = (f"{method}_N{N_str}_var{var_u_x}_trials{n_trials}_dim{dim}_"
                f"Lmax{L_max_str}_B{args.B}_thresh{args.threshold}.txt")
    
    ERS_N = {}
    
    for N in N_list_ers:
        ERS_data = []
        for L_max in L_max_list:
            if method == 'ERS_WZ_comm':
                ers_r, ers_m, ers_d, ers_da, ers_N = ERS_WZ_comm(
                    B=args.B,
                    N=N,
                    n_trials=n_trials,
                    dim=dim,
                    L_max=L_max,
                    var_u_x=var_u_x,
                    threshold=args.threshold
                )
                ERS_data.append([(ers_r, ers_m.item(), ers_d, ers_da, ers_N)])
            elif method == 'IS_WZ':
                # Assuming IS_WZ has similar parameters and return values
                # Adjust these if IS_WZ has different signature
                ers_r, ers_m, ers_d, ers_da, ers_N = IS_WZ(
                    B=args.B,
                    N=N,
                    n_trials=n_trials,
                    dim=dim,
                    L_max=L_max,
                    var_u_x=var_u_x,
                    threshold=args.threshold
                )
                ERS_data.append([(ers_r, ers_m.item(), ers_d, ers_da, ers_N)])
            elif method == 'PML':
                # Assuming IS_WZ has similar parameters and return values
                # Adjust these if IS_WZ has different signature
                ers_r, ers_m, ers_d, ers_da, ers_N = PML(
                    B=args.B,
                    N=N,
                    n_trials=n_trials,
                    dim=dim,
                    L_max=L_max,
                    var_u_x=var_u_x,
                    threshold=args.threshold
                )
                ERS_data.append([(ers_r, ers_m, ers_d, ers_da, ers_N)])
            print(L_max)
        ERS_N[N] = ERS_data
    
    # Writing results to file
    with open(filename, 'w') as f:
        # Write parameters
        f.write("Experiment Parameters:\n")
        f.write(f"Method: {method}\n")
        f.write(f"N_list: {N_list_ers}\n")
        f.write(f"var_u_x: {var_u_x}\n")
        f.write(f"n_trials: {n_trials}\n")
        f.write(f"dim: {dim}\n")
        f.write(f"L_max_list: {L_max_list}\n")
        f.write(f"B: {args.B}\n")
        f.write(f"threshold: {args.threshold}\n")
        f.write("\nResults:\n")
        
        # Write ERS_N data
        f.write(json.dumps(ERS_N, indent=2))
    
    print(f"Results written to {filename}")

if __name__ == "__main__":
    main()
