
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from subspace_demo import gen_samples
from subspace_algos import EstSubspace, compute_projected_vector
from dwork_subspace import Dwork_Subspace
from sklearn.utils.extmath import randomized_svd

def L2(est):
    return np.linalg.norm(est)

def Test_d(iterations=30):
    # Vary d. Fix parameters k, t, n.
    k = 4
    rho = [1, 1]
    t = 125
    n = 2*k*t
    delta = 1e-5
    beta = 0.05

    err_gaus_mech = []
    err_subspace = []
    err_dwork = []
    counters = []

    for sd in [30,60,90,120,150,180]:
        d = sd*sd
        means_gaus_mech = []
        means_subspace = []
        means_dwork = []
        for i in range(iterations):
            X = gen_samples(n, d, k, gap_factor = 10*d)
            print("******* sd = %d (iteration %d) *******" % (sd,i))
            _, s, _ = randomized_svd(X, n_components=k+1)
            #print(s)
            print("sigma_k/sigma_{k+1} = %d" % int(s[k-1]/s[k]))

            mean = np.mean(X, axis=0)
            gaus_mean = mean + np.random.normal(0, scale=np.sqrt(2/(rho[0]+rho[1]))/n, size=d)
            means_gaus_mech.append(L2(gaus_mean-mean))
            subspace_mean_tmp = mean + np.random.normal(0, scale = np.sqrt(2/rho[0])/n, size=d)
            Vt = EstSubspace(X, d, k, rho[1], delta, beta, r_min=1e-6, r_max=100, t=t)
            subspace_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_subspace.append(L2(subspace_mean-mean))
            Vt = Dwork_Subspace(X, d, k, rho[1], delta)
            dwork_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_dwork.append(L2(dwork_mean - mean))
            print("Gaus err = %f" % (L2(gaus_mean - mean)))
            print("subspace err = %f" % (L2(subspace_mean - mean)))
            print("subspace_tmp err = %f" % (L2(subspace_mean_tmp - mean)))
            print("dwork err = %f\n" % (L2(dwork_mean - mean)))


        err_gaus_mech.append(stats.trim_mean(means_gaus_mech,0.1))
        err_subspace.append(stats.trim_mean(means_subspace, 0.1))
        err_dwork.append(stats.trim_mean(means_dwork, 0.1))
        counters.append(sd)

        np.savetxt("./results/synthetic_mean/gaus_mech-d.txt", np.array(err_gaus_mech))
        np.savetxt("./results/synthetic_mean/subspace-d.txt", np.array(err_subspace))
        np.savetxt("./results/synthetic_mean/dwork-d.txt", np.array(err_dwork))
        np.savetxt("./results/synthetic_mean/counters-d.txt", np.array(counters))

        plt.title('Multivariate Mean Estimation')

    
    print("Done")



if __name__ == '__main__':
    Test_d()

