import sys
import os

# Append the parent directory of B to sys.path
b_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
c_dir = os.path.dirname(b_dir)
if b_dir not in sys.path or c_dir not in sys.path:
    sys.path.append(b_dir)
    sys.path.append(c_dir)

import argparse
from functools import partial
from itertools import product
import json
import joblib
import numpy as np
from dataclasses import dataclass
from typing import Any
from sklearn.metrics import roc_auc_score

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
from ..analysis.calibration import generate_calibration_constants
from ..analysis.rdp_cumulative import analyze_multiclass_confident_fair_gnmax, analyze_multiclass_confident_gnmax
from fairpate_tabular.args import process_arguments
from fairpate_tabular.grid_search import run_grid_search
from fairpate_tabular.optuna_study import run_optuna
from fairpate_tabular.utils import get_disparity, one_hot, sklearn_disparity
from fairpate_tabular.utils import process_data, write_results
from .query import FairPATEQuery, PATEQuery, FairnessProcessor
import pandas as pd
# from imblearn.under_sampling import RandomUnderSampler
import multiprocessing

NUM_JOBS = multiprocessing.cpu_count() - 1 \
            if os.environ.get("SLURM_CPUS_PER_TASK", None) is None \
            else int(os.environ["SLURM_CPUS_PER_TASK"])

MAX_ITER = 1000

@dataclass
class DataSet:
    features: np.ndarray
    labels: np.ndarray
    sensitives: np.ndarray
    set: str

@dataclass
class FairPATERunParams:
    train: DataSet
    test: DataSet
    np_rng: Any
    generate_model_and_fit: Any
    MAX_ITER: int

    # use_stratification: bool
    # teacher_query_set_split
    # use_inference_time_postprocessing: bool
    # seed: int
    # num_teachers: int
    # backend: str
    # skip: str
    # log_path: str
    # parallel: bool
    # gt_fairness: bool
    # num_calib: int
    # pate_based_model: str
    # threshold: float
    # fairness_threshold: float
    # sigma_threshold: float
    # sigma_gnmax: float
    # budget: float
    # delta: float
    # min_group_count: int
    # num_classes: int
    # dataset: str
    # fairness_metric: str


    def get_from_args(self, args):
        pass

def create_teacher_indices(args, packed_data, indices, num_teachers):
    '''Creates a powerset of teacher indices such that each teacher has the same number of samples'''
    if num_teachers not in [2, 4, 8, 16, 32, 64, 128, 256, 512]:
        raise ValueError("Number of teachers must be a power of 2")

    train_labels = packed_data.train.labels
    train_sensitives = packed_data.train.sensitives

    teacher_indices = [indices]
    while len(teacher_indices) < num_teachers:
        sizes = np.array([len(x) for x in teacher_indices])
        idx = np.argmax(sizes)
        largest_teacher_indices = teacher_indices[idx]
        teacherset_1, teacherset_2 = train_test_split(largest_teacher_indices, train_size = 0.5, shuffle = True, 
                                                      stratify=define_joint_label(train_labels[largest_teacher_indices], train_sensitives[largest_teacher_indices]), random_state=args.seed)
        # Remove the largest teacher indices
        teacher_indices = teacher_indices[:idx] + teacher_indices[idx+1:]
        teacher_indices.append(teacherset_1)
        teacher_indices.append(teacherset_2)
    
    final_sizes =np.array([len(x) for x in teacher_indices])
    assert final_sizes.max() - final_sizes.min() <= 1
    return teacher_indices

def create_val_and_teacher_indices(args, teachers_indices, train_sensitives, train_labels, np_rng, tt_size):
    val_block_idx, found_valid_val = 0, False
    while val_block_idx < args.num_teachers + 1:
        # Check validation has at least one of each sensitive attribute
        validation_set_indices = teachers_indices[val_block_idx * tt_size: (val_block_idx + 1) * tt_size]
        # Add remainder indices to validation set
        remainder_indices = teachers_indices[(args.num_teachers + 1) * tt_size:]
        validation_set_indices = np.concatenate([validation_set_indices, remainder_indices])
        if len(np.unique(train_sensitives[validation_set_indices])) == args.num_sensitives:
            found_valid_val = True
            break
        val_block_idx += 1
    if not found_valid_val:
        raise ValueError("Could not find a validation set with at least one of each sensitive attribute. Try decreasing the number of teachers?")
    
    # Ensure each teacher train set has more than one distinct label
    individual_teacher_indices = {}
    blocks_valid_to_swap_from = []
    _id = 0
    for block_idx in range(args.num_teachers + 1):
        if block_idx == val_block_idx: continue

        block_indices = teachers_indices[block_idx * tt_size: (block_idx + 1) * tt_size]

        # Check if all same label. If so, pick a random sample and swap with another teacher for a new label.
        unique_labels = np.unique(train_labels[block_indices])
        valid_to_swap_from = True
        while len(unique_labels) == 1:
            #print(f"REQUIRES A SWAP -------------------")
            #print(f"Before, block_indices: {block_indices}, {block_idx}")
            #print(f"Before, labels for block: {train_labels[block_indices]}")
            valid_to_swap_from = False
            idx_in_block = np_rng.integers(low=0, high=len(block_indices))
            idx_in_train = block_indices[idx_in_block]
            single_label = train_labels[idx_in_train]

            # Pick a random block
            new_block_idx = np_rng.choice(blocks_valid_to_swap_from)
            new_block_indices = individual_teacher_indices[new_block_idx]
            # Filter for indices that lead to different labels
            filtered_indices = new_block_indices[train_labels[new_block_indices] != single_label]

            if len(filtered_indices) == 0: 
                raise ValueError("Could not find a valid swap!")
            new_idx_in_train = np_rng.choice(filtered_indices, 1)
            # Find new_idx_in_block
            new_idx_in_block = np.where(new_block_indices == new_idx_in_train)[0][0]
            # Swap
            #print(f"Before, new_block_indices: {new_block_indices}, {new_block_idx}")
            #print(f"Before, labels for new block: {train_labels[new_block_indices]}")
            block_indices[idx_in_block], individual_teacher_indices[new_block_idx][new_idx_in_block] = individual_teacher_indices[new_block_idx][new_idx_in_block], block_indices[idx_in_block]
            #print(f"After, new_block_indices: {new_block_indices}, {new_block_idx}")
            #print(f"After, labels for new block: {train_labels[new_block_indices]}")
            print(f"After, block_indices: {block_indices}, {block_idx}")
            print(f"After, labels for block: {train_labels[block_indices]}")

            unique_labels = np.unique(train_labels[block_indices])
        if valid_to_swap_from: blocks_valid_to_swap_from.append(_id)

        individual_teacher_indices[_id] = block_indices
        _id += 1 # Keeping in line with previous convention
    
    return individual_teacher_indices, validation_set_indices

def define_joint_label(labels, sensitives):
    label_groups = np.unique(labels)
    sensitive_groups = np.unique(sensitives)
    joint_labels = list(product(label_groups, sensitive_groups))
    yz_to_id = dict(zip(joint_labels, range(len(joint_labels))))
    joint_labels = np.array([yz_to_id[y, z]for y, z in zip(labels, sensitives)])
    return joint_labels
    
def evaluate(model, features, labels, sensitives, args, dataset=None):
    if args.use_inference_time_postprocessing:
        # This is the case for PATE-S_post
        if args.fairness_metric == 'DemParity':
            processor = FairnessProcessor(sensitive_group_list=np.unique(sensitives), 
                                           num_classes=args.num_classes,
                                           min_group_count=args.min_group_count, 
                                           max_fairness_violation=args.fairness_threshold)
            predictions = model.predict(features, dataset=dataset)
            before_accuracy = (predictions == labels).mean()
            before_disparity = get_disparity(args.fairness_metric, predictions, sensitives, labels=labels, num_sensitives=args.num_sensitives, args=args)
            log("Before (acc, disparity, coverage): ", f"({before_accuracy}, {before_disparity} , 1.0)")

            premask = np.arange(labels.shape[0], dtype=int)
            np_rng.shuffle(premask)
            predictions = predictions[premask]
            sensitives = sensitives[premask]
            labels = labels[premask]
            
            mask = processor.filter_set(predictions, sensitives)
            masked_labels = labels[mask]
            masked_sensitives = sensitives[mask]
            masked_predictions = predictions[mask]
            
            accuracy = (masked_predictions == masked_labels).mean()
            disparity = get_disparity(args.fairness_metric, masked_predictions, masked_sensitives, labels=masked_labels, num_sensitives=args.num_sensitives, args=args)
            coverage = len(masked_labels) / len(labels)
            log("After (acc, disparity, coverage): ", f"({accuracy}, {disparity} , {coverage})")
        else:
            raise NotImplementedError("Only DemParity is supported for now")
    else:
        preds_classes = model.predict_proba(features)
        preds_labels = np.argmax(preds_classes, axis=1)
        accuracy = (preds_labels == labels).mean()
        auc = roc_auc_score(labels, preds_labels) 
        disparity = get_disparity(args.fairness_metric, preds_labels, sensitives, labels=labels, num_sensitives=args.num_sensitives, args=args)
        coverage = 1

    return accuracy, disparity, coverage, auc

def run_experiment(args, log, packed_data, results_db=None):
    train_features = packed_data.train.features
    train_labels = packed_data.train.labels
    train_sensitives = packed_data.train.sensitives
    train_set = packed_data.train.set

    test_features = packed_data.test.features
    test_labels = packed_data.test.labels
    test_sensitives = packed_data.test.sensitives
    test_set = packed_data.test.set

    np_rng = packed_data.np_rng

    generate_model_and_fit = packed_data.generate_model_and_fit
    MAX_ITER = packed_data.MAX_ITER

    
    if not args.use_stratification:
        train_indices = np.arange(len(train_features))
        np_rng.shuffle(train_indices)
        teachers_indices, query_set_indices = train_test_split(train_indices, train_size = args.teacher_query_set_split, shuffle = True, random_state=args.seed)
        # Teacher train size. Adding one more 'teacher' for validation purposes
        tt_size = len(teachers_indices) // (args.num_teachers + 1)
        if tt_size < 4:
            raise ValueError("Number of teachers is too large for the dataset")
        individual_teacher_indices, validation_set_indices = create_val_and_teacher_indices(args, teachers_indices, train_sensitives, train_labels, np_rng, tt_size)
    else:
        train_indices = np.arange(len(train_features))
        np_rng.shuffle(train_indices)
        # Stratified Split
        teachers_indices, query_set_indices = train_test_split(train_indices, train_size = args.teacher_query_set_split, shuffle = True, 
                                                               stratify=define_joint_label(train_labels, train_sensitives), random_state=args.seed)

        split_indices = create_teacher_indices(args, packed_data, teachers_indices, args.num_teachers)
        validation_set_indices = split_indices[-1]
        individual_teacher_indices = dict(zip(range(args.num_teachers-1), split_indices[:-1]))
        tt_size = len(individual_teacher_indices[0])

    log(f"Train: {len(train_indices)}, Each Teacher: {tt_size}, Val: {len(validation_set_indices)}, Query: {len(query_set_indices)}")
    #log("Size of Each Teacher Training Set: ", tt_size)
    #log("Size of Validation Set: ", len(validation_set_indices))
    #log("Size of Query Set: ", len(query_set_indices))
    
    def _iterate(_id):
        model = generate_model_and_fit(train_features[individual_teacher_indices[_id]], train_labels[individual_teacher_indices[_id]], model_id=_id)
        votes = one_hot(model.predict(train_features[query_set_indices]))[:, :, None]
        if args.backend == 'deep_auc':
            model.destruct()
        return votes

    num_teachers_to_train = args.num_teachers-1 if args.use_stratification else args.num_teachers
    if args.skip == 'voting':
        train_votes = np.load(f'{args.log_path}/train_votes.npz')["train_votes"]
    else:
        if args.parallel:
            train_votes = joblib.Parallel(n_jobs=args.parallel if args.parallel is not None else 1, prefer="processes")(
                joblib.delayed(_iterate)(_id) for _id in 
                tqdm(range(num_teachers_to_train), desc="Training Teachers")
                )
        else:
            train_votes = [_iterate(_id) for _id in tqdm(range(num_teachers_to_train), desc="Training Teachers")]
        train_votes = np.concatenate(train_votes, axis=2)
        np.savez(f'{args.log_path}/train_votes', train_votes=train_votes)
        
    # shape: (num_query, num_class, num_teachers -1) if stratisfied else (num_query, num_class, num_teachers)
    # votes over all training samples (this is only an intermediate value; will _not_ be using them all)
    train_votes = train_votes.sum(axis=2) # shape: (num_query, num_class) this is the histogram of the votes
        
    # Privacy Analysis
    if args.gt_fairness:
        # Note that with calibration, query set will be smaller
        calibration_indices, query_set_indices = train_test_split(query_set_indices, train_size = args.num_calib, shuffle = True, 
                                                                  stratify=define_joint_label(train_labels[query_set_indices], train_sensitives[query_set_indices]), 
                                                                  random_state=args.seed)

        c_votes = train_votes[calibration_indices]
        c_targets = train_labels[calibration_indices]
        c_sensitives = train_sensitives[calibration_indices]

        if (c_votes is not None and c_targets is not None or c_sensitives is not None):
            for_z, but_z = generate_calibration_constants(args.fairness_metric, c_votes=c_votes, c_targets=c_targets, c_sensitives=c_sensitives)

        log("Size of Calibration Set: ", len(calibration_indices))
    else:
        for_z, but_z = None, None


    raw_votes = train_votes
    labels = train_labels[query_set_indices] # shape: (num_query, num_features)
    sensitives = train_sensitives[query_set_indices] # shape: (num_query, num_sensitive)
    
    if args.pate_based_model == 'fairpate':
        (max_num_query, dp_eps, partition, answered, order_opt, 
        sensitive_group_count, pos_prediction_one_hot, answered_curr, gaps, pr_answered) = \
                    analyze_multiclass_confident_fair_gnmax(votes=raw_votes, sensitives=sensitives, 
                                                                threshold=args.threshold, 
                                                                fair_threshold=args.fairness_threshold, 
                                                                sigma_threshold=args.sigma_threshold, 
                                                                sigma_fair_threshold=0,
                                                                sigma_gnmax=args.sigma_gnmax, 
                                                                budget = args.budget,
                                                                for_z=for_z,
                                                                but_z=but_z,
                                                                delta = args.delta,
                                                                num_sensitive_attributes=args.num_sensitives,
                                                                file='.', log=lambda *x: None, args=args)
    
    elif args.pate_based_model in ['pate', 'pateSpre', 'pateSin']:
        (max_num_query, dp_eps, partition, answered, order_opt) = \
                    analyze_multiclass_confident_gnmax(votes=raw_votes,
                                                       threshold=args.threshold,
                                                       sigma_threshold=args.sigma_threshold,
                                                       sigma_gnmax=args.sigma_gnmax, 
                                                       budget = args.budget,
                                                       delta = args.delta, file='.')
    else:
        raise NotImplementedError

    log("Maximum #Queries to be answered: ", max_num_query)
    if max_num_query == 0:
        raise ValueError("No queries to be answered. Try increasing epsilon?")
    achieved_eps = dp_eps[max_num_query-1]

    # Train Student model
    if args.pate_based_model == 'fairpate':
        query = FairPATEQuery(sensitive_group_list=np.unique(sensitives),
                            min_group_count=args.min_group_count, 
                            max_fairness_violation=args.fairness_threshold, 
                            num_classes=args.num_classes, 
                            threshold=args.threshold,
                            sigma_threshold=args.sigma_threshold, 
                            sigma_gnmax=args.sigma_gnmax, 
                            fairness_metric=args.fairness_metric, 
                            dataset=args.dataset)
        answered, student_train_labels = query.create_student_training_set(train_features[query_set_indices], train_sensitives[query_set_indices], 
                                                                           raw_votes, for_z=for_z, but_z=but_z)
    else:
        query = PATEQuery(num_classes=args.num_classes, 
                          threshold=args.threshold, 
                          sigma_threshold=args.sigma_threshold, 
                          sigma_gnmax=args.sigma_gnmax)
        answered, student_train_labels = query.create_student_training_set(train_features[query_set_indices], 
                                                                           raw_votes)

    
    
    # max_num_queries comes from the privacy analysis
    indices_answered = query_set_indices[answered][:max_num_query]
    student_train_sensitives = train_sensitives[indices_answered]


    if args.pate_based_model == 'pateSpre':
        # PATE-S_pre mitigation occurs here:
        log("#Queries before fairness processing: ", len(indices_answered))
        if args.fairness_metric == 'DemParity':
            processor = FairnessProcessor(sensitive_group_list=np.unique(sensitives), 
                                          num_classes=args.num_classes,
                                          min_group_count=args.min_group_count, 
                                          max_fairness_violation=args.fairness_threshold)
            mask = processor.filter_set(student_train_labels[:max_num_query], student_train_sensitives)
            indices_answered = indices_answered[mask]
            log('#Queries after fairness processing: ', len(indices_answered))
        else:
            raise NotImplementedError("Only DemParity is supported for now")
    
    log("Actual #Queries answered: ", len(indices_answered))
    student_train_features = train_features[indices_answered]

    if args.pate_based_model == 'pateSin':
        if args.backend != 'sklearn':   
            raise NotImplementedError
        
        if args.fairness_metric == 'DemParity':
            from fairlearn.reductions import DemographicParity, ExponentiatedGradient
            unmitigated_student_model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed)
            constraint = DemographicParity()
            student_model = ExponentiatedGradient(unmitigated_student_model, constraint)
            student_model.fit(X=student_train_features, y=student_train_labels[:max_num_query], 
                              sensitive_features=student_train_sensitives)
        else:
            raise NotImplementedError("Only DemParity is supported for now")
    else:
        #breakpoint()
        student_model = generate_model_and_fit(student_train_features, student_train_labels[:max_num_query], model_id='S')
        
    validation_out  = evaluate(student_model, train_features[validation_set_indices], train_labels[validation_set_indices], 
                               train_sensitives[validation_set_indices], args, dataset=train_set)
    test_out = evaluate(student_model, test_features, test_labels, test_sensitives, args, dataset=test_set)

    student_model_validation_accuracy, validation_disparity, validation_coverage, validation_auc = validation_out
    student_model_test_accuracy, test_disparity, test_coverage, test_auc = test_out

    #log("Validation Accuracy: ", student_model_validation_accuracy)
    #log("Test Accuracy: ", student_model_test_accuracy)

    #log("Validation Disparity: ", validation_disparity)
    #log("Test Disparity: ", test_disparity)
    
    #log("Validation Coverage: ", validation_coverage)
    #log("Test Coverage: ", test_coverage)

    #log("Validation AUC: ", validation_auc)
    #log("Test AUC: ", test_auc)

    log(f"Accuracy - Validation: {student_model_validation_accuracy:.3f}, Test: {student_model_test_accuracy:.3f}", end=' ')
    log(f"Disparity - Validation: {validation_disparity:.3f}, Test: {test_disparity:.3f}", end=' ')
    log(f"Coverage - Validation: {validation_coverage:.3f}, Test: {test_coverage:.3f}")
    log(f"AUC - Validation: {validation_auc:.3f}, Test: {test_auc:.3f}")


    if results_db is not None:
        results_db = pd.concat([results_db, 
                                pd.DataFrame({'model': args.pate_based_model, 
                                              'fairness_metric': args.fairness_metric, 
                                              'dataset': args.dataset, 'num_teachers': args.num_teachers, 
                                              'threshold': args.threshold, 
                                              'fairness_threshold': args.fairness_threshold, 
                                              'sigma_threshold': args.sigma_threshold, 
                                              'sigma_gnmax': args.sigma_gnmax, 
                                              'budget': args.budget, 
                                              'delta': args.delta, 
                                              'seed': args.seed, 
                                              'student_validation_accuracy': student_model_validation_accuracy, 
                                              'student_test_accuracy': student_model_test_accuracy, 
                                              'validation_disparity': validation_disparity, 
                                              'validation_auc': validation_auc,
                                              'test_disparity': test_disparity, 
                                              'achieved_eps': achieved_eps, 
                                              'max_num_query': max_num_query, 
                                              'max_actual_query': len(indices_answered), 
                                              'validation_coverage': validation_coverage, 
                                              'test_coverage': test_coverage,
                                              'test_auc': test_auc
                                              }, index=[0])], ignore_index=True)
        return results_db, student_model_validation_accuracy, validation_disparity, validation_coverage, validation_auc, achieved_eps, max_num_query, len(indices_answered), student_model_test_accuracy, test_disparity, test_coverage, test_auc
    else:
        return student_model_validation_accuracy, validation_disparity, validation_coverage, achieved_eps, max_num_query, validation_auc, len(indices_answered), student_model_test_accuracy, test_disparity, test_coverage, test_auc

    
if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--dataset', type=str, default='adult')
    args.add_argument('--list_dataset', nargs='+', type=str, help='A list of datasets to run the experiment on.')

    args.add_argument('--num_classes', type=int, default=2)
    args.add_argument('--output_col_name', type=str, default='income')
    args.add_argument('--split', type=float, default=0.75)

    args.add_argument('--dem_disparity_interpretation', type=str, default='max_vs_min')

    args.add_argument('--teacher_query_set_split', type=float, default=0.7)
    
    args.add_argument('--num_teachers', type=int, default=4)
    args.add_argument('--list_num_teachers', nargs='+', type=int, help='A list of number of teachers to run the experiment on.')

    args.add_argument('--threshold', type=int, default=2)
    args.add_argument('--list_threshold', nargs='+', type=int, help='A list of thresholds to run the experiment on.')

    args.add_argument('--fairness_threshold', type=float, default=0.2)
    args.add_argument('--list_fairness_threshold', nargs='+', type=float, help='A list of fairness thresholds to run the experiment on.')

    args.add_argument('--sigma_threshold', type=float, default=60)
    args.add_argument('--list_sigma_threshold', nargs='+', type=float, help='A list of sigma thresholds to run the experiment on.')

    args.add_argument('--sigma_fair_threshold', type=int, default=0)

    args.add_argument('--sigma_gnmax', type=float, default=25)
    args.add_argument('--list_sigma_gnmax', nargs='+', type=float, help='A list of sigma gnmax to run the experiment on.')

    args.add_argument('--budget', type=float, default=1000)
    args.add_argument('--list_budget', nargs='+', type=float, help='A list of budgets to run the experiment on.')

    args.add_argument('--delta', type=float, default=1e-5)
    args.add_argument('--verbose', action='store_true')

    args.add_argument('--seed', type=int, default=0)
    args.add_argument('--list_seed', nargs='+', type=int, help='A list of seeds to run the experiment on.')

    args.add_argument('--data_path', type=str, default='./fairpate_tabular/data/')
    args.add_argument('--min_group_count', type=int, default=50)
    args.add_argument('--results_dir', type=str, default='.', help='Directory to store the results in.')
    
    args.add_argument('--use_optuna', action='store_true', help='Whether to use optuna to find the best hyperparameters.')
    args.add_argument('--num_optuna_trials', type=int, default=1000, help='Number of optuna trials to run.')

    args.add_argument('--use_stratification', action='store_true', help='Whether to use stratification to split the data.')

    args.add_argument('--fairness_metric', type=str, default='DemParity', help='Fairness metric to use for the experiment. Can be `DemParity`, `ErrorParity`, or `EqualityOfOdds`.')
    args.add_argument('--list_fairness_metric', nargs='+', type=str, help='A list of fairness metrics to run the experiment on.')

    args.add_argument('--num_calib', type=int, default=100, help='Number of calibration samples to use for ground-truth-based fairness metrics.')

    args.add_argument('--pate_based_model', type=str, default='fairpate', help='What PATE-based model to use. Can be `fairpate`, `pate`, `pateSpre`, `pateSin` or `pateSpost`.')

    args.add_argument('--use_inference_time_postprocessing', action='store_true', help='Whether to use inference-time postprocessing to mitigate fairness violations.')

    args.add_argument('--undersampling_ratio', type=float, default=None, help='Ratio of the majority class to the minority class. If None, no undersampling is done.')

    args.add_argument('--optuna_db_path', type=str, default='.', help='Path to the optuna study db.')

    args.add_argument('--backend', type=str, default='sklearn', help='Backend to use for the experiment. Can be `sklearn`, `keras+pytorch` or `deep_auc`.')    
    args.add_argument('--epochs', type=int, default=100, help='Number of epochs to train the student model for.')
    args.add_argument('--keras_dict', type=json.loads, default='{"optimizer": "adam", "loss": "binary_crossentropy", \
                      "metrics": ["accuracy"]}', help='Dictionary of arguments to pass to the keras model.')

    args.add_argument('--parallel', type=int, default=None, help='Whether to use parallel processing to train the teachers.')

    args.add_argument('--deep_auc_dict', type=json.loads, default='{"batch_size": 32, "lr": 0.05, "margin": 1.0, \
                      "epoch_decay": 2e-3, "weight_decay": 1e-5, "eval_every": 200, "epochs": 2, "train_validation_split" : 0.8}', help='Dictionary of arguments to pass to the deep_auc model.')

    args.add_argument('--skip', type=str, default=None, help='Whether to skip par of the pipeline. Can be `training_teachers`, `training_all`, `voting` or None (defualt).')

    args.add_argument('--log_path', type=str, default='./logs/', help='Path to store the logs and artifacts from the experiments.')
    
    args = process_arguments(args.parse_args())

    if args.verbose:
        log = print
    else:
        log = lambda *x, **y: None
# 
    # breakpoint()
    # log('\n'.join(f'{k}: {v}' for k, v in vars(args).items()))

    # PATE analysis needs this
    np.random.seed(args.seed)
    # Otherwise we use an rng
    np_rng = np.random.default_rng(args.seed)

    # Initalize results db
    if os.path.exists(args.results_db_path):
        results_db = pd.read_parquet(args.results_db_path)
    else:
        results_db = pd.DataFrame(columns=['dataset', 'model', 'fairness_metric', 'num_teachers', 'threshold', 
                                           'fairness_threshold', 'sigma_threshold', 'sigma_gnmax', 'budget', 
                                           'delta', 'seed', 'student_validation_accuracy', 'student_test_accuracy', 
                                           'validation_disparity', 'test_disparity', 'achieved_eps', 'max_num_query', 
                                           'max_actual_query'])

    args_dict = vars(args)

    # breakpoint()
    list_args = dict(filter(lambda x: 'list' in x[0], args_dict.items()))
    non_None_list_args = dict(filter(lambda x: x[1] != None, list_args.items()))
    # non_list_args = dict(filter(lambda x: 'list' not in x[0], args_dict.items()))
    
    if args.backend == 'sklearn':
        def generate_model_and_fit(train_features, train_labels, model_id=None):
            model = LogisticRegression(max_iter=MAX_ITER, random_state=args.seed).fit(X=train_features, y=train_labels)
            return model
        train_set, test_set = None, None
    elif args.backend == 'keras+pytorch':
        from arch import LogisticRegrressionKeras

        def generate_model_and_fit(train_features, train_labels, model_id=None):
            model = LogisticRegrressionKeras(input_num_attr=args.num_inp_attr)
            model.compile(**args.keras_dict)
            model.fit(x=train_features, y=train_labels, epochs=args.epochs, verbose=0)
            # monkey-patching the predict function to return binary predictions
            model._predict = getattr(model, "predict")
            # breakpoint()
            setattr(model, "predict", lambda x: (model._predict(x, verbose=0) >= 0.5).astype(int).squeeze())
            return model
        train_set, test_set = None, None
    elif args.backend == 'deep_auc':
        from fairpate_tabular.custom_trainers.deep_auc import train_set, test_set, generate_model_and_fit
        generate_model_and_fit = partial(generate_model_and_fit, args=args)
    
    else:
        raise NotImplementedError
    
    # Import data
    train_features, train_labels, train_sensitives, test_features, test_labels, test_sensitives = \
                process_data(np_rng, args, log, train_set=train_set, test_set=test_set)
    packed_data = {
        'train': {
            'features': train_features,
            'labels': train_labels,
            'sensitives': train_sensitives,
            'set': train_set
        },
        'test': {
            'features': test_features,
            'labels': test_labels,
            'sensitives': test_sensitives,
            'set': test_set
        },
        'np_rng': np_rng,
        'generate_model_and_fit': generate_model_and_fit,
        'MAX_ITER': MAX_ITER
    }

    if args.use_optuna:
        print("Running Optuna (Ignoring list items)")
        run_optuna(args, run_experiment)

    elif len(non_None_list_args) > 0:
        print("Running multiple experiments")
        run_grid_search(args, packed_data, run_experiment, results_db, non_None_list_args)
    else:       
        results_db, student_model_validation_accuracy, validation_disparity, validation_coverage, validation_auc, \
            achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity, test_coverage, test_auc = run_experiment(args, log, packed_data, results_db)
        # write_results(args, student_model_validation_accuracy, validation_disparity, achieved_eps, max_num_query, num_queries_answered, student_model_test_accuracy, test_disparity)
        results_db.drop_duplicates().to_parquet(args.results_db_path)