from fairsc import Experiment, discover_groups, low_rank_approx
from fairsc.generators import gen_trade
from fairsc.algorithms import fair_sc, normal_sc
from fairsc.evaluations import compute_individual_balance
import numpy as np
import os
from shutil import rmtree


class TradeExperiment(Experiment):

    def __init__(self, exp_dir: str, seed: int, name: str, **kwargs):
        super(TradeExperiment, self).__init__(exp_dir, seed, name, **kwargs)
        self.kwargs = kwargs

    def _run(self):

        # Get control variables
        data_dir = self.kwargs['data_dir']
        num_clusters = self.kwargs['num_clusters']
        use_fair_sc = self.kwargs['use_fair_sc']
        normalize_laplacian = self.kwargs['normalize_laplacian']
        normalize_evec = self.kwargs['normalize_evec']
        normalize_balance = self.kwargs['normalize_balance']

        # Get the graph
        adj_mat, fair_mat = gen_trade(data_dir)
        original_fair_mat = fair_mat.copy()

        # Perform post processing on fairness matrix
        if use_fair_sc and self.kwargs['fair_mat_post_process_op'] is not None:
            fair_mat_post_process_op = self.kwargs['fair_mat_post_process_op']
            post_process_op_args = self.kwargs['post_process_op_args']
            fair_mat = fair_mat_post_process_op(fair_mat, **post_process_op_args)

        # Run clustering
        if use_fair_sc:
            clusters = fair_sc(adj_mat, fair_mat, num_clusters, normalize_laplacian, normalize_evec)
        else:
            clusters = normal_sc(adj_mat, num_clusters, normalize_laplacian, normalize_evec)

        # Compute balances
        balances, avg_balance = compute_individual_balance(clusters, original_fair_mat, normalize_balance)

        # Save the output
        for i in range(adj_mat.shape[0]):
            self.output[i] = str(clusters[i]) + ',' + str(balances[i])
        self.output['AvgBalance'] = str(avg_balance)


# Prepare configurations
config_common = {
    'data_dir': './data/trade/',
    'use_fair_sc': True,
    'normalize_laplacian': True,
    'normalize_evec': False,
    'num_clusters': 2,
    'normalize_balance': True
}
configs = [
    {
        'use_fair_sc': False
    },
    {
        'fair_mat_post_process_op': None
    },
    {
        'fair_mat_post_process_op': discover_groups,
        'post_process_op_name': 'discover_groups',
        'post_process_op_args': {'num_groups': 2}
    },
    {
        'fair_mat_post_process_op': low_rank_approx,
        'post_process_op_name': 'low_rank_approx',
        'post_process_op_args': {'rank': 2}
    }
]
use_config = 0
n_sims = 10
seeds = [np.random.randint(1000000) for _ in range(n_sims)]
base_dir = './Results'
name = 'default'

# Create base directory if necessary
if not os.path.exists(base_dir):
    os.mkdir(base_dir)

# Remove experiment directory if necessary
exp_dir = os.path.join(base_dir, name)
if os.path.exists(exp_dir):
    rmtree(exp_dir)
os.mkdir(exp_dir)

# Combine the configuration
curr_config = dict((key, config_common[key]) for key in config_common)
for key in configs[use_config]:
    curr_config[key] = configs[use_config][key]

# Start the experiments
for sim in range(n_sims):
    experiment = TradeExperiment(os.path.join(exp_dir, 'sim-' + str(sim)), seeds[sim], 'Trade', **curr_config)
    experiment.run()
    print('Done:', sim + 1, 'of', n_sims)
