import os
import json
import numpy as np
import pandas as pd
from typing import Sequence

from dataset import fetch_dataset, FeatIndex
from eval import Evaluator
from model import LogReg, NeuralNets, SenDrop, SenSR, SenSeI, LCIFR, Project


def run_basic_baseline(
        train_X: np.ndarray,
        train_y: np.ndarray,
        feat_dim: int,
        sen_idx: Sequence,
        evaluator: Evaluator,
        save_dir: str,
        save: bool = True,
) -> None:
    """ Run all baseline methods """

    print("\n=> Logistic Regression")
    clf = LogReg()
    clf.fit(train_X, train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "lr.json"), "w") as f:
            json.dump(res, f)

    print("\n=> Discard Sensitive Features in Logistic Regression")
    clf = SenDrop(drop_idx=sen_idx, model=LogReg())
    clf.fit(train_X, train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "dis_lr.json"), "w") as f:
            json.dump(res, f)

    print("\n=> Neural Networks")
    clf = NeuralNets(input_dim=feat_dim)
    clf.fit(train_X, train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "nn.json"), "w") as f:
            json.dump(res, f)

    print("\n=> Discard Sensitive Features in Neural Networks")
    clf = SenDrop(
        drop_idx=sen_idx,
        model=NeuralNets(input_dim=feat_dim - len(sen_idx)),
    )
    clf.fit(train_X, train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "dis_nn.json"), "w") as f:
            json.dump(res, f)

    return


def run_fair_baseline(
        train_X: np.ndarray,
        train_y: np.ndarray,
        train_df: pd.DataFrame,
        feat_dim: int,
        feat2idx: FeatIndex,
        evaluator: Evaluator,
        save_dir: str,
        save: bool = True,
) -> None:
    """ Run individual fairness baselines """

    sen_idx = feat2idx.sen_idx
    sen_feat = feat2idx.sen_feat
    num_feat_idx = feat2idx.num_idx
    sen_feat2idx = {feat: idx for feat, idx in feat2idx.feat2idx.items() if feat in sen_feat}

    print("\n=> Project with Logistic Regression")
    clf = Project(base=LogReg())
    clf.fit(X=train_X, y=train_y, train_df=train_df, sen_feat=sen_feat, sen_idx=sen_idx)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "proj_lr.json"), "w") as f:
            json.dump(res, f)

    print("\n=> Project with Neural Networks")
    clf = Project(base=NeuralNets(input_dim=feat_dim))
    clf.fit(X=train_X, y=train_y, train_df=train_df, sen_feat=sen_feat, sen_idx=sen_idx)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "proj_nn.json"), "w") as f:
            json.dump(res, f)

    """
    sen_adv_lr: Adult: 1e-1; Compas: 1e-1; Law School: 1e-1; Dutch: 1e-1;
    """

    # print("\n=> Sensr")
    # clf = SenSR(
    #     input_dim=feat_dim,
    #     sen_dirs_dim=len(sen_idx) + 2,
    #     n_input=train_X.shape[0],
    #     sen_adv_lr=1e-1,
    #     enable_full_adv=True,
    # )
    # clf.fit(X=train_X, y=train_y, train_df=train_df, sen_feat=sen_feat, sen_idx=sen_idx, verbose=True)
    # res = evaluator(clf.pred, clf.pred_proba)
    # if save:
    #     with open(os.path.join(save_dir, "sensr.json"), "w") as f:
    #         json.dump(res, f)

    """
    rho: Adult: 1e+4; Law School: 1e+2; Compas: 1e+4; Dutch: 1e+5;
    """

    print("\n=> SenSeI")
    clf = SenSeI(
        input_dim=feat_dim,
        n_input=train_X.shape[0],
        rho=1e+5,
    )
    clf.fit(X=train_X, y=train_y, train_df=train_df, sen_feat=sen_feat, sen_idx=sen_idx, verbose=True)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "sensei.json"), "w") as f:
            json.dump(res, f)

    """
    dl2_weight: Adult: 10.; Law School: 1.; Compas: 1.; Dutch: 1.;
    """

    print("\n=> LCIFR")
    clf = LCIFR(sen_feat2idx=sen_feat2idx, num_feat_idx=num_feat_idx, dl2_weight=1., n_epochs=100)
    clf.fit(train_X, train_y)
    res = evaluator(clf.pred, clf.pred_proba)
    if save:
        with open(os.path.join(save_dir, "lcifr.json"), "w") as f:
            json.dump(res, f)

    return


if __name__ == "__main__":
    dataset = fetch_dataset("oulad")
    train_X, train_y = dataset.train_data(scale="all")
    test_X, test_y = dataset.test_data(scale="all")
    test_comp_data = dataset.comp_data(batch_size=1024, train=False, scale="all")
    evaluator = Evaluator(test_X, test_y, test_comp_data)

    run_fair_baseline(train_X, train_y, dataset.train_df, dataset.feat_dim, dataset.feat_idx, evaluator,
                      "./save/%s" % dataset.name, save=True)
