import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision.models.vision_transformer import _vision_transformer
from torchvision import transforms

is_mnist = True
num_classes = 10
im_size = 28 if is_mnist else 32  # 28 for mnist or 32 for cifar10

params = []
for i in range(1000):
    layers = np.random.randint(2, 5)
    dim_max = 16 if layers > 2 else 24
    dim = np.random.choice(np.arange(8, dim_max + 1, 8))
    mlp_dim = int(dim * (1 if dim > 16 or layers > 2 else 4))
    heads = np.random.choice([4, 8])
    model = _vision_transformer(patch_size=4, num_layers=layers,
                                num_heads=heads, hidden_dim=dim,
                                mlp_dim=mlp_dim, num_classes=num_classes,
                                image_size=im_size, weights=None,
                                progress=False)
    params.append(sum([p.numel() for p in model.parameters()]))
    out = model(torch.randn(2, 3, im_size, im_size))
#     print(out.shape)  # 2xnum_classes
#     print(layers, dim, dim_max, mlp_dim, heads, params[-1])
plt.hist(params, bins=20, alpha=0.5)
plt.show()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0008, weight_decay=0.1)
batch_size = 128

normalize = transforms.Normalize(
    (0.1307,) if is_mnist else (0.49139968, 0.48215827, 0.44653124),
    (0.3081,) if is_mnist else (0.24703233, 0.24348505, 0.26158768),
)
train_transform = [
    transforms.RandomCrop(im_size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(),
    transforms.ToTensor(),
    normalize
]
