import argparse
import logging
import os
import random
import sys

import numpy as np
import torch
import wandb

from fedml_api.model.cv.cnn_cifar10 import cnn_cifar10

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
from fedml_api.model.cv.vgg import vgg11, vgg16
from fedml_api.model.cv.lenet5 import LeNet5
from fedml_api.standalone.ditto.ditto_api import DittoAPI
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_cifar100
from fedml_api.model.cv.resnet import  customized_resnet18
from fedml_api.data_preprocessing.EMNIST.data_loader import load_partition_data_emnist
from fedml_api.standalone.ditto.my_model_trainer import MyModelTrainer


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--model', type=str, default='resnet18', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--momentum', type=float, default=0, metavar='N',
                        help='momentum')

    parser.add_argument('--data_dir', type=str, default='./../../../data/',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--client_optimizer', type=str, default='sgd',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.005)')

    parser.add_argument('--lr_decay', type=float, default=0.998, metavar='LR_decay',
                        help='learning rate decay (default: 0.998)')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=5e-4)

    parser.add_argument('--epochs', type=int, default=3, metavar='EP',
                        help='how many epochs for global model will be trained locally')

    parser.add_argument('--local_epochs', type=int, default=2, metavar='EP',
                        help='how many epochs for local model will be trained locally')

    parser.add_argument('--client_num_in_total', type=int, default=100, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--frac', type=float, default=0.1, metavar='NN',
                        help='selection fraction each round')

    parser.add_argument('--comm_round', type=int, default=1000,
                        help='how many round of communications we shoud use')

    parser.add_argument('--frequency_of_the_test', type=int, default=1,
                        help='the frequency of the algorithms')

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')
    parser.add_argument("--tag", type=str, default="test")

    parser.add_argument("--lamda", type=float, default=0.5)

    parser.add_argument("--seed", type=int, default=0)
    return parser


def load_data(args, dataset_name):
    if dataset_name == "emnist":
        args.data_dir += "EMNIST"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_emnist(args.data_dir, args.partition_method,
                                args.partition_alpha, args.client_num_in_total, args.batch_size)

    else:
        if dataset_name == "cifar10":
            args.data_dir += "cifar10"
            train_data_num, test_data_num, train_data_global, test_data_global, \
            train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
            class_num = load_partition_data_cifar10(args.data_dir, args.partition_method,
                                    args.partition_alpha, args.client_num_in_total, args.batch_size)
        else:
            if dataset_name == "cifar100":
                args.data_dir += "cifar100"
                train_data_num, test_data_num, train_data_global, test_data_global, \
                train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
                class_num = load_partition_data_cifar100(args.data_dir, args.partition_method,
                                                         args.partition_alpha, args.client_num_in_total,
                                                         args.batch_size)

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset



def create_model(args, model_name,class_num):
    logging.info("create_model. model_name = %s" % (model_name))
    model = None
    if model_name == "lenet5":
        model = LeNet5(class_num=class_num)
    elif model_name == "cnn_cifar10":
        model = cnn_cifar10()
    elif model_name =="resnet18":
        model = customized_resnet18(class_num=class_num)
    elif model_name == "vgg11":
        model = vgg11(class_num)
    return model


def custom_model_trainer(args, model):
    return MyModelTrainer(model, args)


if __name__ == "__main__":
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    parser = add_args(argparse.ArgumentParser(description='FedAvg-standalone'))
    args = parser.parse_args()
    print("torch version{}".format(torch.__version__))
    logger.info(args)
    device = torch.device("cuda:" + str(args.gpu) )
    logger.info(device)
    logging.info("running at device{}".format(device))
    data_partition = args.partition_method
    if data_partition != "homo":
        data_partition += str(args.partition_alpha)
    args.identity = "ditto"  + "-"+data_partition + "-seed"+str(args.seed)
    args. client_num_per_round = int(args.client_num_in_total* args.frac)
    wandb.init(
        project="fedml",
        name=args.identity,
        config=args
        # settings=wandb.Settings(start_method="fork")
    )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    # load data
    dataset = load_data(args, args.dataset)

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, class_num= len(dataset[-1][0]))
    model_trainer = custom_model_trainer(args, model)
    logging.info(model)

    fedAvgAPI = DittoAPI(dataset, device, args, model_trainer)
    fedAvgAPI.train()