from generator import SyntheticImageGenerator
import torch

dataset = "ImageNet10"
ipc = 10

if dataset == "CIFAR10" or dataset == "SVHN":
    num_classes = 10
    img_size = 32    
    kernel_size = 2
    stride = 2
    padding = 0
    if ipc == 1: # 1.0046875
        hdims = [6,9,12]
        num_seed_vec = 13
        num_decoder = 8
    elif ipc == 10: # 10.28828125
        hdims = [6,9,12]
        num_seed_vec = 160
        num_decoder = 12    
    elif ipc == 50: # 50.1921875
        hdims = [6,12]
        num_seed_vec = 200
        num_decoder = 16

elif dataset == "CIFAR100":
    num_classes = 100
    img_size = 32    
    kernel_size = 2
    stride = 2
    padding = 0
    if ipc == 1: # 1.01921875
        hdims = [6,9,12]
        num_seed_vec = 16
        num_decoder = 8
    elif ipc == 10: # 10.028828125
        hdims = [6,9,12]
        num_seed_vec = 160
        num_decoder = 12

elif dataset == "CIFAR100_cl":
    num_classes = 20
    img_size = 32
    kernel_size = 4
    stride = 2
    padding = 1
    if ipc == 20: # 20.264583333333334
        hdims = [8,13,18]
        num_seed_vec = 200
        num_decoder = 16
        #hdims = [12]
        #num_seed_vec = 10
        #num_decoder = 8

elif dataset == "TinyImageNet":
    num_classes = 200
    img_size = 64
    kernel_size = 2
    stride = 2
    padding = 0
    if ipc == 1: # 1.01921875
        hdims = [6,9,12]
        num_seed_vec = 16
        num_decoder = 8
    elif ipc == 10: # 10.00240234375
        hdims = [6,12]
        num_seed_vec = 40
        num_decoder = 16

elif dataset == "ImageNet10":
    num_classes = 10
    img_size = 224
    kernel_size = 4
    stride = 2
    padding = 1
    if ipc == 1: # 1.00240234375
        hdims = [3,3,3]
        num_seed_vec = 64
        num_decoder = 14
    elif ipc == 10: # 10.005422247023809
        hdims = [4,6]
        num_seed_vec = 80
        num_decoder = 14

a=SyntheticImageGenerator(
    num_classes, (img_size,img_size), num_seed_vec, num_decoder, hdims,
    kernel_size, stride, padding)
#print(a.decoders[0])
del a.encoders
#print(a.encoders[0](torch.randn(16,3,img_size,img_size)).shape)
#print(a.decoders[0](a.encoders[0](torch.randn(16,3,img_size,img_size))).shape)
num_params = 0
for n,p in a.named_parameters():
    #print(n, p.numel())
    num_params += p.numel()
#print(a.seed_vec.shape)
our_ipc = num_params/(num_classes*3*img_size*img_size)
print(our_ipc)
overparam = 100*(our_ipc - ipc) / ipc
print(overparam)