import argparse
import logging
import os
import random
import socket
import sys

import numpy as np
import psutil
import setproctitle
import torch
import wandb
# https://nyu-cds.github.io/python-mpi/05-collectives/
from torch import nn

# add the FedML root directory to the python path

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from fedml_api.model.cv.resnet_fednas.resnet import resnet18


from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.data_preprocessing.Landmarks.data_loader import load_partition_data_landmarks

from fedml_api.distributed.fednas_extension.FedNASAPI import FedML_init, FedML_FedNAS_distributed


from fedml_api.model.cv.darts.model_search_workshop_code import Network


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    parser.add_argument("--run_id", type=int, default=0)

    # GPU resource-related
    parser.add_argument('--gpu_server_num', type=int, default=1,
                        help='gpu_server_num')

    parser.add_argument('--gpu_num_per_server', type=int, default=4,
                        help='gpu_num_per_server')
    
    parser.add_argument('--starting_gpu_id', type=int, default=4,
                        help='start_gpu_id')

    # linear: fine-tuning the linear classifier
    # a_linear: fine-tuning the attentive linear classifier
    # nas: fine-tuning the NAS classifier
    # a_nas: fine-tuning the attentive NAS classifier
    parser.add_argument('--design', type=str, default="linear",
                        help='FedNAS Extension design: linear; a_linear; nas; a_nas')

    # data related
    # parser.add_argument('--dataset', type=str, default='cifar100', metavar='N',
    #                     help='dataset used for training')
    parser.add_argument('--dataset', type=str, default='gld', metavar='N',
                        help='dataset used for training')

    # parser.add_argument('--data_dir', type=str, default='./../../../data/cifar100',
    #                     help='data directory')
    parser.add_argument('--data_dir', type=str, default='./../../../data/gld/images',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument("--img_size", default=32, type=int,
                        help="Resolution size")
    # model
    parser.add_argument('--stage', type=str, default='personalized_search',
                        help='stage: search; train')

    parser.add_argument('--model', type=str, default='resnet', metavar='N',
                        help='neural network used in training')
    parser.add_argument('--teacher_model', type=str, default='transformer', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')

    parser.add_argument('--layers', type=int, default=4, help='DARTS layers')

    parser.add_argument("--pretrained_dir", type=str,
                        default="./../../../fedml_api/model/cv/pretrained/Transformer/vit/ViT-B_16.npz",
                        help="Where to search for pretrained vit models.")

    # training/algorithm related
    parser.add_argument('--backend', type=str, default="MPI",
                        help='Backend for Server and Client')

    parser.add_argument('--batch_size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.001)

    parser.add_argument('--epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
                        help='the approximate fixed number of data points we will have on each local worker')

    parser.add_argument('--is_mobile', type=int, default=0,
                        help='whether the program is running on the FedML-Mobile server side')

    parser.add_argument('--client_optimizer', type=str, default='sgd',
                        help='SGD with momentum; adam')

    parser.add_argument('--client_num_in_total', type=int, default=100, metavar='NN',
                        help='number of workers in a distributed clusters (added)')

    parser.add_argument('--client_num_per_round', type=int, default=16, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--comm_round', type=int, default=10,
                        help='how many round of communications we shoud use')

    parser.add_argument('--learning_rate', type=float, default=0.01, help='init learning rate')
    parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')

    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')

    parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
    parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')

    parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
    parser.add_argument('--lambda_train_regularizer', type=float, default=1, help='train regularizer parameter')
    parser.add_argument('--lambda_valid_regularizer', type=float, default=1, help='validation regularizer parameter')
    parser.add_argument('--report_freq', type=float, default=20, help='report frequency')

    parser.add_argument('--tau_max', type=float, default=10, help='initial tau')
    parser.add_argument('--tau_min', type=float, default=1, help='minimum tau')

    parser.add_argument('--local_finetune', type=str, default=False, help='local fine_tune')

    parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
    parser.add_argument('--arch', type=str, default='FedNAS_V1', help='which architecture to use')

    parser.add_argument('--frequency_of_the_test', type=int, default=5, help='the frequency of the test')

    parser.add_argument('--epochs_for_local_fine_tuning', type=int, default=2, help='epochs_for_local_fine_tuning')
    parser.add_argument('--epochs_for_train', type=int, default=3, help='epochs_for_local_fine_tuning')

    parser.add_argument('--path_of_local_model', type=str, default="./checking_points/",
                        help='path_of_local_model')

    parser.add_argument('--local_adapted_local_models', type=str, default="./local_adapted_local_models/",
                        help='path_of_local_model')

    parser.add_argument('--client_sampling', action='store_true', default=True, help='use auxiliary tower')

    parser.add_argument('--gamma', type=float, default=1,
                        help='gamma value for KL loss')

    parser.add_argument('--beta', type=float, default=1,
                        help='beta value for efficiency loss')
    parser.add_argument('--temperature', type=float, default=1,
                        help='temperature parameter for KL loss')

    # parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--pssl_lambda', type=float, default=1,
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")


    parser.add_argument('--base_layers', type=float, default=7,
                        help='base_layers')
    parser.add_argument('--personal_layers', type=float, default=1,
                        help='personal_layers')

    parser.add_argument('--path_of_best_global_model', type=str, default="./checking_point/",
                        help='path_of_best_global_model')

    parser.add_argument('--path_of_best_global_arch_parameter', type=str, default="./checking_point/",
                        help='path_of_best_global_arch_parameter')

    parser.add_argument('--is_debug_mode', type=int, default=1, help='debug mode')
    parser.add_argument('--classes_per_client', type=int, default=5, help='debug mode')
    args = parser.parse_args()
    return args


def init_training_device(starting_gpu_id, process_ID, fl_worker_num, gpu_num_per_machine):
    # initialize the mapping from process ID to GPU ID: <process ID, GPU ID>
    if process_ID == 0:
        device = torch.device("cuda:" + str(starting_gpu_id) if torch.cuda.is_available() else "cpu")
        return device
    process_gpu_dict = dict()
    for client_index in range(fl_worker_num):
        gpu_index = (client_index % gpu_num_per_machine)
        process_gpu_dict[client_index] = gpu_index + starting_gpu_id

    logging.info(process_gpu_dict)
    device = torch.device("cuda:" + str(process_gpu_dict[process_ID - 1]) if torch.cuda.is_available() else "cpu")
    logging.info(device)
    return device


def clear_directory_of_saved_local_models(args):
    # for local adapted models each round
    folder = "./checking_points"
    if not os.path.exists(folder):
        os.mkdir(folder)
    if not os.path.exists(args.path_of_local_model):
        os.mkdir(args.path_of_local_model)
    filelist = [f for f in os.listdir(args.path_of_local_model)]
    logging.info(" Deleting the following files ")
    logging.info(filelist)
    for f in filelist:
        os.remove(os.path.join(args.path_of_local_model, f))

    # for local adaptation from the best global model at last round
    folder = "./local_adapted_local_models"
    if not os.path.exists(folder):
        os.mkdir(folder)
    # if not os.path.exists(args.local_adapted_local_models):
    #     os.mkdir(args.path_of_local_model)
    filelist2 = [f for f in os.listdir(args.local_adapted_local_models)]
    logging.info(" Deleting the following files ")
    logging.info(filelist2)
    for f in filelist2:
        os.remove(os.path.join(args.local_adapted_local_models, f))


def load_data(dataset_name):
    if args.dataset == "gld23k":
        logging.info("load_data. dataset_name = %s" % dataset_name)
        fed_train_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv')
        fed_test_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv')
        args.data_dir = os.path.join(args.data_dir, 'images')

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = load_partition_data_landmarks(dataset=dataset_name,
                                                                            data_dir=args.data_dir,
                                                                            fed_train_map_file=fed_train_map_file,
                                                                            fed_test_map_file=fed_test_map_file,
                                                                            partition_method=None, partition_alpha=None,
                                                                            batch_size=args.batch_size)

        logging.info("data point len = %d" % len(test_data_local_dict))
        logging.info("args.client_num_in_total = %d" % args.client_num_in_total)

    else:
        if dataset_name == "cifar10":
            data_loader = load_partition_data_cifar10
        else:
            raise Exception("no such dataset, please change the args.dataset")

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = data_loader(args.dataset, args.data_dir, args.partition_method,
                                                          args.partition_alpha, args.client_num_in_total,
                                                          args.batch_size, args.classes_per_client)
        logging.debug("args.client_num_in_total = %d" % args.client_num_in_total)
    return train_data_num, test_data_num, train_data_global, test_data_global, \
           train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num


if __name__ == "__main__":
    # customize the log format
    logging.basicConfig(level=logging.INFO,
                        format='%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s',
                        datefmt='%Y-%m-%d,%H:%M:%S')

    # initialize distributed computing (MPI)
    comm, process_id, worker_number = FedML_init()

    parser = argparse.ArgumentParser()
    args = add_args(parser)
    logging.info(args)

    # to avoid I/O conflict, we make different runs use caching in different folders.
    args.path_of_local_model = args.path_of_local_model + str(args.run_id) + "/"

    str_process_name = "FedNAS:" + str(process_id)
    setproctitle.setproctitle(str_process_name)
    hostname = socket.gethostname()
    logging.info("#############process ID = " + str(process_id) +
                 ", host name = " + hostname + "########" +
                 ", process ID = " + str(os.getpid()) +
                 ", process Name = " + str(psutil.Process(os.getpid())))

    # initialize the wandb machine learning experimental tracking platform (https://www.wandb.com/).
    if process_id == 0:
        run = wandb.init(
            # project="federated_nas",
            project="fednas_extension",
            name=str(args.dataset) + "_" + str(args.stage) + "_FT_" + str(args.local_finetune) +"_r" + str(args.run_id) + "_basearch_" + str(args.model) + "_l" +
                 str(args.layers) + "_lr" + str(args.lr),
            config=args
        )

    args.path_of_best_global_model = args.path_of_best_global_model + \
                                     str(args.dataset) + "_r" + str(args.run_id) + "_" + str(args.design) + \
                                     "_best_global_model.pth"
    args.path_of_best_global_arch_parameter = args.path_of_best_global_arch_parameter + \
                                              str(args.dataset) + "_r" + str(args.run_id) + "_" + str(args.design) + \
                                              "_best_global_arch_params.pth"
    # 1. when data partition and personalized accuracy was added

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    seed = worker_number
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = init_training_device(args.starting_gpu_id, process_id, worker_number - 1, args.gpu_num_per_server)

    # load data
    train_data_num, test_data_num, train_data_global, test_data_global, \
    train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num = load_data(args.dataset)
    if process_id == 0:
        clear_directory_of_saved_local_models(args)


    if args.stage == 'fednas_search':
        criterion = nn.CrossEntropyLoss()
        model = Network(args.init_channels, class_num, args.layers, criterion, device, args) #FedNAS 8 layer model

        FedML_FedNAS_distributed(process_id, worker_number, device, comm,
                                 model, train_data_num, train_data_global, test_data_global,
                                 train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args,
                                 args.client_num_in_total, None, None)

    elif args.stage == 'train': # FedAvg
        model = resnet18(num_classes = class_num)
        FedML_FedNAS_distributed(process_id, worker_number, device, comm,
                                 model, train_data_num, train_data_global, test_data_global,
                                 train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args,
                                 args.client_num_in_total, None, None)
    else:
        raise Exception("Learning stage not specified")

