import torch
import torch.nn as nn

from models.embedding.cnn import CNN
from models.embedding.textcnn import TextCNN
from models.classifier.mlp import MLP
from models.embedding.resnet import Resnet50
from data_utils import is_textdata

def get_model(args, vocab=None):
    model = {}
    if args.dataset[:5] == 'MNIST':
        if args.dataset == 'MNIST':
            model['ebd'] = CNN(include_fc=True, hidden_dim=args.hidden_dim).cuda()
        else:
            _, _, _, max_c = args.dataset.split('_')
            model['ebd'] = CNN(include_fc=True,
                               hidden_dim=args.hidden_dim,
                               input_channels=int(max_c)).cuda()
        out_dim=args.hidden_dim
        num_classes = 10

    if args.dataset[:4] == 'bird':
        model['ebd'] = Resnet50().cuda()
        out_dim=model['ebd'].out_dim
        num_classes = 2

    if args.dataset[:6] == 'celeba':
        model['ebd'] = Resnet50().cuda()
        out_dim=model['ebd'].out_dim
        num_classes = 2

    if is_textdata(args.dataset):
        model['ebd'] = TextCNN(vocab, num_filters=args.hidden_dim,
                               dropout=args.dropout).cuda()
        out_dim = args.hidden_dim * 3  # 3 different filters
        num_classes = 2

    model['clf_all'] = MLP(out_dim, args.hidden_dim, num_classes,
                           args.dropout, depth=1).cuda()

    opt = torch.optim.Adam(list(model['ebd'].parameters()) +
                           list(model['clf_all'].parameters()), lr=args.lr,
                           weight_decay=args.weight_decay)

    return model, opt
