from scipy.stats import binom
from scipy.stats import norm

import matplotlib.pyplot as plt
import numpy as np
from statistics import NormalDist
import math
import tensorflow as tf
import time 
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=6000)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

def sparsification_with_prob(value,A,B,d):
    random_vec = np.random.rand(d)
    compare_vec = np.less(random_vec,np.divide(np.add(value,A),2*A))
    quantized_vec = np.multiply(np.subtract(compare_vec.astype(int),0.5*np.ones(d)),2)
    
    random_vec = np.random.rand(d)
    compare_vec = np.less(random_vec,np.ones(d)*A/B)
    sparsified_vec = np.multiply(quantized_vec,compare_vec.astype(int))

    return sparsified_vec


def prob_correct_aggregation(true_value,aggregated_value,d):
    num_correct_signs = 0
    for i in range(d):
        if aggregated_value[i] == 0:
            num_correct_signs += 0.5
        else:
            if np.sign(true_value[i]) == np.sign(aggregated_value[i]):
                num_correct_signs += 1
    return num_correct_signs/d


def gaussian_sparsification_with_prob(value,sigma,A,B,d):
    random_vec = np.random.rand(d)
    compare_vec = np.less(random_vec,np.ones(d)*A/B)
    sparsified_vec = np.multiply(np.multiply(np.add(value,np.random.normal(0,sigma,d)),compare_vec.astype(int)),B/A)
    
    return sparsified_vec

time1 = time.time()
MC = 10000
N = 1000
d = 250
c_inf = 1/np.sqrt(d)
c_2 = 1
sigma_set = [1/5,1/4,1/3,1/2,1,2,3,4,5]
mean = 1e-4

sparsification_ratio_set = [0.2,0.4,0.6,0.8,1.0]


# true_value = np.zeros(d)
# user_value = []
# for m in range(N):
#     rand_value = np.random.rand(d)/np.sqrt(d) + (mean-0.5/np.sqrt(d))*np.ones(d)
#     clipped_rand_value = np.clip(rand_value,-c_inf,c_inf)
#     user_value.append(clipped_rand_value)
#     true_value = np.add(true_value,clipped_rand_value)

# true_value = np.divide(true_value,N)

[user_value,true_value] = np.load('record_mean0_0001_uservalue.npy',allow_pickle=True)

ternary_mean_results = []
ternary_sigma_results = []
ternary_variance_results = []

gaussian_mean_results = []
gaussian_sigma_results = []
gaussian_variance_results = []

for sparsity_index in range(len(sparsification_ratio_set)):
    sparsity = sparsification_ratio_set[sparsity_index]
    ternary_mean_results.append([])
    ternary_sigma_results.append([])
    ternary_variance_results.append([])
    gaussian_mean_results.append([])
    gaussian_sigma_results.append([])
    gaussian_variance_results.append([])
    for sigma_index in range(len(sigma_set)):
        sigma = sigma_set[sigma_index]
        ternary_mean_results[sparsity_index].append([])
        ternary_sigma_results[sparsity_index].append([])
        ternary_variance_results[sparsity_index].append([])
        gaussian_mean_results[sparsity_index].append([])
        gaussian_sigma_results[sparsity_index].append([])
        gaussian_variance_results[sparsity_index].append([])
        
        
        A = np.sqrt((c_inf**2+ (sigma*2*c_2)**2)*sparsity)
        B = A/sparsity
        if A < c_inf:
            print('A cannot be smaller than c_inf')
        
        ternary_aggregated = []
        gaussian_aggregated = []
        ternary_difference = []
        gaussian_difference = []

        ternary_variance = np.zeros(d)
        gaussian_variance = np.zeros(d)
        for mc in range(MC):
            ternary_cum_correct_agg_prob = 0
            gaussian_cum_correct_agg_prob = 0
            aggregated_value = np.zeros(d)
            gaussian_aggregated_value = np.zeros(d)
            for m in range(N):
                ternary_value = sparsification_with_prob(user_value[m],A,B,d)
                aggregated_value = np.add(aggregated_value,ternary_value)
                sigma_ = sigma*2*c_2
                gaussian_value = gaussian_sparsification_with_prob(user_value[m],sigma_,A,B,d)
                gaussian_aggregated_value = np.add(gaussian_aggregated_value,gaussian_value)
            
            aggregated_value = np.multiply(aggregated_value,B/N)
            ternary_aggregated.append(aggregated_value)
            
            ternary_variance = np.add(ternary_variance,np.square(np.subtract(aggregated_value,true_value)))
            
            gaussian_aggregated_value = np.divide(gaussian_aggregated_value,N)
            gaussian_aggregated.append(gaussian_aggregated_value)
            
            gaussian_variance = np.add(gaussian_variance,np.square(np.subtract(gaussian_aggregated_value,true_value)))
            
        
                
        ternary_aggregated_mean = np.array(ternary_aggregated).astype(float).mean(axis=0)
        ternary_aggregated_sigma = np.array(ternary_aggregated).astype(float).std(axis=0)
        ternary_aggregated_variance = np.sum(ternary_variance)/MC
        
        
        gaussian_aggregated_mean = np.array(gaussian_aggregated).astype(float).mean(axis=0)
        gaussian_aggregated_sigma = np.array(gaussian_aggregated).astype(float).std(axis=0)
        gaussian_aggregated_variance = np.sum(gaussian_variance)/MC
        
        
        ternary_mean_results[sparsity_index][sigma_index].append(ternary_aggregated_mean)
        ternary_sigma_results[sparsity_index][sigma_index].append(ternary_aggregated_sigma)
        ternary_variance_results[sparsity_index][sigma_index].append(ternary_aggregated_variance)
        gaussian_mean_results[sparsity_index][sigma_index].append(gaussian_aggregated_mean)
        gaussian_sigma_results[sparsity_index][sigma_index].append(gaussian_aggregated_sigma)
        gaussian_variance_results[sparsity_index][sigma_index].append(gaussian_aggregated_variance)

time2=time.time()
print(time2-time1)

np.save('record_uniformsqrtd_mean0_0001_10000mc',[sigma_set,mean,sparsification_ratio_set,ternary_mean_results,ternary_sigma_results,ternary_variance_results,gaussian_mean_results,gaussian_sigma_results,gaussian_variance_results])

fig,ax = plt.subplots()
ax.plot([1/sigma for sigma in sigma_set],[gaussian_variance_results[0][i][0] for i in range(len(sigma_set))], marker = '*', linestyle = '--',color='blue',label=r"Gaussian-Sparsity Ratio 0.2")
ax.plot([1/sigma for sigma in sigma_set],[gaussian_variance_results[1][i][0] for i in range(len(sigma_set))], marker = '^', linestyle = '--',color='black',label=r"Gaussian-Sparsity Ratio 0.4")
ax.plot([1/sigma for sigma in sigma_set],[gaussian_variance_results[2][i][0] for i in range(len(sigma_set))], marker = '+', linestyle = '--',color='red',label=r"Gaussian-Sparsity Ratio 0.6")
ax.plot([1/sigma for sigma in sigma_set],[gaussian_variance_results[3][i][0] for i in range(len(sigma_set))], marker = 'P', linestyle = '--',color='green',label=r"Gaussian-Sparsity Ratio 0.8")
ax.plot([1/sigma for sigma in sigma_set],[gaussian_variance_results[4][i][0] for i in range(len(sigma_set))], marker = 'H', linestyle = '--',color='purple',label=r"Gaussian-Sparsity Ratio 1.0")
ax.plot([1/sigma for sigma in sigma_set],[ternary_variance_results[0][i][0] for i in range(len(sigma_set))], marker = '*', linestyle = 'solid',color='blue',label=r"Ternary-Sparsity Ratio 0.2")
ax.plot([1/sigma for sigma in sigma_set],[ternary_variance_results[1][i][0] for i in range(len(sigma_set))], marker = '^', linestyle = 'solid',color='black',label=r"Ternary-Sparsity Ratio 0.4")
ax.plot([1/sigma for sigma in sigma_set],[ternary_variance_results[2][i][0] for i in range(len(sigma_set))], marker = '+', linestyle = 'solid',color='red',label=r"Ternary-Sparsity Ratio 0.6")
ax.plot([1/sigma for sigma in sigma_set],[ternary_variance_results[3][i][0] for i in range(len(sigma_set))], marker = 'P', linestyle = 'solid',color='green',label=r"Ternary-Sparsity Ratio 0.8")
ax.plot([1/sigma for sigma in sigma_set],[ternary_variance_results[4][i][0] for i in range(len(sigma_set))], marker = 'H', linestyle = 'solid',color='purple',label=r"Ternary-Sparsity Ratio 1.0")
plt.legend(loc=4)
plt.xlabel(r'$\mu$')
plt.ylabel(r'MSE')
plt.axis([0,5,0,30])
plt.rcParams['font.size'] = 12
plt.legend(loc='center left', bbox_to_anchor=(0.35, 0.55))

