import argparse
import torch.nn.functional as F
from tqdm import tqdm

from dataset import get_dataset
from evaluation import evaluate
from model import *
from utils import *

import warnings
warnings.filterwarnings('ignore')


def run(data, args):
    pbar = tqdm(range(args.runs), unit='run')

    acc, f1, auc_roc, parity, equality = np.zeros(args.runs), np.zeros(
        args.runs), np.zeros(args.runs), np.zeros(args.runs), np.zeros(args.runs)

    data = data.to(args.device)

    classifier = MLP_classifier(args).to(args.device)
    optimizer_c = torch.optim.Adam([
        dict(params=classifier.lin.parameters(), weight_decay=args.c_wd)], lr=args.c_lr)

    if(args.encoder == 'MLP'):
        encoder = MLP_encoder(args).to(args.device)
        optimizer_e = torch.optim.Adam([
            dict(params=encoder.lin.parameters(), weight_decay=args.e_wd)], lr=args.e_lr)
    elif(args.encoder == 'GCN'):
        encoder = GCN_encoder_scatter(args).to(args.device)
        optimizer_e = torch.optim.Adam([
            dict(params=encoder.lin.parameters(), weight_decay=args.e_wd),
            dict(params=encoder.bias, weight_decay=args.e_wd)], lr=args.e_lr)
    elif(args.encoder == 'GIN'):
        encoder = GIN_encoder(args).to(args.device)
        optimizer_e = torch.optim.Adam([
            dict(params=encoder.conv.parameters(), weight_decay=args.e_wd)], lr=args.e_lr)
    elif(args.encoder == 'SAGE'):
        encoder = SAGE_encoder(args).to(args.device)
        optimizer_e = torch.optim.Adam([
            dict(params=encoder.conv.parameters(), weight_decay=args.e_wd)], lr=args.e_lr)

    for count in pbar:
        seed_everything(count + args.seed)

        s0_mask = torch.logical_and((data.sens == 0), data.train_mask)
        s1_mask = torch.logical_and((data.sens == 1), data.train_mask)
        s0y1_mask = torch.logical_and((data.y == 1), s0_mask)
        s1y1_mask = torch.logical_and((data.y == 1), s1_mask)

        X = data.x.clone()
        sens = data.sens.clone().float()
        X_centered = X - X.mean(dim=0)
        sens_centered = sens - sens.mean()
        corr = (X_centered * sens_centered[:, None]).sum(dim=0) / (
            torch.sqrt((X_centered**2).sum(dim=0)) * torch.sqrt((sens_centered**2).sum())
        )
        _, topk_indices = torch.topk(corr.abs(), args.k+1)

        new_data = data.clone()
        if args.k >= 0:
            new_data.x = sens_shuffle(data.sens, data.edge_index, data.x, topk_indices, args.sens_idx, args.x_max, args.x_min)

        classifier.reset_parameters()
        encoder.reset_parameters()

        best_val_tradeoff = 0

        for epoch in range(0, args.epochs):
            # train classifier
            classifier.train()
            encoder.train()
            for epoch_c in range(0, args.c_epochs):
                optimizer_c.zero_grad()
                optimizer_e.zero_grad()
                h = encoder(new_data.x, new_data.edge_index, new_data.adj_norm_sp)
                output = classifier(h)
                loss_c = F.binary_cross_entropy_with_logits(
                    output[new_data.train_mask], new_data.y[new_data.train_mask].unsqueeze(1).to(args.device))

                if args.lamda != 0:
                    probs = torch.sigmoid(output)
                    
                    # dp
                    rate_0 = probs[s0_mask].mean()
                    rate_1 = probs[s1_mask].mean()
                    dp_loss = torch.abs(rate_0 - rate_1)

                    # eo
                    tpr_0 = probs[s0y1_mask].mean()
                    tpr_1 = probs[s1y1_mask].mean()
                    eo_loss = torch.abs(tpr_0 - tpr_1)

                    loss_c = loss_c + args.lamda * (dp_loss + eo_loss)

                loss_c.backward()

                optimizer_e.step()
                optimizer_c.step()

            # evaluate classifier
            accs, auc_rocs, F1s, tmp_parity, tmp_equality = evaluate(
                data.x, classifier, None, encoder, data, args)

            if auc_rocs['val'] + F1s['val'] + accs['val'] - (tmp_parity['val'] + tmp_equality['val']) > best_val_tradeoff:
                test_acc = accs['test']
                test_auc_roc = auc_rocs['test']
                test_f1 = F1s['test']
                test_parity, test_equality = tmp_parity['test'], tmp_equality['test']
                best_val_tradeoff = auc_rocs['val'] + F1s['val'] + accs['val'] - (tmp_parity['val'] + tmp_equality['val'])

        acc[count] = test_acc
        f1[count] = test_f1
        auc_roc[count] = test_auc_roc
        parity[count] = test_parity
        equality[count] = test_equality

    return acc, f1, auc_roc, parity, equality


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    """DATASET"""
    parser.add_argument('--dataset', type=str, default='german',)

    """FAIREST"""
    parser.add_argument('--k', type=int, default=0,)
    parser.add_argument('--lamda', type=float, default=0,)

    """EXPERIMENT"""
    parser.add_argument('--runs', type=int, default=5,)
    parser.add_argument('--encoder', type=str, default='GCN',)
    parser.add_argument('--epochs', type=int, default=200,)
    parser.add_argument('--c_epochs', type=int, default=5,)
    parser.add_argument('--seed', type=int, default=1,)

    """HYPER-PARAMETERS"""
    parser.add_argument('--c_lr', type=float, default=0.001,)
    parser.add_argument('--c_wd', type=float, default=0,)
    parser.add_argument('--e_lr', type=float, default=0.001,)
    parser.add_argument('--e_wd', type=float, default=0,)
    parser.add_argument('--dropout', type=float, default=0.5,)
    parser.add_argument('--hidden', type=int, default=16,)

    args = parser.parse_args()
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data, args.sens_idx, args.x_min, args.x_max = get_dataset(args.dataset)

    args.num_features, args.num_classes = data.x.shape[1], 2-1 # binary classes are 0,1 

    acc, f1, auc_roc, parity, equality = run(data, args)
    print('======' + args.dataset + args.encoder + '======')
    print('Acc:', round(np.mean(acc) * 100,2), '±' ,round(np.std(acc) * 100,2), sep='')
    print('f1:', round(np.mean(f1) * 100,2), '±' ,round(np.std(f1) * 100,2), sep='')
    print('parity:', round(np.mean(parity) * 100,2), '±', round(np.std(parity) * 100,2), sep='')
    print('equality:', round(np.mean(equality) * 100,2), '±', round(np.std(equality) * 100,2), sep='')
