import os
import copy
import time
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import random 

import torch

from models import *
from dataloader import get_dataset, Gaussian_stats
from client_utils import LocalUpdate


parser = argparse.ArgumentParser()

# federated arguments (Notation for the arguments followed from paper)
parser.add_argument('--epochs', type=int, default=10,
                    help="number of rounds of training")
parser.add_argument('--batch_size', type=int, default=64,
                    help="local batch size: B")
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--num_users', type=int, default=10,
                    help="number of users: K")
parser.add_argument('--local_ep', type=int, default=10,
                    help="the number of local epochs: E")

parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5,
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--weight_decay', type=float, default=1e-4,
                    help='SGD weight decay (default: 1d-4)')

# other arguments
parser.add_argument('--dataset', type=str, default='income', help="name \
                    of dataset")

parser.add_argument('--save_path', default='trained_models', type=str, metavar='PATH',
                    help='path to save checkpoint (default: none)')

# specify predictive attribute
parser.add_argument('--predicitve_attribute_mode', default=None, type=str)

# specify predictive attribute label
parser.add_argument('--label_y', default=100, type=int)

# specify demographic attribute label
parser.add_argument('--label_a', default=100, type=int)

# specify mean of gaussian
parser.add_argument('--mu', default=None)

# specify std of gaussian
parser.add_argument('--std', default=None)

# specify lambda_robust
parser.add_argument('--lambda_robust', default=0.1, type=float)

# specify whether gaussian mode
parser.add_argument('--gaussian_mode', default=None)

# specify whether fairness notion
parser.add_argument('--fairness_notion', default=None)

# specify adv examples
parser.add_argument('--step_size', default=0.005, type=float)
parser.add_argument('--epsilon', default=0.05, type=float)
parser.add_argument('--num_steps', default=10, type=int)

# specify whether seed
parser.add_argument('--seed', default=0, type=int)


def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

def main():
    global args
    args = parser.parse_args()

    if args.dataset == 'income' and args.predicitve_attribute_mode == 'income_vs_gender':
        pass

    elif args.dataset == 'compas':
        if args.predicitve_attribute_mode == 'crime_vs_african_american':
            args.label_a = 5
        elif args.predicitve_attribute_mode == 'crime_vs_asian':
            args.label_a = 6
        elif args.predicitve_attribute_mode == 'crime_vs_hispanic':
            args.label_a = 7
        elif args.predicitve_attribute_mode == 'crime_vs_native_american':
            args.label_a = 8

    else:
        exit("wrong dataset!!!")
    
    print("We are using ", args.dataset, " dataset and performing ", args.predicitve_attribute_mode)

    if args.predicitve_attribute_mode == 'none':
        exit("predictive attribute not specified!!!")
    
    save_dir = os.path.join(args.save_path, args.dataset, args.predicitve_attribute_mode, 
                            'adv_fairness_notion' + args.fairness_notion, 'gaussian_mode_' + args.gaussian_mode, 
                            'lambda_robust_' + str(args.lambda_robust), 'seed_' + str(args.seed))
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print("we are saving model to: ", save_dir)
    
    # load dataset and user groups
    user_groups, test_dataset = get_dataset(dataset=args.dataset, num_users=args.num_users, 
                                            label_y=args.label_y, label_a=args.label_a)

    if args.gaussian_mode == 'default':
        args.mu = 0
        args.std = 1

    elif args.gaussian_mode == 'manual':
        mu, std = Gaussian_stats(user_groups)
        args.mu = mu
        args.std = std
        
    torch.backends.cudnn.deterministic = True
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    # BUILD MODEL
    if args.dataset == 'income':
        global_model = MLP_large()

    elif args.dataset == 'compas':
        global_model = MLP_small()

    else:
        exit('Error: unrecognized model')    

    global_model.cuda()
    global_model.train()
    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in range(args.epochs):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
            
        for idx in range(args.num_users):
            local_model = LocalUpdate(args=args, user_data=user_groups[idx], user_id=idx)
            w, loss = local_model.update_weights_ins_adv(model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)
  
        # global validation
        global_test(args, global_model, test_dataset)

    global_model_save_name= os.path.join(save_dir, 'model.checkpoint')
    global_model_checkpoint = {'state_dict': global_model.state_dict()}

    torch.save(global_model_checkpoint, global_model_save_name)

def global_test(args, global_model, test_dataset):
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, 
                                            shuffle=False, num_workers=args.num_workers, 
                                            pin_memory=True)

    global_model = global_model.eval()
    total, correct = 0.0, 0.0
    for _, batch in enumerate(test_loader):
        inputs, targets = batch[0].cuda(), batch[2].cuda()

        # Inference
        outputs = global_model(inputs)

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, targets)).item()
        total += len(targets)

    accuracy = correct/total
    
    print("Global Test Accuracy: {:.2f}%".format(100*accuracy))

if __name__ == "__main__":
    main()