import numpy as np
import time
from coinpress_mod.functions import coinpress_mean
from quantile_binary_search.method import random_rotation_mean, clipped_mean
from tests_functions import test_dim
from scipy import stats
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


n = 4000
d = 128
max_r_l = [i * np.sqrt(d) for i in [20.0, 65.0, 110.0, 155.0, 200.0]]
mean0_l = [0.0, 5.0, 10.0]
p = 0.5

for mean0 in mean0_l:
    n_trials = 20

    cp_mean_t1_err_l = []
    cp_mean_t2_err_l = []
    cp_mean_t3_err_l = []
    cp_mean_t4_err_l = []
    cp_mean_t10_err_l = []
    rr_err_l = []
    nonpr_err_l = []

    mean = [mean0] * d

    for max_r in max_r_l:
        print(mean0)
        r = max_r
        u = 2*r
        cp_mean_t1_l = []
        cp_mean_t2_l = []
        cp_mean_t3_l = []
        cp_mean_t4_l = []
        cp_mean_t10_l = []
        rr_l = []
        nonpr_l = []

        for _ in range(n_trials):
            cov = np.diag(np.random.uniform(0, 10, size=d))
            x = np.random.multivariate_normal(mean, cov, int(n))
            c = [0.0]*d

            cp_mean_t1 = coinpress_mean(x,c,r,1,p)
            cp_mean_t2 = coinpress_mean(x,c,r,2,p)
            cp_mean_t3 = coinpress_mean(x,c,r,3,p)
            cp_mean_t4 = coinpress_mean(x,c,r,4,p)
            cp_mean_t10 = coinpress_mean(x,c,r,10,p)

            cp_mean_t1_l.append(np.linalg.norm(cp_mean_t1-mean))
            cp_mean_t2_l.append(np.linalg.norm(cp_mean_t2-mean))
            cp_mean_t3_l.append(np.linalg.norm(cp_mean_t3-mean))
            cp_mean_t4_l.append(np.linalg.norm(cp_mean_t4-mean))
            cp_mean_t10_l.append(np.linalg.norm(cp_mean_t10-mean))
            nonpr_l.append(np.linalg.norm(np.mean(x, axis=0)-mean))

            #mean_clipped = clipped_mean(x,n,d,u,p,threshold=None)
            #print('clipped_mean error:', np.linalg.norm(mean_clipped-mean))

            y_hat = random_rotation_mean(x, d, u, p)
            rr_l.append(np.linalg.norm(y_hat-mean))

        cp_mean_t1_err_l.append(stats.trim_mean(cp_mean_t1_l, 0.1))
        cp_mean_t2_err_l.append(stats.trim_mean(cp_mean_t2_l, 0.1))
        cp_mean_t3_err_l.append(stats.trim_mean(cp_mean_t3_l, 0.1))
        cp_mean_t4_err_l.append(stats.trim_mean(cp_mean_t4_l, 0.1))
        cp_mean_t10_err_l.append(stats.trim_mean(cp_mean_t10_l, 0.1))
        rr_err_l.append(stats.trim_mean(rr_l, 0.1))
        nonpr_err_l.append(stats.trim_mean(nonpr_l, 0.1))

    np.savetxt('./test_R_mean{}.txt'.format(int(mean0)), np.array([cp_mean_t1_err_l, cp_mean_t2_err_l, cp_mean_t3_err_l, cp_mean_t4_err_l, cp_mean_t10_err_l, rr_err_l, nonpr_err_l]))

    '''
    res = np.loadtxt('./test_R_mean{}.txt'.format(int(mean0)))
    cp_mean_t1_err_l = res[0]
    cp_mean_t2_err_l = res[1]
    cp_mean_t3_err_l = res[2]
    cp_mean_t4_err_l = res[3]
    cp_mean_t10_err_l = res[4]
    rr_err_l = res[5]
    nonpr_err_l = res[6]
    '''

    plt.plot(max_r_l, cp_mean_t1_err_l, color='darkviolet', marker='+', label=r'COINPRESS $t=1$')
    plt.plot(max_r_l, cp_mean_t2_err_l, color='red', marker='+', label=r'COINPRESS $t=2$')
    plt.plot(max_r_l, cp_mean_t3_err_l, color='magenta', marker='+', label=r'COINPRESS $t=3$')
    plt.plot(max_r_l, cp_mean_t4_err_l, color='gray', marker='+', label=r'COINPRESS $t=4$')
    plt.plot(max_r_l, cp_mean_t10_err_l, color='olive', marker='+', label=r'COINPRESS $t=10$')
    plt.plot(max_r_l, rr_err_l, color='green', marker='*', label='Shifted-CM')
    plt.plot(max_r_l, nonpr_err_l, color='blue', marker='x', label='Non-private')

    plt.legend(loc='upper right', framealpha=0.4)
    plt.xlim((0.*np.sqrt(d), 1020*np.sqrt(d)))
    #plt.xscale('log', basex=2)
    plt.yscale('log')
    plt.xlabel(r'$R$',fontsize=16)
    plt.ylabel(r'$\ell_2$ error',fontsize=16)
    plt.grid()
    plt.savefig('./test_R_mean{}.pdf'.format(int(mean0)),bbox_inches='tight', pad_inches=0)
    plt.clf()



