# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import argparse
import datetime
# import json
import numpy as np
import os
from copy import deepcopy
# from Core.Generator import Generator

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # classification
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
# import torchvision.datasets as datasets
import scipy.io as sio
import random

# from engine_pretrain_GAN_DT_v3 import train_one_epoch
# from engine_pretrain_GAN_DT_v4 import train_one_epoch
# from engine_pretrain_GAN_DT_v4_1 import train_one_epoch
from engine_pretrain_GAN_DT_v6_1 import train_one_epoch


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# import timm

# assert timm.__version__ == "0.3.2"  # version check
# import timm.optim.optim_factory as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.misc import save_checkpoint

import torchvision
from functions import LinearLrDecay
from eval import Evaluator, get_ydist, get_zdist, get_nsamples, add_imgs, imgs_grid


def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)

    # Hyperparams
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    # Data parameters
    parser.add_argument('--data_path',
                        default='/data2/myzhao/data/',
                        # default='/data2/Miaoyun_data/celeba_hq/',
                        # default='/data/data/celeba/',
                        # default='/mnt/E/Dropbox/Data/',
                        type=str, help='dataset path')
    parser.add_argument('--num_class', default=10, type=int,
                        help='classification num_class')
    parser.add_argument('--batch_size', default=32,
                        type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--dataloadmode', default='pytorch', type=str,
                        choices=['pytorch', 'dali'],
                        help='Mode of the data loading process')

    # Model parameters
    parser.add_argument('--input_size', default=128, type=int,
                        help='images input size')
    parser.add_argument('--patch_size', default=16, type=int,
                        help='images patch size')
    parser.add_argument('--mask_ratio', default=0.75, type=float,
                        help='Masking ratio (percentage of removed patches).')

    # Customize G parameters  384/12
    parser.add_argument('--model', default='type2', type=str, metavar='MODEL',
                        choices=['custom', 'mae_vit_base_patch16', 'mae_vit_large_patch16'],
                        help='Name of Generator model to train')
    parser.add_argument('--GNoise', default='GIt', type=str,
                        choices=['GIn', 'GIt', 'GStyIN', 'GStyLN'],
                        help='The setting of Generator Noise')
    parser.add_argument('--GLastNorm', default='LR', type=str,
                        choices=['NO', 'LN', 'IN', 'LR'],
                        help='The setting of the last norm of the Generator')

    parser.add_argument('--embed_dimG', default=256, type=int,
                        help='Generator encoder embed dim')
    parser.add_argument('--depthG', default=6, type=int,
                        help='Generator encoder depth')
    parser.add_argument('--num_headsG', default=8, type=int,
                        help='Generator encoder num_heads')
    parser.add_argument('--decoder_embed_dimG', default=512, type=int,
                        help='Generator decoder embed_dim: twice that of Generator encoder')
    parser.add_argument('--decoder_depthG', default=6, type=int,
                        help='Generator decoder depth')
    parser.add_argument('--decoder_num_headsG', default=8, type=int,
                        help='Generator decoder num_heads')
    parser.add_argument('--mlp_ratioG', default=2, type=int,
                        help='Generator mlp ratio')
    parser.add_argument('--token_mid', action='store_true',
                        help='if True, add noise and class token to the end of encoer')
    parser.set_defaults(token_mid=True)
    parser.add_argument('--toImg', default='vit', type=str,
                        choices=['vit', '2conv', '4conv', '3conv', '4convStyle', 'siren', 'cips'],
                        help='The setting of to convert to images')

    # Customize D parameters
    parser.add_argument('--model_dis', default='custom', type=str, metavar='MODEL_D',
                        choices=['custom', 'mae_vit_small', 'mae_vit_base_patch16', 'mae_vit_large_patch16', 'type1'],
                        help='Name of Discriminator model to train')
    parser.add_argument('--DST', default='DIn', type=str,
                        choices=['DIn', 'DIt', 'DSty', 'Dmulhead', 'Dmulclass', 'DSn'],
                        help='The setting of Discrimnator ST')
    parser.add_argument('--D_no_padST', action='store_true', help='if true, D-input may have ST-length < L')
    parser.set_defaults(D_no_padST=False)
    parser.add_argument('--DLastNorm', default='LN', type=str,
                        choices=['NO', 'LN', 'LR'],
                        help='The setting of the last norm of the Discriminator')
    parser.add_argument('--Dpadding', default=2, type=int,
                        help='Discriminator padding size')

    parser.add_argument('--embed_dimD', default=256, type=int,
                        help='Discriminator embed dim')
    parser.add_argument('--depthD', default=6, type=int,
                        help='Discriminator encoder depth')
    parser.add_argument('--num_headsD', default=8, type=int,
                        help='Discriminator encoder num_heads')
    parser.add_argument('--mlp_ratioD', default=2, type=int,
                        help='Discriminator mlp ratio')
    parser.add_argument('--drop', default=0.0, type=float,
                        help='dropout ratio for the Linear layer in Discriminator')

    parser.add_argument('--diff_aug', default=None,  # 'filter,translation,erase_ratio,color,hue',
                        type=str, help='Discriminator diff_aug')
    parser.add_argument('--setGP', default='real', type=str,
                        choices=['none', 'real', 'all'])
    parser.add_argument('--lambdaGP', default=10., type=float,
                        help='Gradient penalty coefficient.')
    parser.add_argument('--FixPemd', action='store_true', help='if true, test the pretrained model')
    parser.set_defaults(FixPemd=False)
    # Pretraining # store_true: chufa le caishi True
    parser.add_argument('--isTest', action='store_true', help='if true, test the pretrained model')
    parser.set_defaults(isTest=False)
    parser.add_argument('--resume',
                        # default='./pretrained/mae_pretrain_vit_base_full.pth',
                        default=None,
                        # default='/mnt/E/ExperiResults/BigLearn/output_MNIST/_joint_generate1/model/',
                        # default='/mnt/E/ExperiResults/BigLearn/output_MNIST/generate0.21.00.21.0/model/',
                        # default='/data/data/cong/output_celeba256generateGItDInstandard66-1.0-1.0-1.0/model/',
                        help='resume from checkpoint')

    # Loss parameters
    parser.add_argument('--method', default='generate', type=str,
                        # choices=['joint_classify', 'classify', 'joint_generate', 'generate', 'all'])
                        choices=['classify', 'generate', 'all'])  # conditional_generate
    parser.add_argument('--communicate', action='store_true', help='if true, enable communications among models')
    parser.set_defaults(communicate=False)
    parser.add_argument('--commu_only_fake', action='store_true', help='if true, communication only update fake G')
    parser.set_defaults(commu_only_fake=False)
    parser.add_argument('--loss', default='standard', type=str,
                        choices=['standard', 'wgangp', 'hinge', 'lsgan', 'vanilla'])

    parser.add_argument('--Sratio', default=-1., type=float,
                        help='Model to data: source ratio. a in [0,1]: constant Sratio. -1: random')
    parser.add_argument('--Tratio', default=-1., type=float,
                        help='Model to data: target ratio. a in [0,1]: constant Sratio. -1: random')
    parser.add_argument('--CommuSratio', default=-1., type=float,
                        help='Model to model communication: source ratio. a in [0,1]: constant Sratio. -1: random')

    parser.add_argument('--AProbAlpha', default=0.5, type=float,
                        help='Beta parameter alpha for A.')
    parser.add_argument('--AProbBeta', default=3., type=float,
                        help='Beta parameter beta for A.')
    parser.add_argument('--BProbAlpha', default=3, type=float,
                        help='Beta parameter alpha for B.')
    parser.add_argument('--BProbBeta', default=0.5, type=float,
                        help='Beta parameter beta for B.')
    parser.add_argument('--CProbAlpha', default=3., type=float,
                        help='Beta parameter alpha for C.')
    parser.add_argument('--CProbBeta', default=0.5, type=float,
                        help='Beta parameter beta for C.')

    parser.add_argument('--gen_w_label', action='store_true', help='if true, conditional GAN')
    parser.set_defaults(gen_w_label=False)

    # Optimizer params
    parser.add_argument('--lr', type=float,
                        default=1e-4,
                        metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--glr', type=float,
                        default=1e-4,
                        metavar='LR',
                        help='G learning rate (absolute lr)')
    parser.add_argument('--dlr', type=float,
                        default=1e-4,
                        metavar='LR',
                        help='D learning rate (absolute lr)')
    parser.add_argument('--set_scheduler', default='none', type=str,
                        choices=['none', 'linear', 'mae', 'warmupC'])
    parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--d_iter', default=1, type=int,
                        help='d_iter Discriminator updates per 1 Generator update')

    parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR')
    parser.add_argument('--Dsteps', type=int, default=1, metavar='N',
                        help='num of steps for D')
    parser.add_argument('--weight_decay', type=float, default=0.00,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--clip_gradG', type=float, default=5.0,  # 5.0
                        help='clip_grad G (default: 5.0)')
    parser.add_argument('--clip_gradD', type=float, default=5.0,  # 5.0
                        help='clip_grad D (default: 5.0)')
    parser.add_argument('--epochs', default=400, type=int)
    parser.add_argument('--max_iter', default=5000000, type=int)
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
    parser.add_argument('--norm_pix_loss', action='store_false',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(norm_pix_loss=True)
    parser.add_argument('--use_percept', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(use_percept=False)

    # Distributed training parameters
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    parser.add_argument('--showFreq', type=int, default=100, help='Iters to show the generation during train')

    # Testing params
    parser.add_argument('--fid', action='store_true',
                        help='do FID evaluation')
    parser.set_defaults(fid=False)
    parser.add_argument('--num_sample_fid', default=5000, type=int)

    # Output params
    parser.add_argument('--output_dir',  # sudo chmod -R 777 /data/data/output_MNIST_joint_classify/
                        # default='/data/data/cong/output_MNIST11',
                        default='/data/data/cong/output_celeba_test',
                        # default='/mnt/E/ExperiResults/BigLearn/output_MNIST/',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir',
                        # default='./output_dir',
                        default=None,
                        help='path where to tensorboard log')
    parser.add_argument('--print_freq', default=100, type=int)

    return parser


def dataloading_preprocess(mode='dali'):
    if mode == 'pytorch':
        # simple augmentation
        transform_train = transforms.Compose([
            # transforms.Grayscale(num_output_channels=3),
            transforms.Resize(args.input_size),
            transforms.CenterCrop(args.input_size),
            # transforms.RandomRotation(15),
            # transforms.RandomResizedCrop(args.input_size, scale=(0.95, 1.05), ratio=(0.95, 1.05), interpolation=3),
            # transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        transform_test = transforms.Compose([
            # transforms.Grayscale(num_output_channels=3),
            transforms.Resize(args.input_size),
            transforms.CenterCrop(args.input_size),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        dataset_train = torchvision.datasets.ImageFolder(args.data_path + 'img_align_celeba/',
                                                         transform=transform_train)
        dataset_test = torchvision.datasets.ImageFolder(args.data_path + 'tst/', transform=transform_test)
        print(dataset_train)

        data_loader_train = torch.utils.data.DataLoader(
            dataset_train,
            # sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            shuffle=True,
            drop_last=True,
        )
        data_loader_test = torch.utils.data.DataLoader(
            dataset_test,
            batch_size=args.batch_size * 2,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            shuffle=False,
            drop_last=False,
        )

    elif mode == 'dali':

        from nvidia.dali.pipeline import pipeline_def
        import nvidia.dali.types as types
        import nvidia.dali.fn as fn
        from nvidia.dali.plugin.pytorch import DALIGenericIterator

        num_workers = 1 if args.num_workers == 0 else args.num_workers

        @pipeline_def(batch_size=args.batch_size, num_threads=args.num_workers, device_id=0)
        def get_dali_pipeline(data_dir, crop_size):
            images, labels = fn.readers.file(file_root=data_dir, random_shuffle=True, initial_fill=10*args.batch_size)
            # decode data on the GPU
            images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
            images = fn.resize(images, resize_shorter=crop_size)
            images = fn.crop_mirror_normalize(images, crop=(crop_size, crop_size), crop_pos_x=0.5, crop_pos_y=0.5,
                                              mean=128, std=128,
                                              # mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              # std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                              mirror=fn.random.coin_flip())
            # images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
            # # the rest of processing happens on the GPU as well
            # images = fn.resize(images, resize_x=crop_size, resize_y=crop_size)
            # images = fn.crop_mirror_normalize(images,
            #                                   mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            #                                   std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
            #                                   mirror=fn.random.coin_flip())
            return images, labels

        data_loader_train = DALIGenericIterator(
            [get_dali_pipeline(args.data_path + 'img_align_celeba/', args.input_size)],
            ['data', 'label'],
            # reader_name='Reader'
            auto_reset=True,
        )
        # for i, data in enumerate(data_loader_train):
        #     x, y = data[0]['data'], data[0]['label']

        @pipeline_def(batch_size=args.batch_size*2, num_threads=args.num_workers, device_id=0)
        def get_dali_pipeline_test(data_dir, crop_size):
            images, labels = fn.readers.file(file_root=data_dir, random_shuffle=True, initial_fill=10*args.batch_size)
            # decode data on the GPU
            images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
            images = fn.resize(images, resize_shorter=crop_size)
            images = fn.crop_mirror_normalize(images, crop=(crop_size, crop_size), crop_pos_x=0.5, crop_pos_y=0.5,
                                              mean=128, std=128,
                                              # mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              # std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                              mirror=0)
            # images = fn.decoders.image_random_crop(images, device="mixed", output_type=types.RGB)
            # # the rest of processing happens on the GPU as well
            # images = fn.resize(images, resize_x=crop_size, resize_y=crop_size)
            # images = fn.crop_mirror_normalize(images,
            #                                   mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
            #                                   std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
            #                                   mirror=fn.random.coin_flip())
            return images, labels

        data_loader_test = DALIGenericIterator(
            [get_dali_pipeline_test(args.data_path + 'tst/', args.input_size)],
            ['data', 'label'],
            # reader_name='Reader'
            auto_reset=True,
        )

    elif mode == 'dali1':

        from nvidia.dali.pipeline import Pipeline
        import nvidia.dali.ops as ops
        import nvidia.dali.types as types

        dataroot = '/mnt/E/Dropbox/Data/celeba/img_align_celeba/'

        class SimplePipeline(Pipeline):
            def __init__(self, batch_size, num_threads, device_id, seed, size):
                super(SimplePipeline, self).__init__(batch_size, num_threads, device_id, seed)
                self.input = ops.FileReader(file_root=dataroot, random_shuffle=True)
                self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)

                self.resize = ops.Resize(resize_shorter=size, device='gpu')
                self.crop = ops.CropMirrorNormalize(crop=(size, size), mean=128, std=128, device='gpu')

            def define_graph(self):
                jpegs, labels = self.input(name="Reader")
                images = self.decode(jpegs)
                images = self.crop(self.resize(images))
                return images

        pipe = SimplePipeline(args.batch_size, 2, 0, 12, args.input_size)
        pipe.build()

        from nvidia.dali.plugin.pytorch import DALIGenericIterator
        data_loader_train = DALIGenericIterator(pipe, ['data'], size=int(pipe.epoch_size("Reader")), auto_reset=True)
        data_loader_test = data_loader_train

    return data_loader_train, data_loader_test


def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # seed_torch(args.seed)

    cudnn.benchmark = True

    data_loader_train, data_loader_test = dataloading_preprocess(mode=args.dataloadmode)

    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    # define the model
    if args.model == 'custom':
        gen_net = models_mae.MaskedAutoencoderViT(
            img_size=args.input_size, patch_size=args.patch_size, embed_dim=args.embed_dimG, depth=args.depthG,
            num_heads=args.num_headsG, decoder_embed_dim=args.decoder_embed_dimG, decoder_depth=args.decoder_depthG,
            decoder_num_heads=args.decoder_num_headsG, mlp_ratio=args.mlp_ratioG,
            norm_layer=models_mae.partial(torch.nn.LayerNorm, eps=1e-6),
            norm_pix_loss=args.norm_pix_loss, num_class=args.num_class, GLastNorm=args.GLastNorm)
    elif args.model == 'type1':
        gen_net = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, num_class=args.num_class)
    elif args.model == 'type2':
        # gen_para = {'feature_hidden_size': 384, 'n_transformer_layers': 4, 'output_hidden_dim': 768, 'mapping_mlp_params': {'layers': [], 'activation': 'gelu', 'dropout_rate': 0.0}, 'transformer_params': {'n_head': 4, 'attention_dropout_rate': 0.2, 'mlp_layers': [], 'mlp_activation': 'relu', 'mlp_dropout': 0.2}, 'img_size': 32, 'n_channels': 1, 'lattent_size': 1024}
        # gen_net = Generator(**gen_para)
        gen_net = models_mae.MaskedAutoencoderViT(
            img_size=args.input_size, patch_size=args.patch_size, in_chans=3, embed_dim=args.embed_dimG,
            depth=args.depthG,
            num_heads=args.num_headsG, decoder_embed_dim=args.decoder_embed_dimG, decoder_depth=args.decoder_depthG,
            decoder_num_heads=args.decoder_num_headsG, mlp_ratio=args.mlp_ratioG,
            norm_layer=models_mae.partial(torch.nn.LayerNorm, eps=1e-6),
            norm_pix_loss=args.norm_pix_loss, num_class=args.num_class, GLastNorm=args.GLastNorm, toImg=args.toImg)
    gen_net.to(device)

    if args.resume is not None:
        tmp = torch.load(args.resume + 'checkpoint_g')
        # args.start_epoch = tmp['epoch']
        del tmp['gen_state_dict']['decoder_pos_embed']
        if args.toImg == '4conv':
            tmp['gen_state_dict']['decoder_pred_head.11.weight'] = tmp['gen_state_dict']['decoder_pred_head.5.weight']
            tmp['gen_state_dict']['decoder_pred_head.11.bias'] = tmp['gen_state_dict']['decoder_pred_head.5.bias']
            del tmp['gen_state_dict']['decoder_pred_head.5.weight']
            del tmp['gen_state_dict']['decoder_pred_head.5.bias']
        gen_net.load_state_dict(tmp['gen_state_dict'], strict=False)

    gen_net_without_ddp = gen_net
    avg_gen_net = deepcopy(gen_net).to(device)
    if args.resume is not None:
        avg_gen_net.load_state_dict(tmp['avg_gen_state_dict'], strict=False)

    if args.model_dis == 'custom':
        dis_net = models_vit.MaskedAutoencoderViT(
            img_size=args.input_size, patch_size=args.patch_size, embed_dim=args.embed_dimD, depth=args.depthD,
            num_heads=args.num_headsD, mlp_ratio=args.mlp_ratioD, diff_aug=args.diff_aug,
            norm_layer=models_vit.partial(torch.nn.LayerNorm, eps=1e-6), num_class=args.num_class,
            DLastNorm=args.DLastNorm, padding=args.Dpadding, drop=args.drop)
    elif args.model_dis == 'type1':
        dis_net = models_vit.__dict__[args.model_dis](diff_aug=args.diff_aug, num_class=args.num_class)

    dis_net.to(device)

    if args.resume is not None:
        tmp = torch.load(args.resume + 'checkpoint_d')
        # args.start_epoch = tmp['epoch']
        dis_net.load_state_dict(tmp['dis_state_dict'], strict=False)
    dis_net_without_ddp = dis_net

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()

    if args.dlr is None:  # only base_lr is specified
        args.dlr = args.blr * eff_batch_size / 256

    # print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.dlr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    # # following timm: set wd as 0 for bias and norm layers
    loss_scaler = NativeScaler()

    optimizer_G = torch.optim.AdamW(gen_net_without_ddp.parameters(), lr=args.glr, betas=(0.1, 0.999), eps=1e-8)
    if args.FixPemd:
        dis_net_without_ddp.parameters()  # TODO
        tmp = torch.load(args.resume + 'checkpoint_g')
        # args.start_epoch = tmp['epoch']
        del tmp['gen_state_dict']['decoder_pos_embed']
    optimizer_D = torch.optim.AdamW(dis_net_without_ddp.parameters(), lr=args.dlr, betas=(0.1, 0.999), eps=1e-5)

    # optimizer_G = torch.optim.AdamW(gen_net.parameters(), lr=args.glr, betas=(0.1, 0.999), eps=1e-8)
    # optimizer_D = torch.optim.AdamW(dis_net.parameters(), lr=args.dlr, betas=(0.1, 0.999), eps=1e-5)

    if args.set_scheduler is 'linear':
        scheduler_G = LinearLrDecay(optimizer_G, args.glr, 0.0, 0, args.max_iter * args.d_iter)
        scheduler_D = LinearLrDecay(optimizer_D, args.dlr, 0.0, 0, args.max_iter * args.d_iter)  # DISC
    else:
        scheduler_G = scheduler_D = None

    # misc.load_model(args=args, model_without_ddp=gen_net_without_ddp, optimizer=optimizer_G, loss_scaler=loss_scaler)
    # gen_avg_param = copy_params(gen_net, mode='gpu')

    bs_test = 10
    test_data = iter(data_loader_test).next()[0][:bs_test].to(device)
    add_imgs(test_data, args.dir_img + 'real.jpg', nrow=10)
    print('test_data shape', test_data.shape)

    if args.fid:
        ydist = get_ydist(args.num_class, device=device)
        zdist = get_zdist('gauss', args.embed_dimG, device=device)
        x_real, _ = get_nsamples(data_loader_train, args.num_sample_fid)
        evaluator = Evaluator(zdist, ydist,
                              fid_real_samples=x_real, batch_size=args.batch_size, device=device,
                              inception_nsamples=args.num_sample_fid, fid_sample_size=args.num_sample_fid)

    gen_z_test = torch.randn(bs_test, 1, args.z_dim).to(device)
    # y_test = torch.tensor(np.arange(bs_test) % args.num_class).long().to(device)
    y_test = torch.tensor(np.arange(bs_test) // (bs_test+1)).long().to(device)

    p_list_S = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    N, L = bs_test, (args.input_size // args.patch_size) ** 2
    r_ind = [torch.randperm(L).to(device) for i in range(3)]
    r_ind = torch.cat(r_ind, dim=0).view(3, L)

    epoch = -1
    # 0 -------test: image generation -------
    if args.method in ['generate', 'all']:
        fake_all = []
        with torch.no_grad():
            fake_imgs, _, _, _, _ = gen_net(test_data, gen_z_test, label=y_test, Sratio=0., Tratio=1., method='generate')
            fake_all.append(fake_imgs)
            for i in p_list_S:
                for j in range(1):
                    fake_imgs, _, _, _, _ = gen_net(test_data, gen_z_test, label=y_test, r_ind=r_ind[j].repeat(N, 1),
                                                    Sratio=i, Tratio=1., method='generate')
                    fake_all.append(fake_imgs)
        fake_all.append(test_data)
        fake_all = torch.cat(fake_all, dim=0)

        add_imgs(fake_all.data, args.dir_img + 'Epoch_%d.jpg' % epoch, nrow=bs_test)
        log_writer.add_image('out_img/test_all', imgs_grid(fake_all.data, bs_test), epoch)

    print(f"Start training for {args.epochs} epochs")
    fid_all, inception_mean_all, acc_all = [], [], []
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        avg_gen_net, gen_net, dis_net = train_one_epoch(
            gen_net, dis_net, data_loader_train,
            optimizer_G, optimizer_D, device, epoch, loss_scaler,
            log_writer=log_writer, lambdaGP=args.lambdaGP,
            args=args, avg_gen_net=avg_gen_net,
            scheduler_G=scheduler_G, scheduler_D=scheduler_D,
            gen_z_test=gen_z_test, test_data=test_data, test_y=y_test
        )

        # 1 ------- image generation -------
        if args.method in ['generate', 'all']:
            fake_all = []
            with torch.no_grad():
                fake_imgs, _, _, _, _ = avg_gen_net(test_data, gen_z_test, label=y_test, Sratio=0., Tratio=1., method='generate')
                fake_all.append(fake_imgs)
                for i in p_list_S:
                    for j in range(1):
                        fake_imgs, _, _, _, _ = avg_gen_net(test_data, gen_z_test, label=y_test, r_ind=r_ind[j].repeat(N, 1),
                                                        Sratio=i, Tratio=1., method='generate')
                        fake_all.append(fake_imgs)
            fake_all.append(test_data)
            fake_all = torch.cat(fake_all, dim=0)

            add_imgs(fake_all.data, args.dir_img + 'Epoch_%d.jpg' % epoch, nrow=bs_test)
            log_writer.add_image('out_img/test_all', imgs_grid(fake_all.data, bs_test), epoch)

        # # 2 ------- FID -------
        if args.fid and args.method in ['generate', 'all'] and epoch % 10 == 0:
            inception_mean, inception_std, fid = evaluator.compute_inception_score(generator=gen_net,
                                                                                   mask_ratio=0.75)
            inception_mean_all.append(inception_mean)
            fid_all.append(fid)
            print('test epoch %d: IS: mean %.2f, FID: mean %.2f' % (
                epoch, inception_mean, fid))

            FID = np.stack(fid_all)
            Inception_mean = np.stack(inception_mean_all)
            sio.savemat(args.dir_img + '/FID_IS.mat', {'FID': FID, 'Inception_mean': Inception_mean})

        # 3 ------- classification -------
        if args.method in ['classify', 'all']:
            with torch.no_grad():
                pred = []
                yall = []
                for data_iter_step, (samples, y) in enumerate(data_loader_test):
                    samples = samples.to(device)
                    y = y.to(device)
                    _, cls_token, _, _, _ = gen_net(samples, mask_ratio=args.mask_ratio, maskType='joint_classify')
                    logi_fake = gen_net.head(cls_token)
                    pred.append(logi_fake.argmax(-1))
                    yall.append(y)
                pred = torch.cat(pred)
                yall = torch.cat(yall)
                acc = (pred == yall).sum() / len(yall)
                print('epoch:%08d, test acc: %.8f' % (epoch, acc.item()))
                acc_all.append(acc.cpu().numpy())

                ACC = np.stack(acc_all)
                sio.savemat(args.dir_img + '/ACC.mat', {'acc': ACC})

        save_checkpoint({
            'epoch': epoch + 1,
            'gen_model': args.model,
            'avg_gen_state_dict': avg_gen_net.state_dict(),
            'gen_state_dict': gen_net.state_dict(),
        }, 0, args.dir_model, filename="checkpoint_g")  # only save the last one for now
        # del avg_gen_net
        save_checkpoint({
            'epoch': epoch + 1,
            'gen_model': args.model_dis,
            'dis_state_dict': dis_net.state_dict(),
        }, 0, args.dir_model, filename="checkpoint_d")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    if args.GNoise == 'GIn':
        import models_mae_SC_tokenMid_small_In as models_mae  # TODO: update
    elif args.GNoise == 'GIt':
        import models_mae_SC_tokenMid_small_ItNoNoPE as models_mae
    elif args.GNoise == 'GStyIN':
        args.embed_dimG = args.decoder_embed_dimG
        import models_mae_SC_tokenMid_small_StyIN as models_mae
    elif args.GNoise == 'GStyLN':
        args.embed_dimG = args.decoder_embed_dimG
        import models_mae_SC_tokenMid_small_StyLN as models_mae

    args.multihead = False
    if args.DST == 'DIn':
        import models_vit_D_v2 as models_vit
    elif args.DST == 'DIt':  # Not working
        import models_vit_D_v3_ItInST as models_vit  # TODO: update
    # elif args.DST == 'DSty':
    #     pass  # TODO
    # elif args.DST == 'Dmulhead':
    #     import models_vit_D_v4 as models_vit
    #     args.multihead = True
    #     from engine_pretrain_GAN_DT_v7 import train_one_epoch
    # elif args.DST == 'Dmulclass':
    #     import models_vit_D_v5 as models_vit
    # elif args.DST == 'DSn':
    #     import models_vit_D_v2_sn as models_vit
    else:
        pass  # TODO

    args.output_dir = args.output_dir + ('test' if args.isTest else 'train') \
                      + args.method + args.GNoise + args.DST + args.loss + str(args.depthG) + \
                      str(args.decoder_depthG) + str(args.depthD) + str(args.Sratio) + str(args.Tratio) \
                      + str(args.toImg)+'ps%s'%args.patch_size
    if args.communicate:
        args.output_dir += str(args.CommuSratio)
    if args.commu_only_fake:
        args.output_dir += 'OF'
    if not args.D_no_padST:
        args.output_dir += 'PadST'
    args.output_dir += 'LastNorm' + args.GLastNorm + args.DLastNorm
    if args.setGP == 'real':
        args.output_dir += 'GPr' + str(args.lambdaGP)
    elif args.setGP == 'all':
        args.output_dir += 'GPrf' + str(args.lambdaGP)
    if args.Dpadding > 0:
        args.output_dir += 'DP' + str(args.Dpadding)
    args.output_dir += 'PS' + str(args.patch_size)
    if args.toImg != 'vit':
        args.output_dir += args.toImg
    args.dir_img = args.output_dir + '/img/'
    args.dir_model = args.output_dir + '/model/'

    if args.log_dir is None:
        args.log_dir = args.output_dir

    args.z_dim = args.embed_dimG
    # args.lr = 1e-4 if args.use_schedual else 1.5e-4
    if args.output_dir:
        Path(args.dir_img).mkdir(parents=True, exist_ok=True)
        Path(args.dir_model).mkdir(parents=True, exist_ok=True)

    # args.beta sample
    from torch.distributions.beta import Beta

    args.SampleSratio = Beta(args.AProbAlpha, args.AProbBeta)  # (0.5, 3)
    args.SampleTratio = Beta(args.BProbAlpha, args.BProbBeta)  # (3, 0.5)
    args.SampleCommuSratio = args.SampleSratio  # (0.5, 3)

    args.SampleSratioCls = Beta(args.CProbAlpha, args.CProbBeta)  # (3, 0.5)

    main(args)
