import configparser
import sys

from fair_clustering import fair_clustering, baseline_fair_clustering, baseline_ijcai, strictly_fair_means, baseline_orl, refine_strict, vanilla_kmeans
from util.configutil import read_list

from util.clusteringutil import subsample_data, read_data

import tqdm
import time

config_file = "config/sample.ini"
config = configparser.ConfigParser(converters={'list': read_list})
config.read(config_file)

# Create your own entry in `example_config.ini` and change this str to run
# your own trial
config_str = "bank" if len(sys.argv) == 1 else sys.argv[1]

print("Using config_str = {}".format(config_str))

# Read variables
data_dir = config[config_str].get("data_dir")
dataset = config[config_str].get("dataset")
clustering_config_file = config[config_str].get("config_file")
num_clusters = list(map(int, config[config_str].getlist("num_clusters")))
deltas = list(map(float, config[config_str].getlist("deltas")))
max_points = config[config_str].getint("max_points")
violating = config["DEFAULT"].getboolean("violating")
violation = config["DEFAULT"].getfloat("violation")

print('my test', deltas)

config.read(clustering_config_file)
df = read_data(config, dataset)

if max_points and len(df) > max_points:
    df = subsample_data(df, max_points)

for delta in deltas:
    print('delta = ', delta)
    for n_clusters in tqdm.tqdm(num_clusters):
        for times in [1, 0.5, 0.2, 0.1]:
            # print(dataset)
            
            rounding = False
            
            print('times = ', times)
            
            # print('NIPS 19:')
            # baseline_fair_clustering(dataset, df, clustering_config_file, n_clusters, delta, rounding)
            t1 = time.monotonic()
            print('\n\n\n\n out method')
            fair_clustering(dataset, df, clustering_config_file, n_clusters, delta, rounding, sample=times)
            t2 = time.monotonic()
            print('the overall time', t2-t1)
            # print('\n\n\n\n')
            # masc_res = baseline_ijcai(dataset, df, clustering_config_file, n_clusters, delta)
            # print('ORL21 method')
            # print(' the cost before round is ', masc_res['objective'])
            # print('\n\n')
            # print('vanilla k-means')
            # vanilla_kmeans(dataset, df, clustering_config_file, n_clusters, delta)
            
            
            # cost, assignment = baseline_orl(dataset, df, clustering_config_file, n_clusters, delta)
            # print(cost)
            # cost, assignment = strictly_fair_means(dataset, df, clustering_config_file, n_clusters, delta)
            # print(cost)
            # cost, assignment = refine_strict(dataset, df, clustering_config_file, n_clusters, delta)
            # print(cost) 