import argparse
import logging
import os
import random
import sys

import numpy as np
import torch
import wandb


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_cifar100
from fedml_api.model.cv.vgg import vgg16, vgg11
from fedml_api.model.cv.lenet5 import LeNet5, LeNet5_cifar
from fedml_api.model.cv.cnn_cifar10 import cnn_cifar10
from fedml_api.standalone.fedspa.fedspa_api import FedSpaAPI
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.model.cv.resnet import customized_resnet18
from fedml_api.data_preprocessing.EMNIST.data_loader import load_partition_data_emnist
from fedml_api.standalone.fedspa.my_model_trainer import MyModelTrainer



def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    parser.add_argument('--model', type=str, default='cnn_cifar10', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='./../../../data/',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.1, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--client_optimizer', type=str, default='sgd',
                        help='SGD with momentum; d')

    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.005)')

    parser.add_argument('--lr_decay', type=float, default=0.998, metavar='LR_decay',
                        help='learning rate decay (default: 0.998)')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=5e-4)

    parser.add_argument('--epochs', type=int, default=5, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--client_num_in_total', type=int, default=100, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--frac', type=float, default=0.1, metavar='NN',
                        help='selection fraction each round')

    parser.add_argument('--momentum', type=float, default=0, metavar='NN',
                        help='momentum')

    parser.add_argument('--comm_round', type=int, default=1000,
                        help='how many round of communications we shoud use')

    parser.add_argument('--frequency_of_the_test', type=int, default=1,
                        help='the frequency of the algorithms')

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')

    parser.add_argument('--dense_ratio', type=float, default=0.5,
                        help='final slim ratio')

    parser.add_argument('--anneal_factor', type=float, default=0.5,
                        help='anneal factor for pruning')

    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--public_portion", type=float, default=0)
    parser.add_argument("--erk_power_scale", type=float, default=1 )
    parser.add_argument("--dis_gradient_check", action='store_true')
    parser.add_argument("--strict_avg", action='store_true')
    parser.add_argument("--static", action='store_true')
    parser.add_argument("--uniform", action='store_true')
    parser.add_argument("--save_masks", action='store_true')
    parser.add_argument("--different_initial", action='store_true')
    parser.add_argument("--record_mask_diff", action='store_true')
    parser.add_argument("--global_test", action='store_true')
    parser.add_argument("--tag", type=str, default="test")
    return parser


def load_data(args, dataset_name):
    if dataset_name == "emnist":
        args.data_dir += "EMNIST"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_emnist(args.data_dir, args.partition_method,
                                args.partition_alpha, args.client_num_in_total, args.batch_size)

    else:
        if dataset_name == "cifar10":
            args.data_dir += "cifar10"
            train_data_num, test_data_num, train_data_global, test_data_global, \
            train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
            class_num = load_partition_data_cifar10(args.data_dir, args.partition_method,
                                    args.partition_alpha, args.client_num_in_total, args.batch_size)
        else:
            if dataset_name == "cifar100":
                args.data_dir += "cifar100"
                train_data_num, test_data_num, train_data_global, test_data_global, \
                train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
                class_num = load_partition_data_cifar100(args.data_dir, args.partition_method,
                                                        args.partition_alpha, args.client_num_in_total, args.batch_size)

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset



def create_model(args, model_name,class_num):
    logging.info("create_model. model_name = %s" % (model_name))
    model = None
    if model_name == "lenet5":
        model = LeNet5(class_num)
    if model_name == "lenet5_cifar":
        model = LeNet5_cifar(class_num)
    elif model_name == "cnn_cifar10":
        model = cnn_cifar10()
    elif model_name =="resnet18":
        model = customized_resnet18(class_num=class_num)
    elif model_name == "vgg11":
        model = vgg11(class_num)
    return model


def custom_model_trainer(args, model):
    return MyModelTrainer(model, args)


if __name__ == "__main__":
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    parser = add_args(argparse.ArgumentParser(description='FedAvg-standalone'))
    args = parser.parse_args()
    logger.info(args)
    print("torch version{}".format(torch.__version__))
    if args.gpu == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    logger.info(device)
    data_partition=args.partition_method
    if data_partition!="homo":
        data_partition+=str(args.partition_alpha)
    args.identity = "fedspa" +"-dr" +str(args.dense_ratio)+"rigl" +str(not args.dis_gradient_check)+"-"+data_partition+"-static"+str(args.static)+"-shared"+str(args.public_portion)+"-strict_avg"+str(args.strict_avg)+"-seed"+str(args.seed)
    if args.save_masks:
        args.identity+="-masks"
    if args.uniform:
        args.identity += "-u"
    if args. different_initial:
        args.identity += "-d"
    if args. global_test:
        args.identity += "-g"
    args. client_num_per_round = int(args.client_num_in_total* args.frac)
    wandb.init(
        project="fedml",
        name=args.identity,
        config=args,
        # settings=wandb.Settings(start_method="fork")
    )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    # load data
    dataset = load_data(args, args.dataset)

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model,class_num=len(dataset[-1][0]))
    model_trainer = custom_model_trainer(args, model)
    logging.info(model)

    fedspaAPI = FedSpaAPI(dataset, device, args, model_trainer)
    fedspaAPI.train()
