import numpy as np
import time
import gc
from data import Data
from column_generation import ColumnGeneration
# from visualization import plot_graph
import continuous_scores

'''
Main function to run Bayesian network learning for both continuous and discrete data
Handles data loading, parameter setting, and result saving

This function is called in:
run_continuous_data.py for testing a continuous data instance
run_discrete_data.py for testing a discrete data instance
parallel.py for testing all the continuous data instances
'''

def run(data_type, n, N, d, data_index):
    '''
    data_type: 'C' for continous, 'D' for discrete
    n: number of nodes for continuous data; None for discrete data
    N: sample size for continuous data; None for discrete data
    d: averge in degree (for graph density) for continuous data; None for discrete data
    data_index: for continuous data, index from 0 to 9, for 10 independent instances
                for discrete data, 0 for LUCAS, 1 for ALARM, 2 for INSURANCE
    '''
    np.random.seed(0)
    file_dir = "Enter your file directory of the data"
    if data_type=='C':
        file_path =  file_dir + str(n) + '_' + str(N) + '_' + str(d) + '_' + str(data_index) +'.txt'
        data = Data(data_type='C')
    else: # data_type=='D'
        if data_index == 0:
            file_path = file_dir + 'lucas'
        elif data_index == 1:
            file_path = file_dir + 'alarm'
        else:
            file_path = file_dir + 'insurance'
        data = Data(data_type='D')
    
    data.load_data(file_path)
    data.print_data()

    save_dir = "Enter your save directory for the results"
    if data.data_type=='C':
        save_path = save_dir + 'DCA_' + str(n) + '_' + str(N) + '_' + str(d) + '_' + str(data_index)
    else:
        if data_index == 0:
            save_path = save_dir + 'DCA_' + 'lucas'
        elif data_index == 1:
            save_path = save_dir + 'DCA_' + 'alarm'
        else: 
            save_path = save_dir + 'DCA_' + 'insurance'

    regu_Lambda = 0.5 * np.log(data.ndata) # for BIC
    method = 'DCA'
    '''
    'DCA' for pure DC Algorithm
    'DCA-HC' for adding hill climbing
    'MINLP' for exact solver for MINLP formulation
    '''
    time_limit = 10800 # 3 hours
    cg = ColumnGeneration(data, regu_Lambda, method, time_limit, save_path)
    start = time.time()
    graph_mtx, f = cg.column_generation()
    end = time.time()
    print('Total time for CG: ', round(end-start),2)
    f.write('Total tme for CG: ' + str(round(end-start,2)))
    np.savetxt(save_path+'_graph', graph_mtx)
    
    if data_type=='C': # clean
        continuous_scores.cost_dict = {}
        # continuous_scores.ldet_G_dict = {}
        # continuous_scores.ldet_H_dict = {}
    gc.collect()


    # plot_graph(graph_mtx)