# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
from network import act_network
import numpy as np
import sklearn.metrics as metric

def get_fea(args):
    net = act_network.ActNetwork(args.dataset)
    if args.use_freq:
        print('Using frequency network')
        net = act_network.ActFreqNetwork(args.dataset, args.mask_spectrum, args.freq_type)
    return net

def accuracy(network, loader, weights, usedpredict='p'):
    correct = 0
    total = 0
    weights_offset = 0

    network.eval()
    ys = []
    ypred = []
    feats = []
    device = network.device
    with torch.no_grad():
        for data in loader:
            x = data[0].to(device).float()
            y = data[1].to(device).long()
            ys.append(y)
            # if usedpredict == 'p':
            #     p, feat = network.predict2(x)
            #     feats.append(feat)
            # else:
            p = network.predict(x)
            if weights is None:
                batch_weights = torch.ones(len(x))
            else:
                batch_weights = weights[weights_offset:
                                        weights_offset + len(x)]
                weights_offset += len(x)
            batch_weights = batch_weights.to(device)

            if p.size(1) == 1:
                ypred.append(p.gt(0))
                correct += (p.gt(0).eq(y).float() *
                            batch_weights.view(-1, 1)).sum().item()
            else:
                ypred.append(p.argmax(1))
                correct += (p.argmax(1).eq(y).float() *
                            batch_weights).sum().item()
            total += batch_weights.sum().item()
    memory_footprint = torch.cuda.memory_allocated() / (1024 ** 2)  # calculate GPU memory MB
    print(f"Memory footprint on GPU: {memory_footprint} MB")

    network.train()

    y = torch.cat(ys, 0)
    pred = torch.cat(ypred, 0)
    # feat = torch.cat(feats, 0)
    
    # np.save('Diversify_UCIHAR_features.npy', feat.cpu().numpy())
    # np.save('Diversify_UCIHAR_labels.npy', y.cpu().numpy())
    # print(y.cpu().numpy().shape, pred.cpu().numpy().shape)
    precision, recall, f1, support  = metric.precision_recall_fscore_support(y.cpu().numpy(), pred.cpu().numpy(), average='macro') 
    # print('precision: ', precision, 'recall: ', recall, 'f1: ', f1)
    return correct / total, [precision, recall, f1]
