#coding=utf-8

import os,sys
import torch
from network import act_network
import torch.nn.functional as F


def get_fea(args):
    if args.task=='cross_people':
        ttsk=args.dataset
    elif args.task=='cross_position':
        ttsk='p'+args.dataset
    elif args.task=='cross_dataset':
        ttsk=args.task
    # else:
    #     raise NotImplementedError

    if args.task=='cross_people':
        if args.model_size=='small':
            net=act_network.SActNetwork(args.dataset)
        elif args.model_size=='large':
            net=act_network.LActNetwork(args.dataset)
        else:
            net=act_network.ActNetwork(args.dataset)
    elif args.task=='cross_position':
        net=act_network.ActNetwork('p'+args.dataset)
    elif args.task=='cross_dataset':
        net=act_network.ActNetwork(args.task)

    return net

def accuracy(network, loader, weights,usedpredict='p'):
    correct = 0
    total = 0
    weights_offset = 0

    network.eval()
    with torch.no_grad():
        for data in loader:
            x = data[0].cuda().float()
            # print(x.shape)
            y = data[1].cuda().long()
            if usedpredict=='p':
                p = network.predict(x)
            else:
                p=network.predict1(x)
            if weights is None:
                batch_weights = torch.ones(len(x))
            else:
                batch_weights = weights[weights_offset : weights_offset + len(x)]
                weights_offset += len(x)
            batch_weights = batch_weights.cuda()
            if p.size(1) == 1:
                correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item()
            else:
                correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item()
            total += batch_weights.sum().item()
    network.train()

    return correct / total

def accuracy_loss(network, loader, weights,usedpredict='p'):
    correct = 0
    total = 0
    weights_offset = 0

    network.eval()
    with torch.no_grad():
        for data in loader:
            x = data[0].cuda().float()
            # print(x.shape)
            y = data[1].cuda().long()
            if usedpredict=='p':
                p = network.predict(x)
            else:
                p=network.predict1(x)
            loss = F.cross_entropy(p, y).item()*len(x)
            if weights is None:
                batch_weights = torch.ones(len(x))
            else:
                batch_weights = weights[weights_offset : weights_offset + len(x)]
                weights_offset += len(x)
            batch_weights = batch_weights.cuda()
            if p.size(1) == 1:
                correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item()
            else:
                correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item()
            total += batch_weights.sum().item()
    network.train()

    return correct / total,loss/total
