import argparse
import logging
import os
import random
import socket
import sys
import traceback

import numpy as np
import psutil
import setproctitle
import torch
from mpi4py import MPI

import wandb

# add the FedML root directory to the python path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../../")))

from fedml_api.distributed.fedssl.FedSSLAPI import FedML_FedSSL_distributed
from fedml_api.distributed.utils.gpu_mapping import mapping_processes_to_gpu_device_from_yaml_file
from fedml_api.data_preprocessing.FederatedEMNIST.data_loader import load_partition_data_federated_emnist
from fedml_api.data_preprocessing.fed_cifar100.data_loader import load_partition_data_federated_cifar100
from fedml_api.data_preprocessing.fed_shakespeare.data_loader import load_partition_data_federated_shakespeare
from fedml_api.data_preprocessing.shakespeare.data_loader import load_partition_data_shakespeare
from fedml_api.data_preprocessing.stackoverflow_lr.data_loader import load_partition_data_federated_stackoverflow_lr
from fedml_api.data_preprocessing.stackoverflow_nwp.data_loader import load_partition_data_federated_stackoverflow_nwp
from fedml_api.data_preprocessing.MNIST.data_loader import load_partition_data_mnist
from fedml_api.data_preprocessing.ImageNet.data_loader import load_partition_data_ImageNet

from fedml_api.data_preprocessing.Landmarks_per.data_loader import load_partition_data_landmarks

from fedml_api.data_preprocessing.cifar10_per.data_loader import load_partition_data_cifar10_ssl, \
    load_partition_data_cifar10_ssl_linear_eval
from fedml_api.data_preprocessing.cifar100_per.data_loader import load_partition_data_cifar100_ssl

from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10_for_ssl

from fedml_api.distributed.fedavg.FedAvgAPI import FedML_init

from fedml_api.model.cv.ssl import get_ssl_model


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--run_id', type=int, default=0, metavar='N',
                        help='run id')

    parser.add_argument('--model', type=str, default='resnet18', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--num_class', type=int, default=10,
                        help='number of classes')

    parser.add_argument('--image_size', type=int, default=224,
                        help='number of classes')

    parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--client_num_in_total', type=int, default=1000, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--client_num_per_round', type=int, default=4, metavar='NN',
                        help='number of workers')

    parser.add_argument('--batch_size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--accumulation_steps', type=int, default=8,
                        help='accumulation_steps')

    parser.add_argument('--path_of_local_model', type=str, default="./checking_points/",
                        help='path_of_local_model')

    parser.add_argument('--client_optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=1e-4)

    parser.add_argument('--perFedAvg_is_first_order', help='perFedAvg_is_first_order', action='store_true')

    parser.add_argument('--epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--comm_round', type=int, default=800,
                        help='how many round of communications we shoud use')

    parser.add_argument('--is_mobile', type=int, default=0,
                        help='whether the program is running on the FedML-Mobile server side')

    parser.add_argument('--frequency_of_the_test', type=int, default=1,
                        help='the frequency of the algorithms')

    parser.add_argument('--gpu_server_num', type=int, default=1,
                        help='gpu_server_num')

    parser.add_argument('--gpu_num_per_server', type=int, default=4,
                        help='gpu_num_per_server')

    parser.add_argument('--gpu_mapping_file', type=str, default="gpu_mapping.yaml",
                        help='the gpu utilization file for servers and clients. If there is no \
                        gpu_util_file, gpu will not be used.')

    parser.add_argument('--gpu_mapping_key', type=str, default="mapping_default",
                        help='the key in gpu utilization file')

    # personalization
    # m1: model interpolation
    parser.add_argument('--using_personalized_data', type=int, default=1,
                        help='is linear evaluation')

    parser.add_argument('--personalized_model_path', type=str, default="./checkpoint",
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--pssl_optimizer', type=str, default="pFedMe",
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--pssl_lambda', type=float, default=1,
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--personal_local_epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally for the personal model')

    parser.add_argument('--pfedme_beta', type=int, default=0.001, metavar='EP',
                        help='moving averaging beta value.')

    # self-supervised related

    parser.add_argument('--ssl_method', type=str, default="simsiam",
                        help='self-supervised learning method: byol, simclr, moco, etc')

    parser.add_argument('--ssl_is_supervised', type=int, default=0,
                        help='is linear evaluation')

    parser.add_argument('--ssl_is_linear_eval', type=int, default=0,
                        help='is linear evaluation')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')
    args = parser.parse_args()
    return args


def load_data(args, dataset_name):
    if dataset_name == "mnist":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_mnist(args.batch_size)
        """
        For shallow NN or linear models, 
        we uniformly sample a fraction of clients each round (as the original FedAvg paper)
        """
        args.client_num_in_total = client_num

    elif dataset_name == "femnist":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir)
        args.client_num_in_total = client_num

    elif dataset_name == "shakespeare":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_shakespeare(args.batch_size)
        args.client_num_in_total = client_num

    elif dataset_name == "fed_shakespeare":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir)
        args.client_num_in_total = client_num

    elif dataset_name == "fed_cifar100":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir)
        args.client_num_in_total = client_num
    elif dataset_name == "stackoverflow_lr":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir)
        args.client_num_in_total = client_num
    elif dataset_name == "stackoverflow_nwp":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        client_num, train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir)
        args.client_num_in_total = client_num
    elif dataset_name == "ILSVRC2012":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_ImageNet(dataset=dataset_name, data_dir=args.data_dir,
                                                 partition_method=None, partition_alpha=None,
                                                 client_number=args.client_num_in_total, batch_size=args.batch_size)

    elif dataset_name == "gld23k_per":
        fed_train_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv')
        fed_test_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv')
        data_dir = os.path.join(args.data_dir, 'images')

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = load_partition_data_landmarks(dataset=dataset_name,
                                                                            data_dir=data_dir,
                                                                            ssl_method=args.ssl_method,
                                                                            fed_train_map_file=fed_train_map_file,
                                                                            fed_test_map_file=fed_test_map_file,
                                                                            partition_method=None, partition_alpha=None,
                                                                            batch_size=args.batch_size,
                                                                            gradient_acc_steps=args.accumulation_steps)

        logging.info("data point len = %d" % len(test_data_local_dict))
        logging.info("args.client_num_in_total = %d" % args.client_num_in_total)

    else:
        if dataset_name == "cifar10":
            if args.using_personalized_data == 0:
                data_loader = load_partition_data_cifar10_for_ssl
            else:
                if args.ssl_is_linear_eval:
                    data_loader = load_partition_data_cifar10_ssl_linear_eval
                else:
                    data_loader = load_partition_data_cifar10_ssl
        elif dataset_name == "cifar100":
            data_loader = load_partition_data_cifar100_ssl
        else:
            data_loader = load_partition_data_cifar10_ssl

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = data_loader(args.dataset, args.data_dir, args.ssl_method, args.partition_method,
                                args.partition_alpha, args.client_num_in_total,
                                args.batch_size, args.accumulation_steps)
    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset


def create_model(args, model_name, output_dim):
    logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim))
    args.num_class = output_dim
    model = get_ssl_model(args.num_class, args.ssl_method, args.pssl_optimizer, model_name)
    return model


if __name__ == "__main__":

    # initialize distributed computing (MPI)
    comm, process_id, worker_number = FedML_init()

    # parse python script input parameters
    parser = argparse.ArgumentParser()
    args = add_args(parser)

    # customize the process name
    str_process_name = "Distributed Training:" + str(process_id)
    setproctitle.setproctitle(str_process_name)

    # customize the log format
    # logging.basicConfig(level=logging.INFO,
    logging.basicConfig(level=logging.DEBUG,
                        format='%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s',
                        datefmt='%Y-%m-%d,%H:%M:%S')

    logging.info(args)

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    hostname = socket.gethostname()
    logging.info("#############process ID = " + str(process_id) +
                 ", host name = " + hostname + "########" +
                 ", process ID = " + str(os.getpid()) +
                 ", process Name = " + str(psutil.Process(os.getpid())))

    # initialize the wandb machine learning experimental tracking platform (https://www.wandb.com/).
    if process_id == 0:
        wandb.init(project="self-supervised-fl", entity="automl", config=args,
                   name="SSFL-" + args.ssl_method + "-" + str(args.pssl_optimizer) + "_" + str(args.ssl_is_linear_eval))

    # Please check "GPU_MAPPING.md" to see how to define the topology
    logging.info("process_id = %d, size = %d" % (process_id, worker_number))

    device = mapping_processes_to_gpu_device_from_yaml_file(process_id, worker_number, args.gpu_mapping_file,
                                                            args.gpu_mapping_key)

    # load data
    dataset = load_data(args, args.dataset)
    [train_data_num, test_data_num, train_data_global, test_data_global,
     train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset

    # create model.
    # Note if the model is DNN (e.g., ResNet), the training will be very slow.
    # In this case, please use our FedML distributed version (./fedml_experiments/distributed_fedavg)
    model = create_model(args, model_name=args.model, output_dim=dataset[7])

    FedML_FedSSL_distributed(process_id, worker_number, device, comm,
                             model, train_data_num, train_data_global, test_data_global,
                             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args)