from processing.fourierize import empirical_cf_2d
from processing.music_2d import music_2d
import numpy as np
from mixture.gmm import GMM
import os
from model_selection.gmm_aic import model_selection_gmm_aic
from model_selection.gmm_bic import model_selection_gmm_bic
from multiprocessing import Pool

# n_trials = 2000

# separation_trials = np.zeros(n_trials)
# samplesize_trials = np.zeros(n_trials)
# results_music = np.zeros(n_trials)
# results_aic = np.zeros(n_trials)
# results_bic = np.zeros(n_trials)

# for i in range(n_trials):
#     print(f"Trial {i}/{n_trials}")
#     sample_size = 5 * np.random.rand() + 2
#     n_samples = int(10 ** sample_size)

#     separation_distance = 8.9 * np.random.rand() + 0.1

#     center_1 = [1 + separation_distance, 1]
#     center_2 = [1, 1 + separation_distance]
#     center_3 = [1 + separation_distance, 1 + separation_distance]
#     model = GMM(4, 2, [[1,1], center_1, center_2, center_3], np.eye(2))
#     samples = model.sample(n_samples)

#     domain_x = domain_y = np.linspace(0,1,11)
#     fourier_data = empirical_cf_2d(samples, domain_x, domain_y, modulate_term=np.eye(2))
#     result = music_2d(domain_x, domain_y, fourier_data, 5, (5,5), n_omegas=100, show_plot=False)

#     if result is None:
#         results_music[i] = 0
#     else:
#         results_music[i] = len(result)

#     results_aic[i] = model_selection_gmm_aic(samples, max_components=5)
#     results_bic[i] = model_selection_gmm_bic(samples, max_components=5)

#     separation_trials[i] = separation_distance
#     samplesize_trials[i] = sample_size


# np.save("phase transition/4/separations.npy", separation_trials)
# np.save("phase transition/4/results_music.npy", results_music)
# np.save("phase transition/4/results_aic.npy", results_aic)
# np.save("phase transition/4/results_bic.npy", results_bic)
# np.save("phase transition/4/samplesize.npy", samplesize_trials)

def process_trial(sample_size, separation_distance):
    # sample_size = 5 * np.random.rand() + 1
    n_samples = int(10 ** sample_size)
    # separation_distance = 5.9 * np.random.rand() + 0.1
    center_1 = [1 + separation_distance, 1]
    center_2 = [1, 1 + separation_distance]
    center_3 = [1 + separation_distance, 1 + separation_distance]
    model = GMM(4, 2, [[1,1], center_1, center_2, center_3], np.eye(2))

    samples = model.sample(n_samples)
    Omega = 1
    domain_x = domain_y = np.linspace(-Omega,Omega,11)
    fourier_data = empirical_cf_2d(samples, domain_x, domain_y, modulate_term=np.eye(2))
    result = music_2d(domain_x, domain_y, fourier_data, 5, (5,5), n_omegas=801, show_plot=False, weight_threshold=0.1, min_distance=0.2, freq_bound=[8,8])
    if result is None:
        results_music = 0
    else:
        results_music = len(result)
    results_aic = model_selection_gmm_aic(samples, max_components=4)
    results_bic = model_selection_gmm_bic(samples, max_components=4)

    print(f"Sample size: {n_samples}, separation distance: {separation_distance}, music: {results_music}, aic: {results_aic}, bic: {results_bic}")

    return separation_distance, sample_size, results_music, results_aic, results_bic


if __name__ == '__main__':
    n_trials = 96*30
    separations = 5.5 * np.random.rand(n_trials) + 0.5
    samplesizes = 5.0 * np.random.rand(n_trials) + 2.0
    print(samplesizes)
    pool = Pool(8)
    inputs = [(samplesizes[i],separations[i]) for i in range(n_trials)]
    results = pool.starmap(process_trial, inputs)
    pool.close()
    pool.join()
    separation_trials = np.zeros(n_trials)
    samplesize_trials = np.zeros(n_trials)
    results_music = np.zeros(n_trials)
    results_aic = np.zeros(n_trials)
    results_bic = np.zeros(n_trials)
    for i, result in enumerate(results):
        separation_trials[i], samplesize_trials[i], results_music[i], results_aic[i], results_bic[i] = result
        print(f"Trial {i}/{n_trials}: {result}")
    np.save("phase transition/4/separations.npy", separation_trials)
    np.save("phase transition/4/results_music.npy", results_music)
    np.save("phase transition/4/results_aic.npy", results_aic)
    np.save("phase transition/4/results_bic.npy", results_bic)
    np.save("phase transition/4/samplesize.npy", samplesize_trials)