# coding=utf-8
import torch
from network import img_network

def get_fea(args):
    if args.dataset == 'dg5':
        net = img_network.DTNBase()
    elif args.net.startswith('res'):
        net = img_network.ResBase(args.net)
    elif args.net.startswith('vgg'):
        net = img_network.VGGBase(args.net)
    elif args.net.startswith('ViT'):
        net = img_network.ViTBase(args.net)
    elif args.net.startswith('Eff'):
        net = img_network.EfficientBase(args.net)
    elif args.net.startswith('Mix'):
        net = img_network.MLPMixer(args.net)
    return net


def accuracy(network, loader):
    correct = 0
    total = 0

    network.eval()
    with torch.no_grad():
        for data in loader:
            x = data[0].cuda().float()
            y = data[1].cuda().long()
            p = network.predict(x)

            if p.size(1) == 1:
                correct += (p.gt(0).eq(y).float()).sum().item()
            else:
                correct += (p.argmax(1).eq(y).float()).sum().item()
            total += len(x)
    network.train()
    return correct / total
