from processing.fourierize import empirical_cf_2d
from processing.music_2d import music_2d
from mixture.gmm import GMM
from reduction.pca import PCA
from mixture.em import EM
import numpy as np
import os
import time
from multiprocessing import Pool
from scipy.stats import wasserstein_distance_nd


def process_trial(n_samples, center):
    ndim = 100
    model = GMM(3, ndim, [-center, np.zeros(ndim), center], np.eye(ndim))
    em = EM(3, eps=1e-6)
    em_pca = EM(3, eps=1e-6)
    samples = model.sample(n_samples)

    music_start = time.time()
    pca = PCA(2)
    pca.fit(samples)
    reduced_samples = pca.transform(samples)
    Omega = 1
    domain_x = domain_y = np.linspace(-Omega, Omega, 11)    
    fourier_data = empirical_cf_2d(reduced_samples, domain_x, domain_y, modulate_term=np.eye(2))
    results, music_weights = music_2d(domain_x, domain_y, fourier_data, 3, (3,3), n_omegas=300, freq_bound=[-3,3], show_plot=False, centralShift=True, weight_threshold=1/6, min_distance=0.5, returnWeights=True)
    music_centers = pca.inverse_transform(np.asarray(results))
    music_end = time.time()
    print(f"music finished in {music_end - music_start} seconds, with error {wasserstein_distance_nd(music_centers, [-center, np.zeros(ndim), center], u_weights=music_weights, v_weights=[1/3, 1/3, 1/3])}")

    em_start = time.time()
    model_em = em.estimate_fixed_sigma(samples, sigma=np.eye(ndim))
    em_end = time.time()
    em_centers, em_weights = model_em.centers, model_em.weights
    print(f"em finished in {em_end - em_start} seconds")

    em_pca_start = time.time()
    pca = PCA(2)
    pca.fit(samples)
    reduced_samples = pca.transform(samples)
    model_em_pca = em_pca.estimate_fixed_sigma(reduced_samples, sigma=np.eye(2))
    em_centers_pca, em_weights_pca = model_em_pca.centers, model_em_pca.weights
    em_centers_pca = pca.inverse_transform(em_centers_pca)
    em_pca_end = time.time()
    print(f"em_pca finished in {em_pca_end - em_pca_start} seconds")

    print(f"Sample size: {n_samples}===========================")
    print(len(music_centers), len(em_centers))
    print(len(music_weights), len(em_weights))
    return wasserstein_distance_nd(music_centers, [-center, np.zeros(ndim), center], u_weights=music_weights, v_weights=[1/3, 1/3, 1/3]),\
          wasserstein_distance_nd(em_centers, [-center, np.zeros(ndim), center], em_weights, [1/3, 1/3, 1/3]),\
        wasserstein_distance_nd(em_centers_pca, [-center, np.zeros(ndim), center], em_weights_pca, [1/3, 1/3, 1/3]),\
             music_end - music_start, em_end - em_start, em_pca_end - em_pca_start

if __name__ == "__main__":
    n_trials = 96
    sample_sizes = np.arange(10000, 210000, 10000)
    records = np.zeros((len(sample_sizes), n_trials, 6))
    center = np.random.uniform(low=-1, high=1, size=100)
    radius = 1
    center = center / np.linalg.norm(center) * radius
    
    for i, n_samples in enumerate(sample_sizes):
        centers = np.random.uniform(low=-1, high=1, size=(n_trials, 100))
        centers = centers / np.linalg.norm(centers, axis=1).reshape(-1, 1) * radius
        pool = Pool(96)
        inputs = [(n_samples, centers[i]) for i in range(n_trials)]
        results = pool.starmap(process_trial, inputs)
        pool.close()
        pool.join()

        for j, result in enumerate(results):
            records[i, j] = result
        print(np.mean(records[i, :, 0], axis=0), np.mean(records[i, :, 1], axis=0), np.mean(records[i, :, 2], axis=0), np.mean(records[i, :, 3], axis=0), np.mean(records[i, :, 4], axis=0), np.mean(records[i, :, 5], axis=0))
    
        np.save("comparisonEM/records_three_1_1_pca.npy", records)




    


    
    