import os
import pdb
import copy
import time
import pickle
import argparse
import numpy as np
from tqdm import tqdm

import torch

from models import *
from dataloader import get_dataset
from client_utils import LocalUpdate

from eval_metrics import compute_accurate_metrics, compute_fairness_metrics

parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=10,
                    help="number of rounds of training")
parser.add_argument('--batch_size', type=int, default=32,
                    help="local batch size: B")
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--num_users', type=int, default=10,
                    help="number of users: K")
parser.add_argument('--local_ep', type=int, default=10,
                    help="the number of local epochs: E")

parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5,
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--weight_decay', type=float, default=1e-4,
                    help='SGD weight decay (default: 1d-4)')

# other arguments
parser.add_argument('--dataset', type=str, default='income', help="name \
                    of dataset")

parser.add_argument('--load_path', default='trained_models', type=str, metavar='PATH',
                    help='path to save checkpoint (default: none)')

# specify predictive attribute
parser.add_argument('--predicitve_attribute_mode', default=None, type=str)

# specify predictive attribute label
parser.add_argument('--label_y', default=100, type=int)

# specify demographic attribute label
parser.add_argument('--label_a', default=100, type=int)

# specify lambda_robust
parser.add_argument('--lambda_robust', default=0.1, type=float)

# specify whether gaussian mode
parser.add_argument('--gaussian_mode', default=None)

# specify whether fairness notion
parser.add_argument('--fairness_notion', default=None)

# specify whether ins or adv
parser.add_argument('--mode', default=None)

# specify adv examples
parser.add_argument('--step_size', default=0.01, type=float)
parser.add_argument('--epsilon', default=0.1, type=float)
parser.add_argument('--num_steps', default=10, type=int)

# specify whether seed
parser.add_argument('--seed', default=0, type=int)


def main():
    global args
    args = parser.parse_args()

    if args.dataset == 'income' and args.predicitve_attribute_mode == 'income_vs_gender':
        pass

    elif args.dataset == 'compas':
        if args.predicitve_attribute_mode == 'crime_vs_african_american':
            args.label_a = 5
        elif args.predicitve_attribute_mode == 'crime_vs_asian':
            args.label_a = 6
        elif args.predicitve_attribute_mode == 'crime_vs_hispanic':
            args.label_a = 7
        elif args.predicitve_attribute_mode == 'crime_vs_native_american':
            args.label_a = 8

    else:
        exit("wrong dataset!!!")
    
    print("We are using ", args.dataset, " dataset and performing ", args.predicitve_attribute_mode)

    if args.predicitve_attribute_mode == 'none':
        exit("predictive attribute not specified!!!")
    
    if args.mode == 'global_only':
        load_dir = os.path.join(args.load_path, args.dataset, args.predicitve_attribute_mode, 
                                'global_only', 'gaussian_mode_' + args.gaussian_mode, 
                                'lambda_robust_' + str(args.lambda_robust))

    elif args.mode == 'adv':
        load_dir = os.path.join(args.load_path, args.dataset, args.predicitve_attribute_mode, 
                                'adv_fairness_notion' + args.fairness_notion, 'gaussian_mode_' + args.gaussian_mode, 
                                'lambda_robust_' + str(args.lambda_robust))
    else:
        exit("wrong mode!")

    seeds_name_list = list(os.listdir(load_dir))
    # pdb.set_trace()
    load_name_list = []
    for seeds_name in seeds_name_list:
        load_name_temp = os.path.join(load_dir, seeds_name, 'model.checkpoint')
        load_name_list.append(load_name_temp)

    avg_acc_his = []
    global_acc_his = []
    acc_std_his = []
    demo_parity_his = []
    eq_odds_his = []

    for load_name in load_name_list:

        # BUILD MODEL
        if args.dataset == 'income':
            global_model = MLP_large().cuda()

        elif args.dataset == 'compas':
            global_model = MLP_small().cuda()

        else:
            exit('Error: unrecognized model')    

        if os.path.isfile(load_name):
            print("=> loading checkpoint '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            global_model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint !! ")
        else:
            exit("No model found!: " + load_name)


        # load dataset and user groups
        _, test_dataset = get_dataset(dataset=args.dataset, num_users=args.num_users, 
                                    label_y=args.label_y, label_a=args.label_a)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, 
                                                shuffle=False, num_workers=args.num_workers, 
                                                pin_memory=True)

        global_model = global_model.eval()
    
        targets_his = []
        predictions_his = []
        sensitive_attributes_his = []
        for _, batch in enumerate(test_loader):
            inputs, sensitive_attributes, targets = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()

            # Inference
            outputs = global_model(inputs)

            # Prediction
            pred_labels = torch.argmax(outputs, dim=1)

            targets_his += targets.tolist()
            predictions_his += pred_labels.tolist()
            sensitive_attributes_his += sensitive_attributes.tolist()

        acc_mean, acc, acc_std = compute_accurate_metrics(np.array(predictions_his), np.array(sensitive_attributes_his), np.array(targets_his))
        demo_parity, eq_odds = compute_fairness_metrics(np.array(predictions_his), np.array(sensitive_attributes_his), np.array(targets_his))
        
        avg_acc_his.append(acc_mean)
        global_acc_his.append(acc)
        acc_std_his.append(acc_std)
        demo_parity_his.append(demo_parity)
        eq_odds_his.append(eq_odds)

        # pdb.set_trace()

    avg_acc_mean = np.mean(avg_acc_his)
    global_acc_mean = np.mean(global_acc_his)
    acc_gap_mean = np.mean(acc_std_his)
    demo_parity_mean = np.mean(demo_parity_his)
    eq_odds_mean = np.mean(eq_odds_his)

    avg_acc_std = np.std(avg_acc_his)
    global_acc_std = np.std(global_acc_his)
    acc_gap_std = np.std(acc_std_his)
    demo_parity_std = np.std(demo_parity_his)
    eq_odds_std = np.std(eq_odds_his)

    print("avg acc, mean: {:.3f}, std: {:.3f}".format(avg_acc_mean, avg_acc_std))
    print("acc gap, mean: {:.3f}, std: {:.3f}".format(acc_gap_mean, acc_gap_std))
    print("demo parity, mean: {:.3f}, std: {:.3f}".format(demo_parity_mean, demo_parity_std))
    print("eq odds, mean: {:.3f}, std: {:.3f}".format(eq_odds_mean, eq_odds_std))

    pdb.set_trace()

if __name__ == "__main__":
    main()