import argparse
import logging
import os
import random
import socket
import sys

import numpy as np
import psutil
import setproctitle
import torch
import wandb

# add the FedML root directory to the python path

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from fedml_api.data_preprocessing.Landmarks.data_loader import load_partition_data_landmarks
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10

from fedml_api.model.cv.resnet_fednas.resnet import resnet18
from fedml_api.model.cv.resnet_per.meta_resnet import meta_resnet18

from fedml_api.distributed.fedper.FedAvgAPI import FedML_init, FedML_FedPer_distributed


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # distributed computing configuration
    parser.add_argument('--run_id', type=int, default=0, metavar='N',
                        help='run id')

    parser.add_argument('--backend', type=str, default="MPI",
                        help='Backend for Server and Client')

    # Training settings
    parser.add_argument('--model', type=str, default='mobilenet', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--client_num_in_total', type=int, default=1000, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--client_num_per_round', type=int, default=4, metavar='NN',
                        help='number of workers')

    parser.add_argument('--batch_size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--client_optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001)

    parser.add_argument('--epochs', type=int, default=5, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--comm_round', type=int, default=10,
                        help='how many round of communications we shoud use')

    parser.add_argument('--is_mobile', type=int, default=0,
                        help='whether the program is running on the FedML-Mobile server side')

    parser.add_argument('--frequency_of_the_test', type=int, default=5,
                        help='the frequency of the algorithms')

    parser.add_argument('--gpu_server_num', type=int, default=1,
                        help='gpu_server_num')

    parser.add_argument('--starting_gpu_id', type=int, default=4,
                        help='start_gpu_id')

    parser.add_argument('--gpu_num_per_server', type=int, default=4,
                        help='gpu_num_per_server')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')

    # personalization
    parser.add_argument('--classes_per_client', type=int, default=5,
                        help='personalization dataset: label number per client')

    parser.add_argument('--personalized_model_path', type=str, default="./checkpoint",
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--per_optimizer', type=str, default="FedAvg",
                        help='personalization optimizer. Default: FedAvg; other options: Ditto; perFedAvg')

    parser.add_argument('--personal_local_epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally for the personal model')

    # to save memory
    parser.add_argument('--accumulation_steps', type=int, default=8,
                        help='accumulation_steps')

    # only for Ditto
    # tried 1, now testing 0.1, 0.01
    parser.add_argument('--pssl_lambda', type=float, default=1,
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")
    args = parser.parse_args()
    return args


def load_data(dataset_name):
    if args.dataset == "gld23k":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        # fed_g23k_train_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_train.csv'
        # fed_g23k_test_map_file = '../../../data/gld/data_user_dict/gld23k_user_dict_test.csv'
        # fed_train_map_file = fed_g23k_train_map_file
        # fed_test_map_file = fed_g23k_test_map_file
        fed_train_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv')
        fed_test_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv')
        args.data_dir = os.path.join(args.data_dir, 'images')

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = load_partition_data_landmarks(dataset=dataset_name,
                                                                            data_dir=args.data_dir,
                                                                            fed_train_map_file=fed_train_map_file,
                                                                            fed_test_map_file=fed_test_map_file,
                                                                            partition_method=None, partition_alpha=None,
                                                                            batch_size=args.batch_size)
        logging.info("data point len = %d" % len(test_data_local_dict))
        logging.info("args.client_num_in_total = %d" % args.client_num_in_total)

    else:
        if dataset_name == "cifar10":
            data_loader = load_partition_data_cifar10
        else:
            raise Exception("no such dataset, please change the args.dataset")

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = data_loader(args.dataset, args.data_dir, args.partition_method,
                                                          args.partition_alpha, args.client_num_in_total,
                                                          args.batch_size, args.classes_per_client)
        logging.debug("args.client_num_in_total = %d" % args.client_num_in_total)
    return train_data_num, test_data_num, train_data_global, test_data_global, \
           train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num


def init_training_device(starting_gpu_id, process_ID, fl_worker_num, gpu_num_per_machine):
    if process_ID == 0:
        device = torch.device("cuda:" + str(starting_gpu_id) if torch.cuda.is_available() else "cpu")
        return device
    process_gpu_dict = dict()
    for client_index in range(fl_worker_num):
        gpu_index = (client_index % gpu_num_per_machine)
        process_gpu_dict[client_index] = gpu_index + starting_gpu_id

    logging.info(process_gpu_dict)
    device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu")
    logging.info(device)
    return device


if __name__ == "__main__":

    # initialize distributed computing (MPI)
    comm, process_id, worker_number = FedML_init()

    # customize the log format
    # logging.basicConfig(level=logging.INFO,
    logging.basicConfig(level=logging.DEBUG,
                        format=str(
                            process_id) + ' - %(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                        datefmt='%a, %d %b %Y %H:%M:%S')

    # parse python script input parameters
    parser = argparse.ArgumentParser()
    args = add_args(parser)
    logging.info(args)

    # customize the process name
    str_process_name = "FedPer:" + str(process_id)
    setproctitle.setproctitle(str_process_name)

    hostname = socket.gethostname()
    logging.info("#############process ID = " + str(process_id) +
                 ", host name = " + hostname + "########" +
                 ", process ID = " + str(os.getpid()) +
                 ", process Name = " + str(psutil.Process(os.getpid())))

    # initialize the wandb machine learning experimental tracking platform (https://www.wandb.com/).
    if process_id == 0:
        wandb.init(
            # project="federated_nas",
            project="fednas_extension",
            name="Fedper(d)" + str(args.dataset) + "_" + str(args.partition_method) + "r" + str(args.comm_round) + "-e" + str(
                args.epochs) + "-lr" + str(
                args.lr),
            config=args
        )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    # random.seed(0)
    # np.random.seed(0)
    # torch.manual_seed(0)
    # torch.cuda.manual_seed_all(0)

    seed = worker_number
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    # GPU arrangement: Please customize this function according your own topology.
    # The GPU server list is configured at "mpi_host_file".
    # If we have 4 machines and each has two GPUs, and your FL network has 8 workers and a central worker.
    # The 4 machines will be assigned as follows:
    # machine 1: worker0, worker4, worker8;
    # machine 2: worker1, worker5;
    # machine 3: worker2, worker6;
    # machine 4: worker3, worker7;
    # Therefore, we can see that workers are assigned according to the order of machine list.
    logging.info("process_id = %d, size = %d" % (process_id, worker_number))
    device = init_training_device(args.starting_gpu_id, process_id, worker_number - 1, args.gpu_num_per_server)

    # load data
    """
    # FedNAS
    997323 2021-05-20,19:38:28.171 - {data_loader.py (177)} - partition_data_byclass(): [   29    30    35 ... 59983 59987 59991]
    1997323 2021-05-20,19:38:28.175 - {data_loader.py (177)} - partition_data_byclass(): [    4     5    32 ... 59941 59974 59998]
    1997323 2021-05-20,19:38:28.176 - {data_loader.py (177)} - partition_data_byclass(): [    6    13    18 ... 59970 59982 59989]
    1997323 2021-05-20,19:38:28.177 - {data_loader.py (177)} - partition_data_byclass(): [    9    17    21 ... 59992 59994 59996]
    1997323 2021-05-20,19:38:28.179 - {data_loader.py (177)} - partition_data_byclass(): [    3    10    20 ... 59962 59972 59981]
    1997323 2021-05-20,19:38:28.180 - {data_loader.py (177)} - partition_data_byclass(): [   27    40    51 ... 59985 59993 59997]
    1997323 2021-05-20,19:38:28.181 - {data_loader.py (177)} - partition_data_byclass(): [    0    19    22 ... 59920 59966 59969]
    1997323 2021-05-20,19:38:28.182 - {data_loader.py (177)} - partition_data_byclass(): [    7    11    12 ... 59984 59990 59999]
    1997323 2021-05-20,19:38:28.183 - {data_loader.py (177)} - partition_data_byclass(): [    8    62    69 ... 59986 59988 59995]
    1997323 2021-05-20,19:38:28.184 - {data_loader.py (177)} - partition_data_byclass(): [    1     2    14 ... 59938 59958 59971]
    
    # FedPer
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [   29    30    35 ... 59983 59987 59991]     
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    4     5    32 ... 59941 59974 59998]           
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    6    13    18 ... 59970 59982 59989]
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    9    17    21 ... 59992 59994 59996]           
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    3    10    20 ... 59962 59972 59981]
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [   27    40    51 ... 59985 59993 59997]           
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    0    19    22 ... 59920 59966 59969] 
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    7    11    12 ... 59984 59990 59999]           
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    8    62    69 ... 59986 59988 59995]
    8 - Fri, 21 May 2021 00:21:16 data_loader.py[line:177] INFO [    1     2    14 ... 59938 59958 59971]
    
    # ===> data distribution is the same
    """
    dataset = load_data(args.dataset)
    [train_data_num, test_data_num, train_data_global, test_data_global,
     train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset

    # model related
    if args.per_optimizer == "perFedAvg":
        model = meta_resnet18(num_classes=class_num)
        logging.info("Meta resnet called")
    else:
        model = resnet18(num_classes=class_num)

    args.personalized_model_path = args.personalized_model_path + "/" + str(args.per_optimizer) + \
                                   "_" + str(args.run_id)
    if not os.path.exists(args.personalized_model_path):
        os.mkdir(args.personalized_model_path)

    # start "federated averaging (FedAvg)"
    FedML_FedPer_distributed(process_id, worker_number, device, comm,
                             model, train_data_num, train_data_global, test_data_global,
                             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args)
