# coding=utf-8
import torch
from network import img_network


def get_fea(args):
    if args.net.startswith('vgg'):
        net = img_network.VGGBase(args)
    elif args.net == 'LeNet':
        net = img_network.LeNetBase()
    elif args.net == 'DTN':
        net = img_network.DTNBase()
    elif args.net.startswith('res'):
        if args.bacAug in args.MixAlg:
            net = img_network.ResMix(args)
        else:
            net = img_network.ResBase(args)
    else:
        net = img_network.VGGBase(args)
    return net


def accuracy(network, loader):
    correct = 0
    total = 0

    network.eval()
    with torch.no_grad():
        for data in loader:
            x, y = deal_special_case(data)
            p = network(x)

            if p.size(1) == 1:
                correct += (p.gt(0).eq(y).float()).sum().item()
            else:
                correct += (p.argmax(1).eq(y).float()).sum().item()
            total += len(x)
    network.train()
    return correct / total

def deal_special_case(data):
    if isinstance(data[0], list):
        '''
        in this case, it use original supervised contrastive loss algorithm. (args.souceAlg and targetAlg is supcon, forAug is None.)
        When we calculate accuracy in train step:
                for item in ['train', 'valid']:     
                    acc_record[item] = np.mean(np.array([modelopera.accuracy(Alg_model, eval_loaders[eval_name_dict[item][task_id]])]))
        datautil.image_train makes the training images as [image0, image1], which is concatenate of two imgutil.image_train transform of the same original image.
        But labels is still [batch_size,]
        '''
        x = torch.cat([data[0][0], data[0][1]]).cuda().float()
        y = torch.cat([data[1],data[1]]).cuda().long()
    else:
        x = data[0].cuda().float()
        y = data[1].cuda().long()
    return x, y
