
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import stats
from subspace_demo import gen_samples
from subspace_algos import EstSubspace, compute_projected_vector
from dwork_subspace import Dwork_Subspace
from sklearn.utils.extmath import randomized_svd
import time

def L2(est):
    return np.linalg.norm(est)

def Test_k(iterations=30):
    # Vary k (and n). Fix parameters d, t.
    d = 10000
    rho = [1, 1]
    t = 125
    delta = 1e-5
    beta = 0.05

    err_gaus_mech = []
    err_subspace = []
    err_dwork = []
    counters = []

    for k in [2,4,6,8,10,12]:
        n = 2*k*t
        means_gaus_mech = []
        means_subspace = []
        means_dwork = []
        for i in range(iterations):
            X = gen_samples(n, d, k, gap_factor = 10*d)
            print("k = %d (iteration %d)" % (k,i))
            _, s, _ = randomized_svd(X, n_components=k+1)
            #print(s)
            print("**** sigma_k/sigma_{k+1} = %f ****" % float(s[k-1]/s[k]))

            mean = np.mean(X, axis=0)
            gaus_mean = mean + np.random.normal(0, scale=np.sqrt(2/(rho[0]+rho[1]))/n, size=d)
            means_gaus_mech.append(L2(gaus_mean-mean))
            subspace_mean_tmp = mean + np.random.normal(0, scale = np.sqrt(2/rho[0])/n, size=d)
            Vt = EstSubspace(X, d, k, rho[1], delta, beta, r_min=1e-10, r_max=1, t=t)
            subspace_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_subspace.append(L2(subspace_mean-mean))
            Vt = Dwork_Subspace(X, d, k, rho[1], delta)
            dwork_mean = compute_projected_vector(Vt, subspace_mean_tmp)
            means_dwork.append(L2(dwork_mean - mean))
            print("Gaus err = %f" % (L2(gaus_mean - mean)))
            print("subspace err = %f" % (L2(subspace_mean - mean)))
            print("subspace_tmp err = %f" % (L2(subspace_mean_tmp - mean)))
            print("dwork err = %f\n\n" % (L2(dwork_mean - mean)))


        err_gaus_mech.append(stats.trim_mean(means_gaus_mech,0.1))
        err_subspace.append(stats.trim_mean(means_subspace, 0.1))
        err_dwork.append(stats.trim_mean(means_dwork, 0.1))
        counters.append(k)

        np.savetxt("./results/synthetic_mean/gaus_mech-k.txt", np.array(err_gaus_mech))
        np.savetxt("./results/synthetic_mean/subspace-k.txt", np.array(err_subspace))
        np.savetxt("./results/synthetic_mean/dwork-k.txt", np.array(err_dwork))
        np.savetxt("./results/synthetic_mean/counters-k.txt", np.array(counters))

        plt.title('Multivariate Mean Estimation')
    
    print("Done")



if __name__ == '__main__':
    Test_k()

