import collections
import torch
import torch.nn as nn
from timm.models import create_model
from timm.models import create_model
from models.vit_models import *


def create_cnngnn(log,vit_model_name='deit_base_patch16_224',vit_layers=12,in_channel=37,
                  nb_classes=3, drop=0.0, drop_path=0.1, ckpt_path='./teacher/deit_base_patch16_224-b5f2ef4d.pth'):
    model_name = vit_model_name + '_L' + str(vit_layers)
    log.write(f"Creating model: {model_name}")
    vit = create_model(
        model_name,
        in_channel=in_channel,
        pretrained=False,
        num_classes=nb_classes,
        drop_rate=drop,
        drop_path_rate=drop_path,
        drop_block_rate=None,
    )
    return vit