import argparse
import logging
import os
import sys
import os.path as osp

import numpy as np
import torch
from mpi4py import MPI


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
from data_preprocessing.data_loading import load_partition_target_data, load_partition_target_data_UD, load_source_data
from fedxdd.FedXDDAPI import FedML_init, FedML_FedXDD_distributed, FedML_FedXDD_DA_completed
from model import LiteResidualModule
import model.network as network


from torchvision.models import * # important! all resnets contained
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--partition_method', type=str, default='homo', metavar='N',
                        help='how to partition the dataset on local workers')
    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')
    parser.add_argument('--batch_size', type=int, default=8, metavar='N',
                        help='input batch size for training (default: 8)')
    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=1e-3)
    parser.add_argument('--epochs_client', type=int, default=3, metavar='EP',
                        help='how many epochs will be trained locally')
    parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
                        help='the approximate fixed number of data points we will have on each local worker')
    parser.add_argument('--client_number', type=int, default=1, metavar='NN',
                        help='number of workers in a distributed cluster')
    parser.add_argument('--comm_round', type=int, default=30,
                        help='how many round of communications we shoud use')
    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')
    parser.add_argument('--rho', type=float, default=0.3, metavar='R', help='learning rate (default: 0.001)')
    # knowledge distillation
    parser.add_argument('--temperature', default=2.0, type=float, help='Input the temperature: default(3.0)')
    parser.add_argument('--epochs_server', type=int, default=3, metavar='EP',
                        help='how many epochs will be trained on the server side')
    parser.add_argument('--alpha', default=0.8, type=float, help='Input the relative weight: default(0.8)')
    parser.add_argument('--optimizer', default="SGD", type=str, help='optimizer: SGD, Adam, etc.')
    parser.add_argument('--whether_training_on_client', default=1, type=int)
    parser.add_argument('--whether_distill_on_the_server', default=0, type=int)
    parser.add_argument('--running_name', default="default", type=str)
    parser.add_argument('--sweep', default=0, type=int)
    parser.add_argument('--multi_gpu_server', action='store_true')
    parser.add_argument('--test', action='store_true',
                        help='test mode, only run 1-2 epochs to test the bug of the program')
    parser.add_argument('--gpu_num_per_server', type=int, default=1,
                        help='gpu_num_per_server')
    '''lite domain adaptation config'''
    parser.add_argument('--root_path', type=str, default='../data/')
    parser.add_argument('--src', type=str, default='amazon')
    parser.add_argument('--tar', type=str, default='webcam')
    parser.add_argument('--max_epoch', type=int, default=30, help="max iterations")
    parser.add_argument('--interval', type=int, default=30)
    parser.add_argument('--worker', type=int, default=4, help="number of workers")
    parser.add_argument('--dset', type=str, default='office',
                        choices=['VISDA-C', 'office', 'office-home', 'office-caltech', 'imageCLEF'])
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--net', type=str, default='resnet50',
                        help="resnet50, resnet101", choices=['resnet50','resnet101'])
    parser.add_argument('--compact_net', type=str, default='SHOTresnet18',
                        help="Type of compact net, SHOTresnet18 or SHOTresnet34", choices=['SHOTresnet18','SHOTresnet34'])
    parser.add_argument('--seed', type=int, default=2020, help="random seed")
    parser.add_argument('--DA_max_epoch', type=int, default=2, help="maximal number of epoch of DA")
    parser.add_argument('--is_feedback', action="store_true", help="enable knowledge feedback")
    parser.add_argument('--kd_ratio', type=float, default=0.3, help="weight of KD loss")
    parser.add_argument('--percent', type=float, default=1.0, help="percent of target data used in DA")

    parser.add_argument('--gent', type=bool, default=True)
    parser.add_argument('--ent', type=bool, default=True)
    parser.add_argument('--threshold', type=int, default=0)
    parser.add_argument('--cls_par', type=float, default=0.3)
    parser.add_argument('--ent_par', type=float, default=1.0)
    parser.add_argument('--lr_decay1', type=float, default=0.1)
    parser.add_argument('--lr_decay2', type=float, default=1.0)

    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--epsilon', type=float, default=1e-5)
    parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"])
    parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
    parser.add_argument('--output', type=str, default='./ckps/target/lite')
    parser.add_argument('--output_src', type=str, default='./ckps/source/')
    parser.add_argument('--da', type=str, default='uda', choices=['uda'])
    parser.add_argument('--issave', type=bool, default=True)
    """ lite residual module configs """
    parser.add_argument('--lite_residual_downsample', type=int, default=2)
    parser.add_argument('--lite_residual_expand', type=int, default=1)
    parser.add_argument('--lite_residual_groups', type=int, default=2)
    parser.add_argument('--lite_residual_ks', type=int, default=5)
    parser.add_argument('--random_init_lite_residual', action='store_true', default=True)
    parser.add_argument('--exp_idx', type=int, default=0)
    parser.add_argument('--server_record_name', type=str, default='src_acc_list')
    parser.add_argument('--mixed_record_name', type=str, default='mix_acc_list')
    parser.add_argument('--client_record_name', type=str, default='tar_acc_list')

    args = parser.parse_args()
    return args


def init_training_device(process_ID, fl_worker_num, gpu_num_per_machine):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return device

def load_compact_model(path, net):
    if net[0:3] == 'res':
        model = torch.load(path, map_location='cpu')
    elif net[0:4] == 'SHOT':
        model = torch.load(path, map_location='cpu')
    else:
        raise NotImplementedError
    model.modelname = net
    return model

def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def load_large_model(args):
    # SELECTING NET: MAINTAIN RES ONLY
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net)
    else:
        raise NotImplementedError

    LiteResidualModule.insert_lite_residual_resnet_torch(
        netF, args.lite_residual_downsample, 'bilinear', args.lite_residual_expand, args.lite_residual_ks,
        'relu', args.lite_residual_groups)

    netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck)
    netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck)

    modelpath = args.output_dir_tar + '/target_F_par_'+str(args.cls_par)+'.pt'
    netF.load_state_dict(torch.load(modelpath, map_location='cpu'))
    modelpath = args.output_dir_tar + '/target_B_par_'+str(args.cls_par)+'.pt'
    netB.load_state_dict(torch.load(modelpath, map_location='cpu'))
    modelpath = args.output_dir_tar + '/target_C_par_'+str(args.cls_par)+'.pt'
    netC.load_state_dict(torch.load(modelpath, map_location='cpu'))

    netF.eval()
    netB.eval()
    netC.eval()

    return netF, netB, netC


if __name__ == "__main__":
    # initialize distributed computing (MPI)
    comm, process_id, worker_number = FedML_init()

    # parse python script input parameters
    parser = argparse.ArgumentParser()
    args = add_args(parser)
    seed = args.seed
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    logging.info(args)

    root_path = args.root_path
    domain = {'src': str(args.src), 'tar': str(args.tar)}
    batch_size = args.batch_size
    server_batch_size = 64
    # server_batch_size = args.batch_size
    client_number = args.client_number

    args.output_dir_tar = osp.join(args.output, args.da, args.dset, args.net,
                                   args.src[0].upper() + args.tar[0].upper())

    if args.dset == 'office-home':
        names = ['Art', 'Clipart', 'Product', 'RealWorld']
        args.class_num = 65
    if args.dset == 'office':
        names = ['amazon', 'dslr', 'webcam']
        args.class_num = 31
    if args.dset == 'office-caltech':
        names = ['amazon', 'caltech', 'dslr', 'webcam']
        args.class_num = 10
    if args.dset == 'imageCLEF':
        names = ['C', 'I', 'P']
        args.class_num = 12

    percent = args.percent
    args.name = args.src[0].upper() + args.tar[0].upper()

    args.output_dir_src = osp.join(args.output_src, args.da, args.dset, args.net, args.src[0].upper())
    args.output_dir = osp.join(args.output, args.da, args.dset, args.net,
                               args.src[0].upper() + args.tar[0].upper())

    folder_name = "{}_".format(args.net) + args.compact_net + '_client_iter_' + str(args.epochs_client)+ '_rho_' + str(args.rho) +\
                  '_alpha_' + str(args.alpha) + '_batch_' + str(batch_size) + '_round_' + str(args.comm_round) + \
                  '_percent_' + str(percent) + '_seed_' +str(args.seed)
    save_dir = osp.join(args.dset, args.name, 'DA_completed', folder_name)

    if not osp.exists(save_dir):
        os.system('mkdir -p ' + save_dir)

    args.server_record_name = osp.join(save_dir, args.server_record_name)
    args.client_record_name = osp.join(save_dir, args.client_record_name)
    args.mixed_record_name = osp.join(save_dir, args.mixed_record_name)

    # torch.manual_seed(np.random.randint(size))

    # GPU arrangement: Please customize this function according your own topology.
    # The GPU server list is configured at "mpi_host_file".
    # If we have 4 machines and each has two GPUs, and your FL network has 8 workers and a central worker.
    # The 4 machines will be assigned as follows:
    # machine 1: worker0, worker4, worker8;
    # machine 2: worker1, worker5;
    # machine 3: worker2, worker6;
    # machine 4: worker3, worker7;
    # Therefore, we can see that workers are assigned according to the order of machine list.
    logging.info("process_id = %d, size = %d" % (process_id, worker_number))
    device = init_training_device(process_id, worker_number - 1, args.gpu_num_per_server)

    compact_model_path = './{}_'.format(args.compact_net)+args.src+'.pt'

    # load model
    compact_model = load_compact_model(compact_model_path, args.compact_net)
    # print(compact_model)

    # load data
    source_data_num, source_data = load_source_data(root_path, domain['src'], server_batch_size)
    if process_id == 0:
        train_data_local = None
        train_data_local_num = None
        test_data_local = None
        total_num = None
        netF, netB, netC = None, None, None
        # if args.is_feedback == True:
        #     savename = 'server_acc'
        # else:
        #     savename = 'server_wo_fb_acc'
        savename = 'server_acc'

        print(savename)
        args.out_file = open(osp.join(save_dir, 'log_' + savename + '.txt'), 'a')

    else:
        total_train_num, train_data_local_num, train_data_local, test_data_local = \
            load_partition_target_data_UD(root_path, domain['tar'], batch_size, client_number, process_id,
                                          percent=percent, seed=seed)

        savename = 'client_acc'

        print(savename)
        args.out_file = open(osp.join(save_dir, 'log_' + savename + '.txt'), 'a')

        # netF, netB, netC, optimizer = load_pretrained_model(args)
        netF, netB, netC = load_large_model(args)
        netF.to(device)
        netB.to(device)
        netC.to(device)

    torch.manual_seed(seed)

    # start distributed training
    FedML_FedXDD_DA_completed(args, process_id, worker_number, device, comm, compact_model, netF, netB, netC,
                              source_data, train_data_local, train_data_local_num, test_data_local)
