
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from subspace_demo import gen_samples
from subspace_algos import EstSubspace, compute_projected_vector
from dwork_subspace import Dwork_Subspace
from sklearn.utils.extmath import randomized_svd

def L2(est):
    return np.linalg.norm(est)

def Test_g(iterations=30):
    # Vary gap. Fix parameters k, t, n.
    d = 10000
    k = 4
    rho = [1, 1]
    t = 125
    n = 2*k*t
    delta = 1e-5
    beta = 0.05
    

    err_gaus_mech = []
    err_subspace = []
    err_dwork = []
    counters = []
    

    for gap_factor in [10000,20000,30000,40000,50000,60000,70000]:

        means_gaus_mech = []
        means_subspace = []
        means_dwork = []
        gap = 0
        for i in range(iterations):
            print("gap_factor/d = %d (iteration %d)" % (gap_factor/d, i))
            X = gen_samples(n, d, k, gap_factor)
            _, s, _ = randomized_svd(X, n_components=k+1)
            gap = int(s[k-1]/s[k])
            print("**** sigma_k/sigma_{k+1} = %d ****" % gap)

            mean = np.mean(X, axis=0)
            gaus_mean = mean + np.random.normal(0, scale=np.sqrt(2/(rho[0]+rho[1]))/n, size=d)
            means_gaus_mech.append(L2(gaus_mean-mean))
            subspace_mean_tmp = mean + np.random.normal(0, scale = np.sqrt(2/rho[0])/n, size=d)
            Vt = EstSubspace(X, d, k, rho[1], delta, beta, r_min=1e-8, r_max=100, t=t)
            subspace_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_subspace.append(L2(subspace_mean-mean))
            Vt = Dwork_Subspace(X, d, k, rho[1], delta)
            dwork_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_dwork.append(L2(dwork_mean - mean))
            print("Gaus err = %f" % (L2(gaus_mean - mean)))
            print("subspace err = %f" % (L2(subspace_mean - mean)))
            print("subspace_tmp err = %f" % (L2(subspace_mean_tmp - mean)))
            print("dwork err = %f\n\n" % (L2(dwork_mean - mean)))


        err_gaus_mech.append(stats.trim_mean(means_gaus_mech,0.1))
        err_subspace.append(stats.trim_mean(means_subspace, 0.1))
        err_dwork.append(stats.trim_mean(means_dwork, 0.1))
        counters.append(int(gap_factor/np.power(10,4)))

        np.savetxt("./results/synthetic_mean/gaus_mech-g.txt", np.array(err_gaus_mech))
        np.savetxt("./results/synthetic_mean/subspace-g.txt", np.array(err_subspace))
        np.savetxt("./results/synthetic_mean/dwork-g.txt", np.array(err_dwork))
        np.savetxt("./results/synthetic_mean/counters-g.txt", np.array(counters))

        plt.title('Multivariate Mean Estimation')
    
    print("Done")



if __name__ == '__main__':
    Test_g()

