import argparse
import pdb
import fedmitr
import dfkd
import shutil
from models.nets import *
from train import *
from utils_kd import *
import warnings
warnings.filterwarnings('ignore')
logging.getLogger('PIL').setLevel(logging.WARNING)
FloatTensor = torch.cuda.FloatTensor


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--GPU', type=bool, default=torch.cuda.is_available(), help='use GPU or CPU')
    parser.add_argument('--data_dir', type=str, required=False, default="./data/", help="Data directory")
    parser.add_argument('--log_dir', type=str, required=False, default="./logs/", help='Log directory path')
    parser.add_argument('--model_dir', type=str, required=False, default="./save_models/", help='Model directory path')
    parser.add_argument('--seed', type=int, default=0, help="Random seed")
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--type', default="train", type=str, help='train or test.')
    parser.add_argument('--type_gen', default="old", type=str, help='old or new.')

    # Data Setting
    parser.add_argument('--dataset', type=str, default='MiniImageNet', help='dataset used for training')
    parser.add_argument('--img_size', type=int, default='224', help='train input image size')
    parser.add_argument('--num_classes', type=int, default='200', help='the number of class')

    parser.add_argument('--partition', type=str, default='iid',
                        help='the data partitioning strategy:iid,Dirichlet,non_iid')
    parser.add_argument('--beta', type=float, default=0.5,
                        help='The parameter for the dirichlet distribution for data partitioning')
    parser.add_argument('--n_label', type=int, default=0, help='label number of every client')

    # Model Setting
    parser.add_argument('--model', type=str, default='ViT16', help='neural network used in training')
    parser.add_argument('--model_norm', type=str, default='bn', choices=['none', 'bn', 'in', 'ln', 'gn'])
    parser.add_argument('--pretrained', action='store_true', help='use pre-trained model')

    # Method Setting
    parser.add_argument('--local_method', type=str, default='vanilla', help='local train method ')
    parser.add_argument('--server_method', type=str, default='fedavg', help='server aggregation method')

    # Fl Setting
    parser.add_argument('--n_clients', type=int, default=10, help='number of clients')
    parser.add_argument('--ratio', type=float, default=0.4)
    parser.add_argument('--comm_round', type=int, default=1, help='number of maximum communication round')
    parser.add_argument('--local_epochs', type=int, default=50, help='number of local epochs')

    # Train hyperparameter
    parser.add_argument('--bs', type=int, default=128, help='input images batch size for training')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
    parser.add_argument('--interval', type=int, default=10)

    # Distill Setting
    parser.add_argument('--kd_lr', type=float, default=0.01, help='kd learning rate')
    parser.add_argument('--kd_epochs', type=int, default=10, help='kd epochs')
    parser.add_argument("--T", type=int, default=20, help="Distillation temperature")
    parser.add_argument('--kd_bs', type=int, default=128, help='input images batch size for kd')
    parser.add_argument('--use_test', action='store_true', help='use part of test data for distill(no use)')
    parser.add_argument('--teacher_logits', type=str, default='None', help='Ensemble soft or not')

    # Generator Setting
    parser.add_argument('--save_data_dir', default="", help='path to datapool')
    parser.add_argument('--lr_g', default=1e-3, type=float, help='adam: generator learning rate')
    parser.add_argument('--g_steps', default=20, type=int, help='number of iterations for generation')
    parser.add_argument('--synthesis_batch_size', default=64, type=int)
    parser.add_argument('--syn_numbers', default=1, type=int)

    # DENSE Setting
    parser.add_argument('--div', default=1.0, type=float, help='scaling factor for div loss')
    parser.add_argument('--ce', default=1.0, type=float, help='scaling factor for ce loss')
    parser.add_argument('--bn', default=1.0, type=float, help='scaling factor for BN regularization')
    parser.add_argument("--I_den", type=int, default=1, help="inner iterations of training the global model")

    # CO-BOOSTING Setting
    parser.add_argument('--odseta', default=8, type=float)
    parser.add_argument('--weighted', action='store_true')
    parser.add_argument('--wa_steps', default=5, type=int)
    parser.add_argument('--mu', default=0.01, type=float)
    parser.add_argument('--wdc', default=0.99, type=float)

    # FedFTG Setting
    parser.add_argument("--beta_md", type=float, default=1.0, help="hyper-parameter of L_md loss")
    parser.add_argument("--beta_div", type=float, default=1.0, help="hyper-parameter of L_div loss")
    parser.add_argument("--I_ftg", type=int, default=1, help="inner iterations of training the global model")

    # Ours Setting
    parser.add_argument('--cut', type=int, default=0)
    parser.add_argument('--prune_it', nargs='+', type=int,
                        help='the iteration indexes for inversion stopping; '
                             '-1: to densely invert data; '
                             't1 t2 ... tn: to sparsely invert data and perform inversion stopping at t1, t2, ..., tn')
    parser.add_argument('--prune_ratio', nargs='+', type=float,
                        help='the proportion of patches to be pruned relative to the current remaining patches; '
                             '0: to densely invert data; '
                             'r1 r2 ... rn: progressively stopping the inversion of a fraction (r1, r2, ..., rn)'
                             ' of patches at iterations (t1, t2, ..., tn), respectively')
    parser.add_argument('--SMI_iterations', default=100, type=int)

    # Ours Setting
    parser.add_argument("--alpha1", type=float, default=1.0, help="alpha1 loss")
    parser.add_argument("--alpha2", type=float, default=1.0, help="alpha2 loss")
    parser.add_argument("--adv", type=float, default=0.0, help="adv loss")

    my_args = parser.parse_args()
    return my_args


if __name__ == '__main__':
    # Preliminaries
    args = get_args()
    args_list = list(args.__dict__.keys())
    args_list.sort()
    device = torch.device('cuda:0' if args.GPU else 'cpu')
    set_random_seed(args)

    # create log
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    partition = args.partition
    if partition == 'Dirichlet':
        partition += str(args.beta)
    log_folder = os.path.join(args.log_dir, args.model, args.dataset, partition)
    if not os.path.exists(log_folder):
        os.makedirs(os.path.join(log_folder))
    log_path = get_log_name(args) + '.log'
    logging.basicConfig(
        filename=os.path.join(log_folder, log_path),
        format='%(asctime)s %(levelname)-8s %(message)s',
        datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w')
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    for key in args_list:
        logger.info("{}: {}".format(key, args.__dict__[key]))

    # create model dir
    model_folder = os.path.join(args.model_dir, args.dataset, args.model)
    model_name = get_model_name(args)
    if not os.path.exists(model_folder):
        os.makedirs(os.path.join(model_folder))

    # get data
    print("Partitioning data")
    logger.info("Partitioning data")
    train_dataset, test_dataset, proxy_dataset, _, train_transform, normalizer = load_data(args)

    user_groups, traindata_cls_counts = partition_data(args)
    print('Data statistics: %s' % str(traindata_cls_counts))
    logging.info('Data statistics: %s' % str(traindata_cls_counts))

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.bs, shuffle=False,
                                              num_workers=args.num_workers, pin_memory=True)

    # get model
    logger.info("Initializing nets")
    global_model = get_global_model(args)
    models = [None for net_i in range(args.n_clients)]
    for net_i in range(args.n_clients):
        models[net_i] = copy.deepcopy(global_model)

    # start
    if args.type == 'train':
        print('########start train#######')
        logger.info('########start train#######')
        # train
        if args.comm_round == 1:
            local_train(args, train_dataset, test_loader, user_groups, models, model_folder, model_name, device)

    elif args.type == 'test':
        local_models = []
        print('########start test#######')
        logger.info('########start test#######')
        # get clients
        clients = random.sample(range(args.n_clients), int(args.n_clients * args.ratio))

        # get trained local model
        local_weights = []
        for idx in clients:
            save_model_path = os.path.join(model_folder, f'{model_name}_{idx}.pth')
            global_model.load_state_dict(torch.load(save_model_path))
            global_model.to(device)
            local_model = copy.deepcopy(global_model)
            local_models.append(local_model)
            w = global_model.state_dict()
            local_weights.append(copy.deepcopy(w))
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)
        global_model.to(device)
        ensemble_model = Ensemble(local_models)
        teacher = ensemble_model
        if args.server_method == 'fedavg':
            compute_accuracy(global_model, test_loader, 'fedavg', device)
            compute_accuracy(ensemble_model, test_loader, 'Ensemble', device)
            sys.exit()
        data_pool_path = os.path.join(args.save_data_dir, args.dataset, f'{partition}_seed{args.seed}')
        if os.path.exists(data_pool_path):
            shutil.rmtree(data_pool_path)
            print('remove')
        criterion_kd = KLDiv(T=args.T)
        optimizer_kd = torch.optim.SGD(global_model.parameters(), args.kd_lr, weight_decay=1e-4, momentum=0.9)
        bst_acc = -1
        bst_epo = -1
        data_pool_path2 = None

        if args.type_gen == 'old':
            synthesizer = None
            generator = dfkd.generator.Generator(nz=256, ngf=64, img_size=224, nc=3)
            generator = generator.cuda()
            if args.server_method == 'dense':
                synthesizer = dfkd.dense.DENSESynthesizer(
                    teacher=teacher, mdl_list=local_models, student=global_model, generator=generator, nz=256,
                    num_classes=args.num_classes,
                    img_size=(3, 224, 224), iterations=args.g_steps, lr_g=args.lr_g,
                    synthesis_batch_size=args.synthesis_batch_size, sample_batch_size=args.kd_bs,
                    adv=args.div, bn=args.bn, oh=args.ce, criterion=criterion_kd,
                    transform=train_transform,
                    save_dir=data_pool_path, normalizer=normalizer, args=args)
            elif args.server_method == 'fedftg':
                synthesizer = dfkd.fedftg.FEDFTGSynthesizer(
                    teacher=teacher, mdl_list=local_models, student=global_model, generator=generator, nz=256,
                    num_classes=args.num_classes,
                    img_size=(3, 224, 224), iterations=args.g_steps, lr_g=args.lr_g,
                    synthesis_batch_size=args.synthesis_batch_size, sample_batch_size=args.kd_bs,
                    adv=args.div, bn=args.bn, oh=args.ce, criterion=criterion_kd,
                    transform=train_transform,
                    save_dir=data_pool_path, normalizer=normalizer, args=args)
            elif args.server_method == 'coboosting':
                ww = torch.zeros(size=(len(local_models), 1))
                for _ww in range(len(local_models)):
                    ww[_ww] = 1.0 / len(local_models)
                ww = ww.cuda()
                teacher = WEnsemble(local_models, ww)
                synthesizer = dfkd.coboost.COBOOSTSynthesizer(
                    teacher=teacher, mdl_list=local_models, student=global_model, generator=generator, nz=256,
                    num_classes=args.num_classes,
                    img_size=(3, 224, 224), iterations=args.g_steps, lr_g=args.lr_g,
                    synthesis_batch_size=args.synthesis_batch_size, sample_batch_size=args.kd_bs,
                    adv=args.div, bn=args.bn, oh=args.ce, criterion=criterion_kd,
                    transform=train_transform,
                    save_dir=data_pool_path, normalizer=normalizer, args=args)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_kd, T_max=args.kd_epochs)

            for epoch in range(args.kd_epochs):
                for i in range(args.syn_numbers):
                    vis_results = synthesizer.synthesize(cur_ep=epoch)
                teacher = synthesizer.teacher
                teacher = teacher.cuda()
                dfkd.kd_train(synthesizer, [global_model, teacher], criterion_kd, optimizer_kd)
                scheduler.step()
                acc = compute_accuracy(global_model, test_loader, f'{args.server_method}', device)
                if acc > bst_acc:
                    bst_epo = epoch + 1
                bst_acc = max(acc, bst_acc)
                print(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")
                logger.info(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")

        elif args.type_gen == 'new':
            if args.server_method == 'DeepInversion':
                synthesizer = fedmitr.mi.MI(args=args, clients=clients, teachers=local_models, student=global_model,
                                            num_classes=args.num_classes,
                                            img_shape=(3, 224, 224), iterations=args.SMI_iterations, patch_size=16, lr_g=0.25,
                                            synthesis_batch_size=args.synthesis_batch_size,
                                            sample_batch_size=args.kd_bs,
                                            adv=args.adv, bn=0, oh=1, tv1=0, tv2=0.0001, l2=0,
                                            save_dir=data_pool_path, transform=train_transform,
                                            normalizer=normalizer, device=device, bnsource='resnet50v1', init_dataset=None)

                for epoch in range(args.kd_epochs):
                    print(f"Epoch {epoch + 1} Generating data...")
                    logger.info(f"Epoch {epoch + 1} Generating data...")
                    for i in range(args.syn_numbers):
                        synthesizer.synthesize(num_patches=197, prune_it=args.prune_it, prune_ratio=args.prune_ratio)
                    calibrate_dataloader = synthesizer.sample()
                    print("Fine-tuning and Knowledge transfer with generated data...")
                    logger.info("Fine-tuning and Knowledge transfer with generated data...")

                    fedmitr.kd_train(epoch, calibrate_dataloader, teacher, global_model,
                                     criterion_kd, optimizer_kd, device)

                    acc = compute_accuracy(global_model, test_loader, f'{args.server_method}', device)
                    if acc > bst_acc:
                        bst_epo = epoch + 1
                    bst_acc = max(acc, bst_acc)
                    print(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")
                    logger.info(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")
            elif args.server_method == 'FedMITR':
                data_pool_path2 = os.path.join(args.save_data_dir, args.dataset, f'{partition}_seed{args.seed}_relabel')
                if os.path.exists(data_pool_path2):
                    shutil.rmtree(data_pool_path2)
                print('remove')
                criterion_ce = nn.CrossEntropyLoss().to(device)
                optimizer = torch.optim.SGD(global_model.parameters(), lr=args.kd_lr, momentum=0.9, weight_decay=1e-4)
                synthesizer = fedmitr.mi.MI(args=args, clients=clients,
                                            teachers=local_models, student=global_model, num_classes=args.num_classes,
                                            img_shape=(3, 224, 224), iterations=args.SMI_iterations, patch_size=16,
                                            lr_g=0.25,
                                            synthesis_batch_size=args.synthesis_batch_size,
                                            sample_batch_size=args.kd_bs,
                                            adv=args.adv, bn=0, oh=1, tv1=0, tv2=0.0001, l2=0,
                                            save_dir=data_pool_path, save_dir2=data_pool_path2,
                                            transform=train_transform,
                                            normalizer=normalizer, device=device, bnsource='resnet50v1',
                                            init_dataset=None, test_loader=test_loader)
                for epoch in range(args.kd_epochs):
                    epoch_loss_collector = []
                    print(f"Epoch {epoch + 1} Generating data...")
                    logger.info(f"Epoch {epoch + 1} Generating data...")
                    for i in range(args.syn_numbers):
                        synthesizer.synthesize(num_patches=197, prune_it=args.prune_it, prune_ratio=args.prune_ratio)
                    calibrate_dataloader = synthesizer.sample()
                    print("Fine-tuning and Knowledge transfer with generated data...")
                    logger.info("Fine-tuning and Knowledge transfer with generated data...")
                    bar = tqdm(calibrate_dataloader, file=sys.stdout)
                    correct = 0.0
                    for batch_id, datas in enumerate(bar):
                        data1 = datas[0]
                        data2 = datas[1]
                        data = data1[0] + data2[0]
                        data1[0], data1[1], data2[0], data = data1[0].to(device), data1[1].to(device), data2[0].to(device), data.to(device)
                        global_model.zero_grad()
                        output = global_model(data1[0],
                                              torch.LongTensor(list(range(197))).repeat(data1[0].shape[0], 1).to(
                                                  device),
                                              torch.LongTensor(list(range(197))).repeat(data1[0].shape[0], 1).to(
                                                  device))
                        with torch.no_grad():
                            output_t = teacher(data2[0],
                                               torch.arange(197).repeat(data2[0].shape[0], 1).to(data2[0].device),
                                               torch.cat([torch.zeros(data2[0].shape[0], 1, dtype=torch.long)
                                                         .to(data2[0].device),
                                                          find_non_zero_patches(images=data2[0], patch_size=16)],
                                                         dim=1))
                            output_t_inv = teacher(data,
                                               torch.arange(197).repeat(data2[0].shape[0], 1).to(data2[0].device),
                                               torch.cat([torch.zeros(data2[0].shape[0], 1, dtype=torch.long)
                                                         .to(data2[0].device),
                                                          find_non_zero_patches(images=data2[0], patch_size=16)],
                                                         dim=1))

                        aaa = torch.cat([torch.zeros(data2[0].shape[0], 1, dtype=torch.long).to(data2[0].device),
                                         find_non_zero_patches(images=data2[0], patch_size=16)], dim=1)

                        output_s = global_model(data2[0],
                                                torch.arange(197).repeat(data2[0].shape[0], 1)
                                                .to(data2[0].device), aaa)
                        output_s_inv = global_model(data,
                                                torch.arange(197).repeat(data2[0].shape[0], 1)
                                                .to(data2[0].device), aaa)

                        loss_1 = criterion_ce(output[0], data1[1])

                        loss_2 = criterion_kd(output_s[0], output_t[0].detach())
                        loss_inv = criterion_kd(output_s_inv[0], output_t_inv[0].detach())
                        loss = args.alpha1 * loss_1 + args.alpha2 * loss_2 + loss_inv
                        loss.backward()
                        optimizer.step()
                        epoch_loss_collector.append(loss.item())
                    acc = compute_accuracy(global_model, test_loader, f'{args.server_method}', device)
                    if acc > bst_acc:
                        bst_epo = epoch + 1
                    bst_acc = max(acc, bst_acc)
                    print(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")
                    logger.info(f"epoch{epoch + 1}: best acc:{bst_acc:.4f} in epoch{bst_epo}")

        if os.path.exists(data_pool_path):
            shutil.rmtree(data_pool_path)
        if args.server_method == 'FedMITR':
            if os.path.exists(data_pool_path2):
                shutil.rmtree(data_pool_path2)
