from __future__ import print_function
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim

import numpy as np
from eval_metrics import compute_fairness_metrics

import pdb

def _pgd_whitebox(model, X, a, y, args, iter):
    batch_size = X.shape[0]

    model.eval()
    X_pgd = X.clone().detach().requires_grad_(True)

    # random start
    random_noise = 0.001 * torch.randn(X_pgd.shape).cuda()
    temp = X_pgd.data + random_noise
    X_pgd = temp.clone().detach().requires_grad_(True)

    for _ in range(args.num_steps):
        opt = torch.optim.SGD([X_pgd], lr=1e-3)
        opt.zero_grad()
        
        with torch.enable_grad():
            model.train()
            outputs = model(X_pgd)
            probs = torch.softmax(outputs, dim=-1)
            predictions = torch.argmax(outputs, dim=-1)

            mask = probs[:, 1] >= 0.5

            if args.fairness_notion == "dp":
                pred_dis = torch.sum((probs[:, 1] * mask) * (a==1)) / (torch.sum((a==1)) + 1e-6) \
                            - torch.sum((probs[:, 1] * mask) * (a==0)) / (torch.sum((a==0)) + 1e-6)
                
                loss_pgd = torch.abs(pred_dis)

            elif args.fairness_notion == "eqodds":
                # |P(y_hat=1|y=1,a=0) - P(y_hat=1|y=1,a=1)| + |P(y_hat=1|y=0,a=0) - P(y_hat=1|y=0,a=1)|
                pred_dis_1 = torch.sum(probs[:, 1] * a * y) / (torch.sum(a * y) + 1e-6) \
                            - torch.sum(probs[:, 1] * (1 - a) * y) / (torch.sum((1 - a) * y) + 1e-6)
                
                pred_dis_2 = torch.sum(probs[:, 1] * a * (1 - y)) / (torch.sum(a * (1 - y)) + 1e-6) + \
                            - torch.sum(probs[:, 1] * (1 - a) * (1 - y)) / (torch.sum((1 - a) * (1 - y)) + 1e-6)
                
                pred_dis = torch.abs(pred_dis_1) + torch.abs(pred_dis_2)
                loss_pgd = pred_dis

        loss_pgd.backward()

        eta = args.step_size * X_pgd.grad.data.sign() # this is gradient ascent
        X_pgd = (X_pgd.data + eta).clone().detach().requires_grad_(True)
        eta = torch.clamp(X_pgd.data - X.data, -args.epsilon, args.epsilon)
        X_pgd = (X.data + eta).clone().detach().requires_grad_(True) 
    X_pgd = X_pgd.requires_grad_(False)

    return X_pgd