import torch
from torch import nn
from vit_pytorch import ViT


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        self.args = args

        if args.dataset == 'cifar10':
            num_class = 10
        elif args.dataset == 'cifar100':
            num_class = 100
        elif args.dataset == 'svhn':
            num_class = 10
        elif args.dataset == 'imagenet':
            num_class = 1000

        self.encoder = ViT(
                            image_size = 32,
                            patch_size = args.vit_patch_size,
                            num_classes = num_class,
                            dim = 1024,
                            depth = args.depth,
                            heads = 16,
                            mlp_dim = 2048,
                            dropout = 0.1,
                            emb_dropout = 0.1
                            )

    def forward(self, x, device):
        y_pred_linear = self.encoder(x)
        y_pred = y_pred_linear.argmax(1)
        return y_pred_linear, y_pred

