import 
import time

from ebgmcr import RandomComponentMixtureSynthesizer

from evaluate_synthesized_benchmarks import (DefaultDatasetArgument,
                                             NMF_baseline,
                                             SparseNMF_baseline,
                                             BayesNMF_baseline,
                                             ICA_baseline,
                                             MCR_ALS_baseline)

def search_baselines(data, scanning_components = []):
    runtime = {'NMF': {}, 'Sparse-NMF': {}, 'Bayes-NMF': {}, 'ICA': {}, 'MCR-ALS': {}}
    for count in scanning_components:
        NMF_start = time.time()
        _ =  NMF_baseline(data, count)
        NMF_end = time.time()
        runtime['NMF'][count] = NMF_end - NMF_start
        SparseNMF_start = time.time()
        _ = SparseNMF_baseline(data, count)
        SparseNMF_end = time.time()
        runtime['Sparse-NMF'][count] = SparseNMF_end - SparseNMF_start
        BayesNMF_start = time.time()
        _ = BayesNMF_baseline(data, count)
        BayesNMF_end = time.time()
        runtime['Bayes-NMF'][count] = BayesNMF_end - BayesNMF_start
        ICA_start = time.time()
        _ = ICA_baseline(data, count)
        ICA_end = time.time()
        runtime['ICA'][count] = ICA_end - ICA_start
        MCR_ALS_start = time.time()
        _ = MCR_ALS_baseline(data, count)
        MCR_ALS_end = time.time()
        runtime['MCR-ALS'][count] = ICA_end - ICA_start

    return runtime

def repeat_and_collect_baselines(component_number, datafold, interval, signal_to_nosie_ratio = 20., repeat_time = 5):
    collected_runtime = {'NMF': {}, 'Sparse-NMF': {}, 'Bayes-NMF': {}, 'ICA': {}, 'MCR-ALS': {}} 
    for _ in repeat_time:
        dataset_config = copy.deepcopy(DefaultDatasetArgument)
        dataset_config['component_number'] = component_number
        dataset_config['signal_to_nosie_ratio'] = signal_to_nosie_ratio
        data_number = component_number * datafold
        data_sampler = RandomComponentMixtureSynthesizer(**dataset_config)
        data = data_sampler(data_number)

        scanning_components = [i for i in range(interval, component_number + interval, interval)]
        single_runtime = search_baselines(data, scanning_components)

        for method in runtime:
            for count in scanning_components:
                collected_runtime[method].setdefault(count, [])
                collected_runtime[method][count].append(runtime[method][count])

    return collected_runtime

 
