import numpy as np
import pandas as pd
import os
from scipy.stats import trim_mean
from coinpress_mod.algos import multivariate_mean_iterative_n, multivariate_mean_iterative, L2


def coinpress_mean(X, c, r, t, p, func=multivariate_mean_iterative):
    #p = eps*eps*0.5
    if t==1:
        Ps = [p]
    else:
        Ps = [(1.0/4.0/(t-1))*p for i in range(t-1)]
        Ps.append((3.0/4.0)*p)
    mean = func(X.copy(), c, r, t, Ps)
    return mean

# def generate_gauss_paths(mean,cov,nPaths,nSamples,d,seed=123,fileOut=''):
#     np.random.seed(seed=seed)
#     X = []
#     for i in range(nPaths):
#         X.append(np.random.multivariate_normal(mean, cov, int(nSamples)))
#     return X
    
# def generate_laplace_paths(mean,cov,nPaths,nSamples,d,seed=123,fileOut=''):
#     np.random.seed(seed=seed)
#     X = []
#     scale = np.sqrt(cov[0,0]/2)
#     for i in range(nPaths):
#         X.append(np.random.laplace(mean[0], scale, (int(nSamples),d)))
#     return X

# def test_gauss(mean,cov,eps,nPaths,nSamples,d,c,r,bGenerate=True,folder=''):
#     if bGenerate:
#         X_paths = generate_gauss_paths(mean,cov,nPaths,nSamples,d)
#         if not(folder==''):
#             filename='gauss_'+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
#             save_paths(X_paths,nPaths,nSamples,d,folder,filename)
#     else:
#         filename='gauss_'+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
#         X_paths = load_paths(folder+filename,nPaths,nSamples,d)
#     non_pr = []
#     means_t1 = []
#     means_t2 = []
#     means_t3 = []
#     means_t4 = []
#     means_t10 = []

#     means_t1_n = []
#     means_t2_n = []
#     means_t3_n = []
#     means_t4_n = []
#     means_t10_n = []
#     #print('non_pr,means_t1,means_t1_n,means_t2,means_t2_n,means_t3,means_t3_n,means_t4,means_t4_n\n')
#     for i in range(nPaths):
#         X = X_paths[i]
#         non_pr.append(L2(np.mean(X, axis=0)-mean))
#         means_t1.append(L2(coinpress_mean(X.copy(), c, r, 1, eps, func=multivariate_mean_iterative)-mean))
#         means_t1_n.append(L2(coinpress_mean(X.copy(), c, r, 1, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t2.append(L2(coinpress_mean(X.copy(), c, r, 2, eps, func=multivariate_mean_iterative)-mean))
#         means_t2_n.append(L2(coinpress_mean(X.copy(), c, r, 2, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t3.append(L2(coinpress_mean(X.copy(), c, r, 3, eps, func=multivariate_mean_iterative)-mean))
#         means_t3_n.append(L2(coinpress_mean(X.copy(), c, r, 3, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t4.append(L2(coinpress_mean(X.copy(), c, r, 4, eps, func=multivariate_mean_iterative)-mean))
#         means_t4_n.append(L2(coinpress_mean(X.copy(), c, r, 4, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t10.append(L2(coinpress_mean(X.copy(), c, r, 10, eps, func=multivariate_mean_iterative)-mean))
#         means_t10_n.append(L2(coinpress_mean(X.copy(), c, r, 10, eps, func=multivariate_mean_iterative_n)-mean))
#         #strOut = ','.join([str(means_t1[i]),str(means_t1_n[i]),str(means_t2[i]),str(means_t2_n[i]),str(means_t3[i]),str(means_t3_n[i]),str(means_t4[i]),str(means_t4_n[i])])
#     #return np.mean(non_pr), np.mean(means_t1), np.mean(means_t1_n), np.mean(means_t2), np.mean(means_t2_n), np.mean(means_t3), np.mean(means_t3_n), np.mean(means_t4), np.mean(means_t4_n)
#     return non_pr, means_t1, means_t1_n, means_t2, means_t2_n, means_t3, means_t3_n, means_t4, means_t4_n, means_t10, means_t10_n

# def test_laplace(mean,cov,eps,nPaths,nSamples,d,c,r,bGenerate=True,folder=''):
#     if bGenerate:
#         X_paths = generate_laplace_paths(mean,cov,nPaths,nSamples,d)
#         if not(folder==''):
#             filename='laplace_'+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
#             save_paths(X_paths,nPaths,nSamples,d,folder,filename)
#     else:
#         filename='laplace_'+str(nPaths)+'paths_'+str(nSamples)+'samples_dim'+str(d)+'.txt'
#         X_paths = load_paths(folder+filename,nPaths,nSamples,d)
#     non_pr = []
#     means_t1 = []
#     means_t2 = []
#     means_t3 = []
#     means_t4 = []
#     means_t10 = []

#     means_t1_n = []
#     means_t2_n = []
#     means_t3_n = []
#     means_t4_n = []
#     means_t10_n = []
#     #print('non_pr,means_t1,means_t1_n,means_t2,means_t2_n,means_t3,means_t3_n,means_t4,means_t4_n\n')
#     for i in range(nPaths):
#         X = X_paths[i]
#         non_pr.append(L2(np.mean(X, axis=0)-mean))
#         means_t1.append(L2(coinpress_mean(X.copy(), c, r, 1, eps, func=multivariate_mean_iterative)-mean))
#         means_t1_n.append(L2(coinpress_mean(X.copy(), c, r, 1, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t2.append(L2(coinpress_mean(X.copy(), c, r, 2, eps, func=multivariate_mean_iterative)-mean))
#         means_t2_n.append(L2(coinpress_mean(X.copy(), c, r, 2, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t3.append(L2(coinpress_mean(X.copy(), c, r, 3, eps, func=multivariate_mean_iterative)-mean))
#         means_t3_n.append(L2(coinpress_mean(X.copy(), c, r, 3, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t4.append(L2(coinpress_mean(X.copy(), c, r, 4, eps, func=multivariate_mean_iterative)-mean))
#         means_t4_n.append(L2(coinpress_mean(X.copy(), c, r, 4, eps, func=multivariate_mean_iterative_n)-mean))
#         means_t10.append(L2(coinpress_mean(X.copy(), c, r, 10, eps, func=multivariate_mean_iterative)-mean))
#         means_t10_n.append(L2(coinpress_mean(X.copy(), c, r, 10, eps, func=multivariate_mean_iterative_n)-mean))
#         #strOut = ','.join([str(means_t1[i]),str(means_t1_n[i]),str(means_t2[i]),str(means_t2_n[i]),str(means_t3[i]),str(means_t3_n[i]),str(means_t4[i]),str(means_t4_n[i])])
#     #return np.mean(non_pr), np.mean(means_t1), np.mean(means_t1_n), np.mean(means_t2), np.mean(means_t2_n), np.mean(means_t3), np.mean(means_t3_n), np.mean(means_t4), np.mean(means_t4_n)
#     return non_pr, means_t1, means_t1_n, means_t2, means_t2_n, means_t3, means_t3_n, means_t4, means_t4_n, means_t10, means_t10_n


# def test_dim(eps,nPaths,nSamples,d0=64,dPoints=5,bGenerate=True,folder='',test_dist=test_gauss,prefix=''):
#     d = d0/2
#     dims = []
#     err_nonpr = []
#     err_t1 = []
#     err_t1n = []
#     err_t2 = []
#     err_t2n = []
#     err_t3 = []
#     err_t3n = []
#     err_t4 = []
#     err_t4n = []
#     err_t10 = []
#     err_t10n = []

#     err_nonpr_paths = []
#     err_t1_paths = []
#     err_t1n_paths = []
#     err_t2_paths = []
#     err_t2n_paths = []
#     err_t3_paths = []
#     err_t3n_paths = []
#     err_t4_paths = []
#     err_t4n_paths = []
#     err_t10_paths = []
#     err_t10n_paths = []
#     for i in range(dPoints):
#         non_pr = []
#         mean_t1 = []
#         mean_t1n = [] 
#         mean_t2 = [] 
#         mean_t2n = []
#         mean_t3 = []
#         mean_t3n = []
#         mean_t4 = []
#         mean_t4n = []
#         d = int(d*2)
#         print(d)
#         dims.append(d)
#         mean = [0.0]*d
#         cov = np.eye(d)
#         c = [0]*d
#         r = 10*np.sqrt(d)
#         non_pr, mean_t1, mean_t1n, mean_t2, mean_t2n, mean_t3, mean_t3n, mean_t4, mean_t4n, mean_t10, mean_t10n = test_dist(mean,cov,eps,nPaths,nSamples,d,c,r,bGenerate=bGenerate,folder='')
#         err_nonpr.append(np.mean(non_pr))
#         err_t1.append(np.mean(mean_t1))
#         err_t2.append(np.mean(mean_t2))
#         err_t3.append(np.mean(mean_t3))
#         err_t4.append(np.mean(mean_t4))
#         err_t10.append(np.mean(mean_t10))
#         err_t1n.append(np.mean(mean_t1n))
#         err_t2n.append(np.mean(mean_t2n))
#         err_t3n.append(np.mean(mean_t3n))
#         err_t4n.append(np.mean(mean_t4n))
#         err_t10n.append(np.mean(mean_t10n))
#         #strOut = ','.join([str(),str(),str(),str(),str(),str(),str(),str(),str()])
#         #strOut_trim = ','.join([str(trim_mean(non_pr,0.1)),str(trim_mean(mean_t1,0.1)),str(trim_mean(mean_t1n,0.1)),str(trim_mean(mean_t2,0.1)),str(trim_mean(mean_t2n,0.1)),str(trim_mean(mean_t3,0.1)),str(trim_mean(mean_t3n,0.1)),str(trim_mean(mean_t4,0.1)),str(trim_mean(mean_t4n,0.1))])
#         #print(strOut)
#         #print(strOut_trim)
#         err_nonpr_paths.append(non_pr)
#         err_t1_paths.append(mean_t1)
#         err_t2_paths.append(mean_t2)
#         err_t3_paths.append(mean_t3)
#         err_t4_paths.append(mean_t4)
#         err_t10_paths.append(mean_t10)
#         err_t1n_paths.append(mean_t1n)
#         err_t2n_paths.append(mean_t2n)
#         err_t3n_paths.append(mean_t3n)
#         err_t4n_paths.append(mean_t4n)
#         err_t10n_paths.append(mean_t10n)
#     strfolder = folder + str(nPaths) + prefix + 'paths_' + str(nSamples) + 'samples_' + str(d0) + '_' +str(dPoints) + '/'
#     write_output(dims,err_nonpr,strfolder,'err_nonpr.txt')
#     write_output(dims,err_t1,strfolder,'err_t1.txt')
#     write_output(dims,err_t1n,strfolder,'err_t1n.txt')
#     write_output(dims,err_t2,strfolder,'err_t2.txt')
#     write_output(dims,err_t2n,strfolder,'err_t2n.txt')
#     write_output(dims,err_t3,strfolder,'err_t3.txt')
#     write_output(dims,err_t3n,strfolder,'err_t3n.txt')
#     write_output(dims,err_t4,strfolder,'err_t4.txt')
#     write_output(dims,err_t4n,strfolder,'err_t4n.txt')
#     write_output(dims,err_t10,strfolder,'err_t10.txt')
#     write_output(dims,err_t10n,strfolder,'err_t10n.txt')

#     write_output(dims,err_nonpr_paths,strfolder,'err_nonpr_paths.txt')
#     write_output(dims,err_t1_paths,strfolder,'err_t1_paths.txt')
#     write_output(dims,err_t1n_paths,strfolder,'err_t1n_paths.txt')
#     write_output(dims,err_t2_paths,strfolder,'err_t2_paths.txt')
#     write_output(dims,err_t2n_paths,strfolder,'err_t2n_paths.txt')
#     write_output(dims,err_t3_paths,strfolder,'err_t3_paths.txt')
#     write_output(dims,err_t3n_paths,strfolder,'err_t3n_paths.txt')
#     write_output(dims,err_t4_paths,strfolder,'err_t4_paths.txt')
#     write_output(dims,err_t4n_paths,strfolder,'err_t4n_paths.txt')
#     write_output(dims,err_t10_paths,strfolder,'err_t10_paths.txt')
#     write_output(dims,err_t10n_paths,strfolder,'err_t10n_paths.txt')

# def test_sigma(eps,nPaths,nSamples,sigma0=10,d=64,sPoints=5,bGenerate=True,folder='',test_dist=test_gauss,prefix=''):
#     sigma = sigma0/10.0
#     sigmas = []
#     err_nonpr = []
#     err_t1 = []
#     err_t1n = []
#     err_t2 = []
#     err_t2n = []
#     err_t3 = []
#     err_t3n = []
#     err_t4 = []
#     err_t4n = []
#     err_t10 = []
#     err_t10n = []

#     err_nonpr_paths = []
#     err_t1_paths = []
#     err_t1n_paths = []
#     err_t2_paths = []
#     err_t2n_paths = []
#     err_t3_paths = []
#     err_t3n_paths = []
#     err_t4_paths = []
#     err_t4n_paths = []
#     err_t10_paths = []
#     err_t10n_paths = []

#     for i in range(sPoints):
#         non_pr = []
#         mean_t1 = []
#         mean_t1n = [] 
#         mean_t2 = [] 
#         mean_t2n = []
#         mean_t3 = []
#         mean_t3n = []
#         mean_t4 = []
#         mean_t4n = []
#         sigma = sigma*10
#         print(sigma)
#         sigmas.append(sigma)
#         mean = [10.0]*d
#         cov = sigma*np.eye(d)
#         c = [5.0]*d
#         r = 18*np.sqrt(d*sigma)
#         non_pr, mean_t1, mean_t1n, mean_t2, mean_t2n, mean_t3, mean_t3n, mean_t4, mean_t4n, mean_t10, mean_t10n = test_dist(mean,cov,eps,nPaths,nSamples,d,c,r,bGenerate=bGenerate,folder='')#test_gauss(mean,cov,eps,nPaths,nSamples,d,c,r,bGenerate=bGenerate,folder='')
#         err_nonpr.append(np.mean(non_pr))
#         err_t1.append(np.mean(mean_t1))
#         err_t2.append(np.mean(mean_t2))
#         err_t3.append(np.mean(mean_t3))
#         err_t4.append(np.mean(mean_t4))
#         err_t10.append(np.mean(mean_t10))
#         err_t1n.append(np.mean(mean_t1n))
#         err_t2n.append(np.mean(mean_t2n))
#         err_t3n.append(np.mean(mean_t3n))
#         err_t4n.append(np.mean(mean_t4n))
#         err_t10n.append(np.mean(mean_t10n))
#         err_nonpr_paths.append(non_pr)
#         err_t1_paths.append(mean_t1)
#         err_t2_paths.append(mean_t2)
#         err_t3_paths.append(mean_t3)
#         err_t4_paths.append(mean_t4)
#         err_t10_paths.append(mean_t10)
#         err_t1n_paths.append(mean_t1n)
#         err_t2n_paths.append(mean_t2n)
#         err_t3n_paths.append(mean_t3n)
#         err_t4n_paths.append(mean_t4n)
#         err_t10n_paths.append(mean_t10n)
#     strfolder = folder + str(nPaths) + prefix + 'paths_' + str(nSamples) + 'samples_' + str(sigma0) + 'sigma0_' +str(sPoints) + '_dim'+str(d)+'/'
#     write_output(sigmas,err_nonpr,strfolder,'err_nonpr.txt')
#     write_output(sigmas,err_t1,strfolder,'err_t1.txt')
#     write_output(sigmas,err_t1n,strfolder,'err_t1n.txt')
#     write_output(sigmas,err_t2,strfolder,'err_t2.txt')
#     write_output(sigmas,err_t2n,strfolder,'err_t2n.txt')
#     write_output(sigmas,err_t3,strfolder,'err_t3.txt')
#     write_output(sigmas,err_t3n,strfolder,'err_t3n.txt')
#     write_output(sigmas,err_t4,strfolder,'err_t4.txt')
#     write_output(sigmas,err_t4n,strfolder,'err_t4n.txt')
#     write_output(sigmas,err_t10,strfolder,'err_t10.txt')
#     write_output(sigmas,err_t10n,strfolder,'err_t10n.txt')

#     write_output(sigmas,err_nonpr_paths,strfolder,'err_nonpr_paths.txt')
#     write_output(sigmas,err_t1_paths,strfolder,'err_t1_paths.txt')
#     write_output(sigmas,err_t1n_paths,strfolder,'err_t1n_paths.txt')
#     write_output(sigmas,err_t2_paths,strfolder,'err_t2_paths.txt')
#     write_output(sigmas,err_t2n_paths,strfolder,'err_t2n_paths.txt')
#     write_output(sigmas,err_t3_paths,strfolder,'err_t3_paths.txt')
#     write_output(sigmas,err_t3n_paths,strfolder,'err_t3n_paths.txt')
#     write_output(sigmas,err_t4_paths,strfolder,'err_t4_paths.txt')
#     write_output(sigmas,err_t4n_paths,strfolder,'err_t4n_paths.txt')
#     write_output(sigmas,err_t10_paths,strfolder,'err_t10_paths.txt')
#     write_output(sigmas,err_t10n_paths,strfolder,'err_t10n_paths.txt')

# def write_output(x,y,folder,filename):
#     if not os.path.isdir(folder):
#         os.makedirs(folder)
#     if (len(np.array(y).shape)) == 1:
#         out = [x,y]
#     else:
#         out = []
#         for i in range(len(x)):
#             out.append([x[i]])
#             out[i].extend(y[i])
#     np.savetxt(folder+filename,np.transpose(out))

# def save_paths(X,nPaths,nSamples,d,folder,filename):
#     if not os.path.isdir(folder):
#         os.makedirs(folder)
#     Y = np.reshape(X,(nPaths,d*nSamples))
#     np.savetxt(folder+filename,np.transpose(Y))

# def load_paths(filename,nPaths,nSamples,d):
#     data = np.transpose(np.genfromtxt(filename))
#     X = np.reshape(data,(nPaths,nSamples,d))
#     return X




    



