import argparse
import numpy as np
from dataset import libsvm_loader, data_partition
from utils import gradient_x, gradient_alpha, dr_subproblem
import time 

# --------------------------------------------------------------------------- #
# Parse command line arguments (CLAs):
# --------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description='Federated AUC Maximization Problem')
parser.add_argument('--dataset', default='a9a', type=str, help='name of dataset')
parser.add_argument('--reg_lambda', default=0.001, type=float, help='coefficient of regularization')
parser.add_argument('--algorithm', default='DRAUC', type=str, help='name of algorithm')
parser.add_argument('--num_epochs', default=1000, type=int, help='number of epochs to train')
parser.add_argument('--batch_size', default=40, type=int, help='batch size')
parser.add_argument('--learning_rate', default=[0.0001, 0.1], nargs='+', type=float, help='learning rate')
parser.add_argument('--fed_n', default=20, type=int, help='number of workers')
parser.add_argument('--fed_sampling', action='store_true', help='using worker sampling')
parser.add_argument('--fed_local', default=5, type=int, help='number of local epochs')
parser.add_argument('--fed_s', default=5, type=int, help='number of sample workers')
parser.add_argument('--fed_k', default=120, type=int, help='number of local update iterations')
parser.add_argument('--dr_beta', default=0.1, type=float, help='hyper-parameter of DR algorithm')
parser.add_argument('--dr_epsilon', default=0.95, type=float, help='hyper-parameter of DR algorithm')
parser.add_argument('--dr_tolerance', default=0.001, type=float, help='hyper-parameter of DR algorithm')
parser.add_argument('--dr_maxiter', default=50, type=int, help='hyper-parameter of DR algorithm')
parser.add_argument('--stochastic', action='store_true', help='using stochastic gradient oracle')
parser.add_argument('--print_freq', default=10, type=int, help='frequency to print train stats')
parser.add_argument('--out_fname', default='DRAUC.csv', type=str, help='name of the output file')
# --------------------------------------------------------------------------- #


def get_AUC(data, label, w):
    num_pos, num_neg, num_miss = 0, 0, 0
    L = []
    for idx in range(data.shape[0]):
        if label[idx] == 1:
            L.append((np.inner(data[idx, :], w), 0))
            num_pos += 1
        else:
            L.append((np.inner(data[idx, :], w), 1))
            num_neg += 1

    num_pair = num_pos * num_neg
    L_sort = sorted(L, reverse=True)

    s = 0
    for item in L_sort:
        if item[1] == 0:
            num_miss += s
        s += item[1]
    return 1 - (num_miss / num_pair)


def get_val(data, label, w, a, b, alpha, args):
    n_sample = data.shape[0]
    num_pos = sum(label == 1)
    p = num_pos / n_sample

    pos_index = (1 + label) / 2
    neg_index = (1 - label) / 2

    linear = np.sum(data * w, axis=1)
    term1 = (1 - p) * np.mean(pos_index * (linear - a) * (linear - a))
    term2 = p * np.mean(neg_index * (linear - b) * (linear - b))
    term3 = 2 * (1 + alpha) * np.mean(linear * (p * neg_index - (1 - p) * pos_index))
    fval = term1 + term2 + term3 + p * (1 - p) * (1 - alpha * alpha) + args.reg_lambda * sum(np.abs(w))
    return fval


def DRAUC(data, label, data_part, label_part, args):
    reg_lambda = args.reg_lambda
    T = args.num_epochs
    beta = args.dr_beta
    freq = args.print_freq
    out_fname = args.out_fname if args.out_fname else args.algorithm + '.csv'

    n_sample = data.shape[0]
    num_pos = sum(label == 1)
    p = num_pos / n_sample

    x = np.zeros([args.fed_n, data.shape[1] + 2])  # w, a, b
    y = np.copy(x)
    z = np.zeros(data.shape[1] + 2)
    alpha = np.zeros(args.fed_n)

    elapsed_time = 0.0
    oracle = 0

    num_agent = args.fed_s if args.fed_sampling else args.fed_n
    
    with open(out_fname, 'w') as f:
        f.write('iteration,time,oracle,fval,AUC\n')
        f.write('0,0.00,0,0.0000,0.0000\n')

    for iteration in range(T):
        t_begin = time.time()

        if args.fed_sampling:
            sample = np.random.randint(0, args.fed_n, args.fed_s)
            x[sample, :] = x[sample, :] - y[sample, :] + z 
            for k in sample:
                y[k, :], alpha[k], oracle_add = dr_subproblem(data_part[k], label_part[k], p, y[k, :], alpha[k], x[k, :], args)
                oracle += oracle_add
            z = np.mean(2 * y[sample, :] - x[sample, :], axis=0)
        else:
            x = x - y + z
            for k in range(args.fed_n):
                y[k, :], alpha[k], oracle_add = dr_subproblem(data_part[k], label_part[k], p, y[k, :], alpha[k], x[k, :], args)
                oracle += oracle_add
            z = np.mean(2 * y - x, axis=0)
        z[:-2] = np.sign(z[:-2]) * np.maximum(np.abs(z[:-2]) - reg_lambda * beta, 0)

        t_end = time.time()
        elapsed_time += (t_end - t_begin) / num_agent

        if (iteration + 1) % freq == 0:
            w = z[:-2]
            a, b = z[-2], z[-1]
            AUC = get_AUC(data, label, w)
            fval = get_val(data, label, w, a, b, np.mean(alpha), args)
            with open(out_fname, '+a') as f:
                f.write('%d,%.2f,%d,%.4f,%.4f\n' % (iteration + 1, elapsed_time, oracle, fval, AUC))


def main():
    args = parser.parse_args()
    data, label = libsvm_loader(args.dataset)
    data_part, label_part = data_partition(data, label, args.fed_n)
    args.fed_k = 1 + (args.fed_local * data.shape[0]) // (args.fed_n * args.batch_size)
    DRAUC(data, label, data_part, label_part, args)


if __name__ == '__main__':
    main()

