from processing.fourierize import empirical_cf_2d
from processing.music_2d import music_2d
import numpy as np
from mixture.gmm import GMM
import os
from model_selection.gmm_aic import model_selection_gmm_aic
from model_selection.gmm_bic import model_selection_gmm_bic
from multiprocessing import Pool


def process_trial(sample_size, separation_distance):
    
    n_samples = int(10 ** sample_size)
    model = GMM(2, 2, [[1,1], [1+separation_distance/np.sqrt(2),1+separation_distance/np.sqrt(2)]], np.eye(2))
    samples = model.sample(n_samples)
    Omega = np.sqrt(3)
    domain_x = np.linspace(-Omega,Omega,7)
    domain_y = np.linspace(-Omega,Omega,7)
    fourier_data = empirical_cf_2d(samples, domain_x, domain_y, modulate_term=np.eye(2))
    result = music_2d(domain_x, domain_y, fourier_data, 3, (4,4), n_omegas=401, show_plot=False, freq_bound=(6,6), weight_threshold=0.2, min_distance=0.3)
    if result is None:
        results_music = 0
    else:
        results_music = len(result)
    results_aic = model_selection_gmm_aic(samples, max_components=3)
    results_bic = model_selection_gmm_bic(samples, max_components=3)

    print(f"Separation distance: {separation_distance}, Sample size: {sample_size}, MUSIC: {results_music}, AIC: {results_aic}, BIC: {results_bic}")
    return separation_distance, sample_size, results_music, results_aic, results_bic


if __name__ == '__main__':
    n_trials = 96 * 30
    separations = 5.9 * np.random.rand(n_trials) + 0.1
    samplesizes = 4 * np.random.rand(n_trials) + 2
    print(len(samplesizes))
    pool = Pool(8)
    inputs = [(samplesizes[i],separations[i]) for i in range(n_trials)]
    results = pool.starmap(process_trial, inputs)
    pool.close()
    pool.join()
    separation_trials = np.zeros(n_trials)
    samplesize_trials = np.zeros(n_trials)
    results_music = np.zeros(n_trials)
    results_aic = np.zeros(n_trials)
    results_bic = np.zeros(n_trials)
    for i, result in enumerate(results):
        separation_trials[i], samplesize_trials[i], results_music[i], results_aic[i], results_bic[i] = result
    np.save("phase transition/2/separations.npy", separation_trials)
    np.save("phase transition/2/results_music.npy", results_music)
    np.save("phase transition/2/results_aic.npy", results_aic)
    np.save("phase transition/2/results_bic.npy", results_bic)
    np.save("phase transition/2/samplesize.npy", samplesize_trials)