import sys
from baselines.dp_dg import dp_dg_experiment
from argparse import Namespace
from torch import device
from convertors import get_train_test_index, write_dp_dg_format_train_test_data, DPDG_ADULT_COLUMN_ORDER
import pandas as pd
import numpy as np
import sklearn
import os
import re
from IPython.display import display
from fairpate_tabular.utils import get_disparity
import pdb

def find_best_epoch_number(path):
    with open(os.path.join(path, "log.txt")) as f:
        lines = f.readlines()
        best_lines = [line for line in lines if "has the best validation performance so far" in line]
        # pdb.set_trace()
        match = re.match(".*best: (\d+).*", best_lines[-1])
        if match is None:
            match = re.match("Epoch (\d+).*", best_lines[-1])
        if match is not None:
            epoch_number = int(match.group(1))
        else:
            raise ValueError("No best epoch found")
            epoch_number = 0
        # print(epoch_number)
    return epoch_number

def evaluate_csv_results():
    best_epoch = find_best_epoch_number(f"baselines/dp_dg/logs/adult/{experiment_name}")
    print("Validation")
    print(pd.read_csv(f"baselines/dp_dg/logs/adult/{experiment_name}/val_eval.csv").query(f"epoch == {best_epoch}"))

    print("Test")
    print(pd.read_csv(f"baselines/dp_dg/logs/adult/{experiment_name}/test_eval.csv").query(f"epoch == {best_epoch}"))


def fill_in_preds(df, fill_in_index, pred_col, csv_path):
    df.loc[fill_in_index, pred_col] = pd.read_csv(csv_path, header=None)[0].values
    assert np.all(df.loc[fill_in_index, pred_col].values == pd.read_csv(csv_path, header=None)[0].values)
    return df
    
def calculate_accuracy_disparity_metrics(experiment_name, train_index, validation_index, test_index):
    adult_dpfermi = pd.read_csv("Datasets/Adult/adult_original_purified.csv").reset_index(drop=True)
    adult_reuslts = adult_dpfermi.copy(deep=True)
    adult_reuslts["best_pred"] = np.nan
    adult_reuslts["last_pred"] = np.nan
    adult_reuslts["split"] = pd.concat([pd.Series("train", index=train_index), 
                                        pd.Series("validation", index=validation_index), 
                                        pd.Series("test", index=test_index)], axis=0).sort_index()

    adult_reuslts = fill_in_preds(adult_reuslts, validation_index, "best_pred", f"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:val_seed:1_epoch:best_pred.csv")
    adult_reuslts = fill_in_preds(adult_reuslts, validation_index, "last_pred", f"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:val_seed:1_epoch:last_pred.csv")
    adult_reuslts = fill_in_preds(adult_reuslts, test_index, "best_pred", f"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:test_seed:1_epoch:best_pred.csv")
    adult_reuslts = fill_in_preds(adult_reuslts, test_index, "last_pred", f"baselines/dp_dg/logs/adult/{experiment_name}/adult_split:test_seed:1_epoch:last_pred.csv")

    print("= Dataset stats")
    print(adult_reuslts.query("split in ['test', 'validation']").groupby(["split", ">50K", "sex"]).apply(lambda x: len(x)))

    # Dropping the trian set since we do not have labels for them
    adult_reuslts = adult_reuslts.query("split != 'train'").copy()

    adult_reuslts["sensitive"] = adult_reuslts["sex"].apply(lambda x: x == "Male").astype(int)
    adult_reuslts["label"] = adult_reuslts[">50K"].apply(lambda x: x == "Yes").astype(int)

    print("= Accuracy")
    print("== best model:")
    print(adult_reuslts.groupby("split").apply(lambda x: (x["best_pred"].astype(int) == x["label"]).astype(float).sum()/len(x)))
    print("== last model:")
    print(adult_reuslts.groupby("split").apply(lambda x: (x["last_pred"].astype(int) == x["label"]).astype(float).sum()/len(x)))

    print("= Demographic Parity")
    print("== best model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("DemParity", x["best_pred"], x["sensitive"])))
    print("== last model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("DemParity", x["last_pred"], x["sensitive"])))
    print("")

    print("= Equality of Odds")
    print("== best model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("EqualityOfOdds", x["best_pred"], x["sensitive"], x["label"])))
    print("== last model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("EqualityOfOdds", x["last_pred"], x["sensitive"], x["label"])))
    print("")

    print("= Error Parity")
    print("== best model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("ErrorParity", x["best_pred"], x["sensitive"], x["label"])))
    print("== last model:")
    print(adult_reuslts.groupby("split").apply(lambda x: get_disparity("ErrorParity", x["last_pred"], x["sensitive"], x["label"])))


if __name__ == "__main__":
    # adult_dpfermi = pd.read_csv("Datasets/Adult/adult_original_purified.csv").reset_index(drop=True)
    # adult_dpdg = pd.concat([pd.read_csv("baselines/dp_dg/data/backup/train.csv"), 
    #                         pd.read_csv("baselines/dp_dg/data/backup/test.csv")]).reset_index(drop=True)

    experiment_name = "run_2"

    fairpate_config = Namespace(
        dataset='adult',
        list_dataset=None,
        num_classes=2,
        output_col_name='>50K',
        split=0.75,
        dem_disparity_interpretation='max_vs_min',
        teacher_query_set_split=0.7,
        num_teachers=4,
        list_num_teachers=None,
        threshold=2,
        list_threshold=None,
        fairness_threshold=0.2,
        list_fairness_threshold=None,
        sigma_threshold=60,
        list_sigma_threshold=None,
        sigma_fair_threshold=0,
        sigma_gnmax=25,
        list_sigma_gnmax=None,
        budget=1000,
        list_budget=None,
        delta=1e-05,
        verbose=True,
        seed=0,
        list_seed=None,
        data_path='./fairpate_tabular/data/',
        min_group_count=50,
        results_dir='.',
        use_optuna=False,
        num_optuna_trials=1000,
        use_stratification=False,
        fairness_metric='DemParity',
        list_fairness_metric=None,
        num_calib=100,
        pate_based_model='fairpate',
        use_inference_time_postprocessing=False,
        undersampling_ratio=None,
        optuna_db_path='.',
        path='./Datasets/Adult/adult_original_purified.csv',
        num_inp_attr=102,
        cols_to_norm=['age',
        'fnlwgt',
        'education-num',
        'capital-gain',
        'capital-loss',
        'hours-per-week'],
        sensitive_attributes=['sex'],
        results_db_path='./fairpate_adult_DemParity_results.parquet',
        gt_fairness=False
        )
    dpdg_config = Namespace(
        dataset='adult', 
        algorithm='ERM', 
        root_dir='./baselines/dp_dg/data', 
        enable_privacy=True, 
        enable_fair_privacy=False, 
        apply_noise=False, 
        split_scheme='official', 
        dataset_kwargs={}, 
        download=False, 
        subsample=False, 
        frac=1.0, 
        version=None, 
        loader_kwargs={'num_workers': 1, 'pin_memory': True}, 
        train_loader='standard', 
        uniform_over_groups=False, 
        distinct_groups=None, 
        n_groups_per_batch=4, 
        batch_size=1024,  #instead of 256
        eval_loader='standard', 
        weighted_uniform_iid=True, 
        uniform_iid=None, 
        sample_rate=0.005, 
        clip_sample_rate=None,
        model='logistic_regression', 
        model_kwargs={'in_features': 85}, # 85 instead of 86 since work-class does not have "Never-worked" in this version of the dataset
        transform=None, 
        target_resolution=None, 
        resize_scale=None, 
        max_token_length=None, 
        loss_function='cross_entropy', 
        loss_kwargs={}, 
        groupby_fields=['sex', 'y'], 
        group_dro_step_size=None, 
        coral_penalty_weight=None, 
        irm_lambda=None, 
        irm_penalty_anneal_iters=None, 
        algo_log_metric='accuracy', 
        val_metric='acc_wg', 
        val_metric_decreasing=False, 
        n_epochs=1, # instead of 20s
        optimizer='SGD', 
        lr=0.22360679774997896, 
        weight_decay=0.01, 
        max_grad_norm=None, 
        optimizer_kwargs={'momentum': 0.9}, sigma=5.0, 
        max_per_sample_grad_norm=0.5, 
        delta=1e-05, 
        sigma2=1.0, 
        C0=1.0, 
        scheduler=None, 
        scheduler_kwargs={}, 
        scheduler_metric_split='val',
        scheduler_metric_name=None, 
        process_outputs_function='multiclass_logits_to_pred', 
        evaluate_all_splits=True, 
        eval_splits=[], 
        eval_only=True, 
        eval_epoch=None, 
        device=device(type='cpu'), 
        seed=1, 
        log_dir=f'./baselines/dp_dg/logs/adult/{experiment_name}',
        log_every=50, 
        save_step=None, 
        save_best=True, 
        save_last=True, 
        save_pred=True, 
        no_group_logging=False, 
        use_wandb=False, 
        progress_bar=False, 
        resume=False)

    if not os.path.exists('./baselines/dp_dg/data/adult_v1.0'):
        # this needs to be done at least once. 
        # The train/test split (the seed, etc.) is not important, 
        # but the pre-processing and ordering is.
        write_dp_dg_format_train_test_data(
                seed=fairpate_config.seed,
                split=fairpate_config.split, 
                undersampling_ratio=fairpate_config.undersampling_ratio, 
                
                path='./Datasets/Adult/adult_original_purified.csv', 
                data_path='./fairpate_tabular/data/',
                save_path='./baselines/dp_dg/data/adult_v1.0',
                
                dataset=fairpate_config.dataset, 
                output_col_name=fairpate_config.output_col_name, 

                cols_to_norm=fairpate_config.cols_to_norm,
                sensitive_attributes=fairpate_config.sensitive_attributes,
                column_order=DPDG_ADULT_COLUMN_ORDER,)


    train_test_index = get_train_test_index(
        seed=fairpate_config.seed,
        split=fairpate_config.split, 
        undersampling_ratio=fairpate_config.undersampling_ratio, 
        
        path='./Datasets/Adult/adult_original_purified.csv', 
        data_path='./fairpate_tabular/data/',

        dataset=fairpate_config.dataset, 
        output_col_name=fairpate_config.output_col_name, 

        cols_to_norm=fairpate_config.cols_to_norm,
        sensitive_attributes=fairpate_config.sensitive_attributes)

    train_index, test_index = train_test_index
    val_to_all_train_ratio = 3000./43000. #using the same ratio as in dp-dg experiment
    train_index, validation_index = \
        sklearn.model_selection.train_test_split(train_index, test_size=val_to_all_train_ratio, random_state=fairpate_config.seed)

    dp_dg_experiment(dpdg_config, train_val_test_index=(train_index, validation_index, test_index))

    # evaluate_csv_results()
    # calculate_accuracy_disparity_metrics(experiment_name, train_index, validation_index, test_index)