import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from network import *
from utils.tools import *

def compute_result(dataloader, net, device, uct = False):
    bs, clses, entropy, prob = [], [], [], []
    net.eval()
    for img, cls, _ in tqdm(dataloader):
    # for img, cls, _ in dataloader:
        clses.append(cls)
        bs.append((net(img.to(device))).data.cpu())
        if uct == True:
            entropy.append(torch.sum(Bernoulli(probs = torch.sigmoid(net(img.to(device))).data.cpu()).entropy(), axis = 1))
            prob.append(torch.sigmoid(net(img.to(device))).data.cpu())
    if uct == True:
        return torch.cat(bs).sign(), torch.cat(clses), torch.cat(entropy), torch.cat(prob)
    else:
        return torch.cat(bs).sign(), torch.cat(clses)

def unique_hashcode(hash_code):
    return np.unique(hash_code.detach().cpu().numpy(), axis = 0).shape[0]

def entrywise_fairness(hash_code):
    return np.sum((hash_code.detach().cpu().numpy() == 1), axis = 0)/hash_code.detach().cpu().numpy().shape[0]

def balance_analysis(train_loader, net):
    train_binary, train_label = compute_result(train_loader, net, device="cuda", uct = False)
    print("Entry-wise fairness:\n", sum(abs(entrywise_fairness(train_binary) - 0.5)))
    print("Total Number of training data:", train_binary.size()[0], "\nNumber of unique hashcode:", unique_hashcode(train_binary))
    print("Total Number of unique label:", np.unique(train_label.detach().cpu().numpy(), axis = 0).shape[0])