import time, math, os, copy, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import datasets, transforms
# from scipy.ndimage.interpolation import rotate as scipyrotate
from ema_pytorch import EMA
from networks import *
from ldm.util import instantiate_from_config
from easydict import EasyDict

class Config:
    custom = [1, 199, 388, 294, 340, 932, 327, 765, 928, 486]
    imagenette = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
    # ["australian_terrier", "border_terrier", "samoyed", "beagle", "shih-tzu", "english_foxhound", "rhodesian_ridgeback", "dingo", "golden_retriever", "english_sheepdog"]
    imagewoof = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229]
    # ["tabby_cat", "bengal_cat", "persian_cat", "siamese_cat", "egyptian_cat", "lion", "tiger", "jaguar", "snow_leopard", "lynx"]
    imagemeow = [281, 282, 283, 284, 285, 291, 292, 290, 289, 287]
    # ["rock_beauty", "clownfish", "loggerhead", "puffer", "stingray", "jellyfish", "starfish", "eel", "anemone", "american_lobster"]
    imageblub = [392, 393, 33, 397, 6, 107, 327, 390, 108, 122]
    # ["peacock", "flamingo", "macaw", "pelican", "king_penguin", "bald_eagle", "toucan", "ostrich", "black_swan", "cockatoo"]
    imagesquawk = [84, 130, 88, 144, 145, 22, 96, 9, 100, 89]
    alyosha = [292, 340, 971, 987, 130, 323, 937, 337, 199, 294]
    # [scotty, brown bear, beaver, husky, bee, panther, terrapin, tiger, badger, duck
    mascots = [199, 294, 337, 250, 309, 286, 36, 292, 362, 97]
    # ["pineapple", "banana", "strawberry", "orange", "lemon", "pomegranate", "fig", "bell_pepper", "cucumber", "green_apple"]
    fruits = [953, 954, 949, 950, 951, 957, 952, 945, 943, 948]
    # ["bee", "ladys slipper", "banana", "lemon", "corn", "school_bus", "honeycomb", "lion", "garden_spider", "goldfinch"]
    yellow = [309, 986, 954, 951, 987, 779, 599, 291, 72, 11]
    # ["baseball", "basketball", "croquet_ball", "golf_ball", "ping-pong_ball", "rugby_ball", "soccer_ball", "tennis_ball", "volleyball", "puck"]
    imagesport = [429, 430, 522, 574, 722, 768, 805, 852, 890, 746]
    # ["saxophone", "trumpet", "french_horn", "flute", "oboe", "ocarina", "bassoon", "trombone", "panpipe", "harmonica"]
    imagewind = [776, 513, 566, 558, 683, 684, 432, 875, 699, 593]
    # ["saxophone", "trumpet", "french_horn", "flute", "oboe", "ocarina", "bassoon", "trombone", "panpipe", "harmonica"]
    imagestrings = [776, 513, 566, 558, 683, 684, 432, 875, 699, 593]
    # ["volcano", "alp", "lakeside", "geyser", "coral_reef", "sandbar", "promontory", "seashore", "cliff", "valley"]
    imagegeo = [980, 970, 975, 974, 973, 977, 976, 978, 972, 979]
    # ["axolotl", "tree_frog", "king_snake", "african_chameleon", "iguana", "eft", "fire_salamander", "box_turtle", "american_alligator", "agama"]
    imageherp = [29, 31, 56, 47, 39, 27, 25, 37, 50, 42]
    # ["cheeseburger", "hotdog", "pretzel", "pizza", "french loaf", "icecream", "guacamole", "carbonara", "bagel", "trifle"]
    imagefood = [933, 934, 932, 963, 930, 928, 924, 959, 931, 927]
    # ["fire_engine", "garbage_truck", "forklift", "racer", "tractor", "unicycle", "rickshaw", "steam_locomotive", "bullet_train", "mountain_bike"]
    imagewheels = [555, 569, 561, 751, 866, 880, 612, 820, 466, 671]
    # ["bubble", "piggy_bank", "stoplight", "coil", "kimono", "cello", "combination_lock", "triumphal_arch", "fountain", "cowboy_boot"]
    imagemisc = [971, 719, 920, 506, 614, 486, 507, 873, 562, 514]
    # ["broccoli", "cauliflower", "mushroom", "cabbage", "cardoon", "mashed_potato", "artichoke", "corn", "fountain", "spaghetti_squash"]
    imageveg = [971, 719, 920, 506, 614, 486, 507, 873, 562, 940]
    # ["ladybug", "bee", "monarch", "dragonfly", "mantis", "black_widow", "rhinoceros_beetle", "walking_Stick", "grasshopper", "scorpion"]
    imagebug = [301, 309, 323, 319, 315, 75, 306, 313, 311, 71]
    # ["african_elephant", "red_panda", "camel", "zebra", "guinea_pig", "kangaroo", "platypus", "arctic_fox", "porcupine", "gorilla"]
    imagemammal = [386, 387, 354, 340, 338, 104, 103, 279, 334, 366]
    # ["orca", "great_white_shark", "puffer", "starfish", "loggerhead", "sea lion", "jellyfish", "anemone", "rock crab", "rock beauty"]
    marine = [148, 2, 397, 327, 33, 150, 107, 108, 119, 392]

    # ['Leonberg', 'proboscis monkey, Nasalis larvatus', 'rapeseed', 'three-toed sloth, ai, Bradypus tridactylus', 'cliff dwelling', "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 'hamster', 'gondola', 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 'limpkin, Aramus pictus']
    alpha = [255, 376, 984, 364, 500, 986, 333, 576, 148, 135]
    # ['spoonbill', 'web site, website, internet site, site', 'lorikeet', 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 'earthstar', 'trolleybus, trolley coach, trackless trolley', 'echidna, spiny anteater, anteater', 'Pomeranian', 'odometer, hodometer, mileometer, milometer', 'ruddy turnstone, Arenaria interpres']
    beta = [129, 916,  90, 275, 995, 874, 102, 259, 685, 139]
    # ['freight car', 'hummingbird', 'fireboat', 'disk brake, disc brake', 'bee eater', 'rock beauty, Holocanthus tricolor', 'lion, king of beasts, Panthera leo', 'European gallinule, Porphyrio porphyrio', 'cabbage butterfly', 'goldfinch, Carduelis carduelis']
    gamma = [565,  94, 554, 535,  92, 392, 291, 136, 324,  11]
    # ['ostrich, Struthio camelus', 'Samoyed, Samoyede', 'junco, snowbird', 'Brabancon griffon', 'chickadee', 'sorrel', 'admiral', 'great grey owl, great gray owl, Strix nebulosa', 'hornbill', 'ringlet, ringlet butterfly']
    delta = [  9, 258,  13, 262,  19, 339, 321,  24,  93, 322]
    # ['spindle', 'toucan', 'black swan, Cygnus atratus', 'king penguin, Aptenodytes patagonica', "potter's wheel", 'photocopier', 'screw', 'tarantula', 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 'lycaenid, lycaenid butterfly']
    epsilon = [816,  96, 100, 145, 739, 713, 783,  76, 688, 326]
    dict = {
        "imagenette" : imagenette,
        "imagewoof" : imagewoof,
        "fruits": fruits,
        "yellow": yellow,
        "cats": imagemeow,
        "birds": imagesquawk,
        "geo": imagegeo,
        "food": imagefood,
        "mammals": imagemammal,
        "marine": marine,
        "a": alpha,
        "b": beta,
        "c": gamma,
        "d": delta,
        "e": epsilon
    }

    mean = torch.tensor([0.4377, 0.4438, 0.4728]).reshape(1, 3, 1, 1).cuda()
    std = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1).cuda()

config = Config()

def get_dataset(dataset, data_path, batch_size = 1, res = None, args = None):

    dst_train_dict = None
    class_map = None
    loader_train_dict = None
    class_map_inv = None

    if dataset in ['CIFAR10', 'CIFAR100']:
        channel = 3
        im_size = (32, 32)
        num_classes = 10 if dataset == 'CIFAR10' else 100
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]

        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        dataset_class = datasets.CIFAR10 if dataset == 'CIFAR10' else datasets.CIFAR100
        dst_train = dataset_class(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = dataset_class(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes
        class_map = {x:x for x in range(num_classes)}
        class_map_inv = {x: x for x in range(num_classes)}

    elif dataset.startswith("imagenet"):
        subset = dataset.split("-")[1]
        channel = 3
        im_size = (res, res)
        num_classes = 10
        config.img_net_classes = config.dict[subset]
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]

        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), transforms.Resize(res, antialias = True), transforms.CenterCrop(res)])
        dst_train = datasets.ImageNet(data_path, split="train", transform=transform) # no augmentation
        dst_train_dict = {c : torch.utils.data.Subset(dst_train, np.squeeze(np.argwhere(np.equal(dst_train.targets, config.img_net_classes[c])))) for c in range(len(config.img_net_classes))}
        dst_train = torch.utils.data.Subset(dst_train, np.squeeze(np.argwhere(np.isin(dst_train.targets, config.img_net_classes))))
        loader_train_dict = {c : torch.utils.data.DataLoader(dst_train_dict[c], batch_size=batch_size, shuffle=True, num_workers=4) for c in range(len(config.img_net_classes))}
        dst_test = datasets.ImageNet(data_path, split="val", transform=transform)
        dst_test = torch.utils.data.Subset(dst_test, np.squeeze(np.argwhere(np.isin(dst_test.targets, config.img_net_classes))))
        for c in range(len(config.img_net_classes)):
            dst_test.dataset.targets[dst_test.dataset.targets == config.img_net_classes[c]] = c
            dst_train.dataset.targets[dst_train.dataset.targets == config.img_net_classes[c]] = c
        # class_names = dst_train.classes
        class_map = {x: i for i, x in enumerate(config.img_net_classes)}
        class_map_inv = {i: x for i, x in enumerate(config.img_net_classes)}
        class_names = None

    else:
        exit('unknown dataset: %s'%dataset)

    testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_test, shuffle=False, num_workers=2)

    return channel, im_size, num_classes, class_names, class_map, class_map_inv, mean, std, dst_train, dst_test, testloader, loader_train_dict 
    
class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]

# def get_default_convnet_setting():
#     net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
#     return net_width, net_depth, net_act, net_norm, net_pooling

def get_network(model, channel, num_classes, im_size=(32, 32), dist=True, depth=3, width=128, norm="instancenorm", convnet_pooling = 'avgpooling'):
    torch.random.manual_seed(int(time.time() * 1000) % 100000)
    # net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()

    if model == 'AlexNet':
        net = AlexNet(channel, num_classes=num_classes, im_size=im_size)
    elif model == 'VGG11':
        net = VGG11(channel=channel, num_classes=num_classes)
    elif model == 'VGG11BN':
        net = VGG11BN(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18':
        net = ResNet18(channel=channel, num_classes=num_classes, norm=norm)
    elif model == "ViT":
        net = ViT(
            image_size = im_size,
            patch_size = 16,
            num_classes = num_classes,
            dim = 512,
            depth = 10,
            heads = 8,
            mlp_dim = 512,
            dropout = 0.1,
            emb_dropout = 0.1,
        )


    elif model == "AlexNetCIFAR":
        net = AlexNetCIFAR(channel=channel, num_classes=num_classes)
    elif model == "ResNet18CIFAR":
        net = ResNet18CIFAR(channel=channel, num_classes=num_classes)
    elif model == "VGG11CIFAR":
        net = VGG11CIFAR(channel=channel, num_classes=num_classes)
    elif model == "ViTCIFAR":
        net = ViTCIFAR(
                image_size = im_size,
                patch_size = 4,
                num_classes = num_classes,
                dim = 512,
                depth = 6,
                heads = 8,
                mlp_dim = 512,
                dropout = 0.1,
                emb_dropout = 0.1)

    elif model == "ConvNet":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm=norm, net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNetGAP":
        net = ConvNetGAP(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm=norm, net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNet_BN":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm="batchnorm", net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNet_IN":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm="instancenorm", net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNet_LN":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm="layernorm", net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNet_GN":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm="groupnorm", net_pooling=convnet_pooling, im_size=im_size)
    elif model == "ConvNet_NN":
        net = ConvNet(channel, num_classes, net_width=width, net_depth=depth, net_act='relu', net_norm="none", net_pooling=convnet_pooling, im_size=im_size)

    else:
        net = None
        exit('DC error: unknown model')

    if dist:
        gpu_num = torch.cuda.device_count()
        if gpu_num>0:
            device = 'cuda'
            if gpu_num>1:
                net = nn.DataParallel(net)
        else:
            device = 'cpu'
        net = net.to(device)

    return net

def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))

def distance_wb(gwr, gws):
    shape = gwr.shape
    if len(shape) == 4: # conv, out*in*h*w
        gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
        gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
    elif len(shape) == 3:  # layernorm, C*h*w
        return torch.tensor(0, dtype=torch.float, device=gwr.device)
        gwr = gwr.reshape(shape[0], shape[1] * shape[2])
        gws = gws.reshape(shape[0], shape[1] * shape[2])
    elif len(shape) == 2: # linear, out*in
        tmp = 'do nothing'
    elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
        gwr = gwr.reshape(1, shape[0])
        gws = gws.reshape(1, shape[0])
        return torch.tensor(0, dtype=torch.float, device=gwr.device)

    dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
    dis = dis_weight
    return dis

def match_loss(gw_syn, gw_real, dis_metric):
    dis = 0.0
    if dis_metric == 'ours':
        for ig in range(len(gw_real)):
            gwr = gw_real[ig]
            gws = gw_syn[ig]
            dis += distance_wb(gwr, gws)

    elif dis_metric == 'mse':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = torch.sum((gw_syn_vec - gw_real_vec)**2)

    elif dis_metric == 'cos':
        gw_real_vec = []
        gw_syn_vec = []
        for ig in range(len(gw_real)):
            gw_real_vec.append(gw_real[ig].reshape((-1)))
            gw_syn_vec.append(gw_syn[ig].reshape((-1)))
        gw_real_vec = torch.cat(gw_real_vec, dim=0)
        gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
        dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)

    else:
        exit('DC error: unknown distance function')

    return dis

def get_loops(ipc):
    # Get the two hyper-parameters of outer-loop and inner-loop.
    # The following values are empirically good.
    if ipc == 1:
        outer_loop, inner_loop = 1, 1
    elif ipc == 10:
        outer_loop, inner_loop = 10, 50
    elif ipc == 20:
        outer_loop, inner_loop = 20, 25
    elif ipc == 30:
        outer_loop, inner_loop = 30, 20
    elif ipc == 40:
        outer_loop, inner_loop = 40, 15
    elif ipc == 50:
        outer_loop, inner_loop = 50, 10
    else:
        outer_loop, inner_loop = 0, 0
        exit('DC error: loop hyper-parameters are not defined for %d ipc'%ipc)
    return outer_loop, inner_loop

def epoch(mode, dataloader, net, optimizer, criterion, args, aug, target_iter = None):
    loss_avg, acc_avg, num_exp, num_calc_acc = 0, 0, 0, 0
    net = net.to(args.device)
    # criterion = criterion.to(args.device)

    if "imagenet" in args.dataset:
        class_map = {x: i for i, x in enumerate(config.img_net_classes)}

    if mode == 'train':
        net.train()
    else:
        net.eval()

    cur_iter = 0
    stop = False if target_iter else True
    while True:
        for i_batch, datum in enumerate(dataloader):
            img = datum[0].to(args.device)
            lab = datum[1].to(args.device)
            if aug:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)

            if "imagenet" in args.dataset and mode == 'test':
                lab = torch.tensor([class_map[x.item()] for x in lab]).to(args.device)

            n_b = lab.shape[0]

            # CutMix, based on the implementation of IDC
            if mode == 'train' and np.random.rand(1) < args.cutmix_p:
                lam = np.random.beta(args.cutmix_beta, args.cutmix_beta)
                rand_index = torch.randperm(n_b).to(args.device)
                lab_b = lab[rand_index]
                bbx1, bby1, bbx2, bby2 = rand_bbox(img.shape, lam)
                img[:, :, bbx1:bbx2, bby1:bby2] = img[rand_index, :, bbx1:bbx2, bby1:bby2]
                ratio = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.shape[-1] * img.shape[-2]))
                output = net(img)
                loss = criterion(output, lab) * ratio + criterion(output, lab_b) * (1. - ratio)
            else:
                output = net(img)
                loss = criterion(output, lab)
                predicted = torch.argmax(output.data, 1)
                correct = (predicted == lab).sum()
                acc_avg += correct.item()
                num_calc_acc += n_b

            loss_avg += loss.item()*n_b
            num_exp += n_b

            if mode == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if target_iter:
                cur_iter += 1
                if cur_iter == target_iter:
                    stop = True
                    break

        if stop:
            break

    loss_avg /= num_exp
    acc_avg = acc_avg / num_calc_acc if num_calc_acc > 0 else np.nan

    return loss_avg, acc_avg

def evaluate_synset(it_eval, net, images_train, labels_train, testloader, logger, args, decay="cosine", 
                    return_loss=False, test_epoch=200, aug=True, verbose = False, lr = None):
    net = net.to(args.device)
    images_train = images_train.to(args.device)
    labels_train = labels_train.to(args.device)
    Epoch = int(args.epoch_eval_train)
    optimizer = get_optimizer_net(args, net, lr)

    if decay == "cosine":
        sched1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.0000001, end_factor=1.0, total_iters=Epoch//2)
        sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Epoch//2)

    elif decay == "step":
        lmbda1 = lambda epoch: 1.0
        sched1 = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lmbda1)
        lmbda2 = lambda epoch: 0.1
        sched2 = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lmbda2)

    sched = sched1

    ema = EMA(net, beta=0.995, power=1, update_after_step=0, update_every=1)

    criterion = nn.CrossEntropyLoss().to(args.device)

    dst_train = TensorDataset(images_train, labels_train)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    start = time.time()
    acc_train_list = []
    loss_train_list = []
    acc_test_list = []
    loss_test_list = []
    acc_test_max = 0
    acc_test_max_epoch = 0
    for ep in (tqdm(range(Epoch)) if verbose else range(Epoch)):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=aug)
        acc_train_list.append(acc_train)
        loss_train_list.append(loss_train)
        ema.update()
        sched.step()
        if ep == Epoch // 2:
            sched = sched2

        if (ep + 1) % test_epoch == 0:
            with torch.no_grad():
                loss_test, acc_test = epoch('test', testloader, ema, optimizer, criterion, args, aug=False)
            acc_test_list.append(acc_test)
            loss_test_list.append(loss_test)
            if acc_test > acc_test_max:
                acc_test_max = acc_test
                acc_test_max_epoch = ep + 1

    time_train = time.time() - start

    logger.log('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f @ %d epoch' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test_max, acc_test_max_epoch))

    if return_loss:
        return net, acc_train_list, acc_test_list, loss_train_list, loss_test_list
    else:
        return net, acc_train_list, acc_test_list

def get_eval_pool(eval_mode, model, model_eval):
    if eval_mode == 'M': # multiple architectures
        model_eval_pool = [model, "ResNet18", "VGG11", "AlexNet", "ViT"]
    elif eval_mode == 'cross':  # other architectures except for args.model
        model_eval_pool = ["ResNet18", "VGG11", "AlexNet", "ViT"]
    elif eval_mode == 'W': # ablation study on network width
        model_eval_pool = ['ConvNetW32', 'ConvNetW64', 'ConvNetW128', 'ConvNetW256']
    elif eval_mode == 'D': # ablation study on network depth
        model_eval_pool = ['ConvNetD1', 'ConvNetD2', 'ConvNetD3', 'ConvNetD4']
    elif eval_mode == 'A': # ablation study on network activation function
        model_eval_pool = ['ConvNetAS', 'ConvNetAR', 'ConvNetAL']
    elif eval_mode == 'P': # ablation study on network pooling layer
        model_eval_pool = ['ConvNetNP', 'ConvNetMP', 'ConvNetAP']
    elif eval_mode == 'N': # ablation study on network normalization layer
        model_eval_pool = ['ConvNetNN', 'ConvNetBN', 'ConvNetLN', 'ConvNetIN', 'ConvNetGN']
    elif eval_mode == 'S': # itself
        model_eval_pool = [model[:model.index('BN')]] if 'BN' in model else [model]
    elif eval_mode == 'C':
        model_eval_pool = [model, 'ConvNet']
    elif eval_mode == "big":
        model_eval_pool = [model, "RN18", "VGG11_big", "ViT"]
    elif eval_mode == "small":
        model_eval_pool = [model, "ResNet18", "VGG11", "LeNet", "AlexNet"]
    elif eval_mode == "ConvNet_Norm":
        model_eval_pool = ["ConvNet_BN", "ConvNet_IN", "ConvNet_LN", "ConvNet_NN", "ConvNet_GN"]
    elif eval_mode == "CIFAR":
        model_eval_pool = [model, "AlexNetCIFAR", "ResNet18CIFAR", "VGG11CIFAR", "ViTCIFAR"]
    else:
        model_eval_pool = [model_eval]
    return model_eval_pool

class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'IDC' # *M*ultiple, *S*ingle, or *IDC*
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.ratio_noise = 0.05
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5

def set_seed_DiffAug(param):
    if param.latestseed == -1:
        return
    else:
        torch.random.manual_seed(param.latestseed)
        param.latestseed += 1

def DiffAugment(x, strategy='', seed = -1, param = None):
    if seed == -1:
        param.batchmode = False
    else:
        param.batchmode = True

    param.latestseed = seed

    if strategy == 'None' or strategy == 'none':
        return x

    if strategy:
        if param.aug_mode == 'M': # original
            for p in strategy.split('_'):
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
        elif param.aug_mode == 'S':
            pbties = strategy.split('_')
            set_seed_DiffAug(param)
            p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
            print('seed', seed, p)
            for f in AUGMENT_FNS[p]:
                x = f(x, param)
        elif param.aug_mode == 'IDC':
            pbties = strategy.split('_')
            # print(pbties)
            set_seed_DiffAug(param)
            use_pbties = []
            if 'flip' in pbties:
                for f in AUGMENT_FNS['flip']:
                    x = f(x, param)
                pbties.remove('flip')
                use_pbties.append('flip')
            if 'color' in pbties:
                for f in AUGMENT_FNS['color']:
                    x = f(x, param)
                pbties.remove('color')
                use_pbties.append('color')
            do_cutout = False
            if 'cutout' in pbties:
                pbties.remove('cutout')
            if len(pbties) > 0:
                p = random.choice(pbties)
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
                use_pbties.append(p)
            if do_cutout:
                for f in AUGMENT_FNS['cutout']:
                    x = f(x, param)
                use_pbties.append('cutout')
            # print(use_pbties)
            # print()
        else:
            exit('Error ZH: unknown augmentation mode.')
        x = x.contiguous()
    return x

# We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
def rand_scale(x, param):
    ratio = param.ratio_scale
    set_seed_DiffAug(param)
    sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    set_seed_DiffAug(param)
    sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
    theta = [[[sx[i], 0,  0],
            [0,  sy[i], 0],] for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.batchmode: # batch-wise:
        theta[:] = theta[0].clone()
    grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
    x = F.grid_sample(x, grid, align_corners=True)
    return x

def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
    ratio = param.ratio_rotate
    set_seed_DiffAug(param)
    theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
    theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
        [torch.sin(theta[i]), torch.cos(theta[i]),  0],]  for i in range(x.shape[0])]
    theta = torch.tensor(theta, dtype=torch.float)
    if param.batchmode: # batch-wise:
        theta[:] = theta[0].clone()
    grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
    x = F.grid_sample(x, grid, align_corners=True)
    return x

def rand_flip(x, param):
    prob = param.prob_flip
    set_seed_DiffAug(param)
    randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
    if param.batchmode: # batch-wise:
        randf[:] = randf[0].clone()
    return torch.where(randf < prob, x.flip(3), x)

def rand_brightness(x, param):
    ratio = param.brightness
    set_seed_DiffAug(param)
    randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.batchmode:  # batch-wise:
        randb[:] = randb[0].clone()
    x = x + (randb - 0.5)*ratio
    return x

def rand_saturation(x, param):
    ratio = param.saturation
    x_mean = x.mean(dim=1, keepdim=True)
    set_seed_DiffAug(param)
    rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.batchmode:  # batch-wise:
        rands[:] = rands[0].clone()
    x = (x - x_mean) * (rands * ratio) + x_mean
    return x

def rand_contrast(x, param):
    ratio = param.contrast
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    set_seed_DiffAug(param)
    randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.batchmode:  # batch-wise:
        randc[:] = randc[0].clone()
    x = (x - x_mean) * (randc + ratio) + x_mean
    return x

def rand_crop(x, param):
    # The image is padded on its surrounding and then cropped.
    ratio = param.ratio_crop_pad
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    if param.batchmode:  # batch-wise:
        translation_x[:] = translation_x[0].clone()
        translation_y[:] = translation_y[0].clone()
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x

def rand_cutout(x, param):
    ratio = param.ratio_cutout
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    if param.batchmode:  # batch-wise:
        offset_x[:] = offset_x[0].clone()
        offset_y[:] = offset_y[0].clone()
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x

AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'crop': [rand_crop],
    'cutout': [rand_cutout],
    'flip': [rand_flip],
    'scale': [rand_scale],
    'rotate': [rand_rotate],
}

# originally in grad_utils.py
import copy, gc
import torchvision as tv

def get_eval_lrs(args):
    eval_pool_dict = {
        args.model: 0.01, #0.001,
        "ResNet18": 0.001,
        "VGG11": 0.0001,
        "AlexNet": 0.001,
        "ViT": 0.001,

        "AlexNetCIFAR": 0.001,
        "ResNet18CIFAR": 0.001,
        "VGG11CIFAR": 0.0001,
        "ViTCIFAR": 0.001,
    }

    return eval_pool_dict


########
# Added
########
class Logger():
    def __init__(self, log_file):
        self.f = open(log_file, 'w')

    def log(self, *x, sep = ' ', end = '\n', std_output = True):
        output = sep.join([str(xx) for xx in x])
        self.f.write(output + end)
        self.f.flush()
        if std_output:
            print(output)

def get_run_name(args):
    run_name = f'f{args.f}_ipc{args.ipc}'
    if args.ldc_mode != 'same_ipc':
        run_name += f'-{args.lpc}'
    if args.init != 'real':
        run_name += f'_init-{args.init}'
    # if args.lr_latent_ != 0.1:
    run_name += f'_lr-lb{args.lr_latent_base}'
    if hasattr(args, 'loop') and args.loop:
        run_name += f'_loop{args.loop}'
    # if not args.dsa:
    #     run_name += '_noaug'
    if hasattr(args, 'num_batch_syn'):
        run_name += f'_nbsyn{args.num_batch_syn}'
    if args.comment:
        run_name += f'_{args.comment}'
    run_name += f'_{time.strftime("%m%d-%H%M")}'
    return run_name

def print_eval_it_pool(eval_it_pool):
    if len(eval_it_pool) <= 5:
        return str(eval_it_pool)
    return f'[{eval_it_pool[0]}, {eval_it_pool[1]}, {eval_it_pool[2]}, ..., {eval_it_pool[-1]}]'

@torch.no_grad()
def build_dataset(args, ae_model, ds, class_map, batch_size = 16, test_max_num = None):
    print(f'Building dataset of size {len(ds)}')
    latent_all = []
    label_all = []
    indices_class = [[] for c in range(args.num_classes)]

    if test_max_num is None:
        total_num = len(ds)
        for i in tqdm(range(math.ceil(total_num / batch_size)), desc = f'Organizing images/latents with batch size {batch_size}'):
            batch_st = i * batch_size
            batch_ed = min(batch_st + batch_size, total_num)
            image = []
            for j in range(batch_st, batch_ed):
                im, label = ds[j]
                image.append(im.unsqueeze(0))
                label_all.append(class_map[label])
            image = torch.cat(image)
            latent = ae_model.encode(tv.transforms.functional.resize(image, (args.im_size[0] * (8 // args.f), args.im_size[1] * (8 // args.f)), antialias = True).cuda())
            latent_all.extend(latent.cpu().split(1))
    else:
        # for i in tqdm(range(len(ds)), desc = 'Organizing images/latents'):
        select_index = np.random.choice(len(ds), test_max_num, replace = False)
        for i in tqdm(select_index, desc = 'Organizing images/latents'):
            image, label = ds[i]
            latent = ae_model.encode(tv.transforms.functional.resize(image.cuda().unsqueeze(0), (args.im_size[0] * (8 // args.f), args.im_size[1] * (8 // args.f)), antialias = True))
            latent_all.append(latent.cpu())
            label_all.append(class_map[torch.tensor(label).item()])
    for i, lab in tqdm(enumerate(label_all), desc = 'Building indices_class'):
        indices_class[lab].append(i)
    latent_all = torch.cat(latent_all, dim=0)
    label_all = torch.tensor(label_all, dtype=torch.long)
    return latent_all, label_all, indices_class

@torch.no_grad()
def build_image_dataset(args, ds, class_map, test_max_num = None):
    print(f'Building dataset of size {len(ds)}')
    image_all = []
    label_all = []
    indices_class = [[] for c in range(args.num_classes)]

    select_index = np.random.choice(len(ds), test_max_num, replace = False) if test_max_num else range(len(ds))
    for i in tqdm(select_index, desc = 'Organizing images/latents'):
        image, label = ds[i]
        image_all.append(image.unsqueeze(0))
        label_all.append(class_map[torch.tensor(label).item()])
    for i, lab in tqdm(enumerate(label_all), desc = 'Building indices_class'):
        indices_class[lab].append(i)
    image_all = torch.cat(image_all, dim=0)
    label_all = torch.tensor(label_all, dtype=torch.long)
    return image_all, label_all, indices_class

def load_autoencoder_from_config(config, ckpt, verbose = False):
    print(f"Loading model from {ckpt}")
    pl_ae = torch.load(ckpt, map_location = "cpu")
    model = instantiate_from_config(config.autoencoder_config)
    m, u = model.load_state_dict(pl_ae, strict = False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    model.eval()
    return model

def prepare_latent(args, get_latents = None):
    with torch.no_grad():
        ''' initialize the synthetic data '''
        label_syn = torch.tensor([[i for _ in range(args.lpc)] for i in range(args.num_classes)], 
                                 dtype = torch.long, requires_grad = False, device = args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
        latent = torch.randn((args.num_classes * args.lpc, args.C, args.latent_size[0], args.latent_size[1]), 
                              dtype = torch.float, requires_grad = True, device = args.device)
        if args.init == 'real':
            for c in range(args.num_classes):
                latent.data[c * args.lpc:(c + 1) * args.lpc] = get_latents(c, args.lpc).detach().data
        return latent, label_syn

def get_optimizer_latent(args, latent):
    return torch.optim.SGD([latent], lr = args.lr_latent_base * args.lpc, momentum = args.mom_latent)

def get_optimizer_net(args, net, lr = None):
    return torch.optim.SGD(net.parameters(), lr = lr if lr is not None else args.lr_net, 
                           momentum = args.mom_net, weight_decay = args.weight_decay)

@torch.no_grad()
def latent_to_image(latent, ae_model, batch_size = 16):
    total_num = latent.shape[0]
    images = []
    for i in range(math.ceil(total_num / batch_size)):
        # tic = time.time()
        batch_st = i * batch_size
        batch_ed = min(batch_st + batch_size, total_num)
        batch_latent = latent[batch_st:batch_ed]
        im = ae_model.decode(batch_latent)
        images.append(im.clamp(-1.0, 1.0))
        # toc = time.time()
        # print(toc - tic)
    return torch.cat(images)

@torch.no_grad()
def image_to_latent(images, ae_model, batch_size = 16):
    total_num = images.shape[0]
    latent = []
    for i in range(math.ceil(total_num / batch_size)):
        batch_st = i * batch_size
        batch_ed = min(batch_st + batch_size, total_num)
        batch_images = images[batch_st:batch_ed]
        latent = ae_model.encode(batch_images)
        latent.append(latent)
    return torch.cat(latent)

def save(args, latent, ae_model, it = 0):
    image_syn = latent_to_image(latent, ae_model, batch_size = 1)
    image_syn = tv.transforms.functional.resize(image_syn, args.im_size)  
    tv.utils.save_image(image_syn.detach().cpu(), os.path.join(args.save_path, f'image_iter{it}.jpg'), nrow = args.lpc, normalize = True, value_range = (-1, 1))       

def eval_and_save(args, latent, label_syn, ae_model, logger, testloader=None, model_eval_pool=[], it=0, save = True, verbose = False, use_lr_net = True):
    best_acc = {"{}".format(m): 0 for m in model_eval_pool}
    best_std = {m: 0 for m in model_eval_pool}
    if not use_lr_net:
        eval_pool_dict = get_eval_lrs(args)

    l2i_bs = 1 if args.latent_size[0] > 64 else 4 if args.latent_size[0] > 32 else 16
    image_syn = latent_to_image(latent, ae_model, batch_size = l2i_bs)
    image_syn = tv.transforms.functional.resize(image_syn, args.im_size, antialias = True)    
    
    # save
    if save:
        torch.save({'latent': latent.detach().cpu(), 'label_syn': label_syn.detach().cpu()}, os.path.join(args.save_path, f'latent-label_iter{it}.pt'))
        tv.utils.save_image(image_syn.detach().cpu(), os.path.join(args.save_path, f'image_iter{it}.jpg'), nrow = args.lpc, normalize = True, value_range = (-1, 1))       
    
    # eval
    for model_eval in (model_eval_pool if it > 0 else model_eval_pool[0:1]):
        lr = args.lr_net if use_lr_net else eval_pool_dict[model_eval]
        logger.log('-------------------------\nEvaluation\nmodel_train = %s (D%d), model_eval = %s (lr %f), iteration = %d' % (args.model, args.test_depth, model_eval, lr, it))
        acc_max_test = []
        for it_eval in range(args.num_eval if it > 0 else 1):
            net_eval = get_network(model_eval, args.channel, args.num_classes, args.im_size, 
                                width = args.test_width, depth = args.test_depth, dist = False).to(args.device)  # get a random model
            image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
            _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader,
                                                     logger, args=args, aug = args.dsa_eval, verbose = verbose, lr = lr)
            del _
            del net_eval
            acc_max_test.append(max(acc_test))
        # print(accs_test)
        acc_max_test = np.array(acc_max_test)
        acc_max_test_mean, acc_max_test_std = acc_max_test.mean(), acc_max_test.std()
        best_acc[model_eval] = acc_max_test_mean
        best_std[model_eval] = acc_max_test_std
        logger.log('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------' % (
        len(acc_max_test), model_eval, acc_max_test_mean, acc_max_test_std))

    return best_acc, best_std

def get_lpc(args):
    if args.ldc_mode == 'same_ipc':
        return args.ipc
    if args.ldc_mode == 'same_param':
        return int((args.f ** 2) * args.channel / args.C * args.ipc)
    if args.ldc_mode == 'same_storage':
        return int((args.f ** 2) * args.channel / args.C / 4  * args.ipc)
    
# CutMix
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

@torch.no_grad()
def compute_grad_l1(grad):
    if type(grad) in [list, tuple]:
        return sum([g.norm(1) for g in grad])
    if type(grad) is torch.Tensor:
        return grad.norm(1)