import os
import numpy as np
import MCENB as MCENB
from utils import read_arff
from datetime import datetime

def experiment(dir_name):
    """
    run the hold-out validations on datasets in the path of dir_name
    """
    entries = os.listdir(dir_name)
    file_count = sum(os.path.isfile(os.path.join(dir_name, entry)) for entry in entries)
    total_path = dir_name
    run_num = 1
    algorithm_num = 1
    np.random.seed(42)
    random_seeds = np.random.randint(1, 10001, size=run_num)
    result = []
    result_number = []
    dataset_name = []

    for filename in os.listdir(total_path):
        # read each dataset
        dataset_name.append(filename[:-5])
        file_path = os.path.join(total_path, filename)
        X, y = read_arff(file_path, filename)
        result_file = np.zeros(algorithm_num)
        result_file_10 = np.zeros((run_num, algorithm_num))

        # run each hold-out validation
        for i in range(run_num):
            now = datetime.now()
            formatted_now = now.strftime("%Y-%m-%d %H:%M:%S")
            print("============= The " + str(i + 1) + " validation " + formatted_now + " =============")
            # set the parameter value
            result_i = MCENB.run(X, y, max_epoch=2, learning_rate=0.01, abla_var=0, random_seed=random_seeds[i])
            result_i = result_i * 100
            result_file_10[i] = result_i
            result_file += result_i

        # calculate the mean and the standard deviation
        means = np.mean(result_file_10, axis=0)
        stds = np.std(result_file_10, axis=0)
        result_file_10 = np.vstack((result_file_10, means, stds))
        result_file_10 = np.around(result_file_10, 2)

        # Uniform the format
        result_file_10_str = result_file_10.astype(str)
        result_file_mean_std = []
        for i in range(algorithm_num):
            m = result_file_10_str[-2][i]
            s = result_file_10_str[-1][i]
            if m[-3] != '.':
                m += '0'
            if s[-3] != '.':
                s += '0'
            m_s = m + '±' + s
            result_file_mean_std.append(m_s)
        result.append(result_file_mean_std)
        result_number.append(result_file_10[-2])

        # output the result of the current dataset
        print("\n{:<20}\t{}"
              .format('Dataset', 'MCENB'))
        print("{:<20}\t".format(dataset_name[-1]), end='')
        for j in range(algorithm_num):
            print("{:<12}\t".format(result[len(result) - 1][j]), end='')
        print('\n')

    # output the result of all datasets
    dataset_name.append('Average')
    result.append(np.around(np.mean(result_number, axis=0), 2))
    print("\n{:<20}\t{}"
          .format('Dataset', 'MCENB'))
    for i in range(file_count + 1):
        print("{:<20}\t".format(dataset_name[i]), end='')
        for j in range(algorithm_num):
            print("{:<12}\t".format(result[i][j]), end='')
        print('')


if __name__ == "__main__":
    # please modify "dir_path" to the path where your datasets are stored
    dir_path = r'..\dataset\real-world'
    experiment(dir_path)