import torch
from tqdm import tqdm
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
random.seed(222)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def l2_norm(input, axis = 1):
    # normalizes input with respect to second norm
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output


def evaluate(dataloader, criterion, backbone, head, emb_size,  k_accuracy = False, multilabel_accuracy = False,
             demographic_to_labels = None):

    loss = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
    acc = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
    count = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}
    acc_k = {k:torch.tensor(0.0) for k in demographic_to_labels.keys()}

    backbone.eval()
    if multilabel_accuracy:
        head.eval()

    feature_matrix = torch.empty(0, emb_size)
    labels_all = []
    demographic_all = []

    for inputs, labels, sens_attr in tqdm(iter(dataloader)):
        inputs = inputs.to(device)
        labels = labels.to(device).long()
        sens_attr = np.array(sens_attr)
        with torch.no_grad():

            if multilabel_accuracy:
                features = backbone(inputs)
                outputs = head(features, labels)
                loss_value = criterion(outputs, labels)

                # add sum of losses for female and male images
                for k in loss.keys():
                    loss[k] += loss_value[sens_attr == k].sum().cpu()

                # multiclass accuracy
                _, predicted = outputs.max(1)
                for k in acc.keys():
                    acc[k] +=  predicted[sens_attr == k].eq(labels[sens_attr == k]).sum().cpu().item()

                for k in count.keys():
                    count[k] += sum(sens_attr == k)

            if k_accuracy:
                ###need to build feature matrix
                inputs_flipped = torch.flip(inputs, [3])
                embed = backbone(inputs) + backbone(inputs_flipped)
                features_batch = l2_norm(embed)
                feature_matrix = torch.cat((feature_matrix, features_batch.detach().cpu()), dim = 0)

                labels_all = labels_all + labels.cpu().tolist()
                demographic_all = demographic_all + sens_attr.tolist()


    if multilabel_accuracy:
        for k in acc.keys():
            acc[k] = acc[k]/count[k].item()
            loss[k] = loss[k]/count[k].item()

    if k_accuracy:
        acc_k, correct = predictions(feature_matrix, torch.tensor(labels_all), demographic_to_labels, feature_matrix, torch.tensor(labels_all), np.array(demographic_all))

    return loss, acc, acc_k, correct, labels_all, demographic_all




def l2_dist(feature_matrix, test_features):
    ''' computing distance matrix '''
    return torch.cdist(test_features, feature_matrix)


def predictions(feature_matrix, labels, demographic_to_labels, test_features, test_labels, test_demographic):
    dist_matrix =  l2_dist(feature_matrix, test_features)
    acc_k = {k:0 for k in demographic_to_labels.keys()}
    nearest_neighbors = torch.topk(dist_matrix, dim=1, k = 2, largest = False)[1][:,1]
    n_images = dist_matrix.shape[0]
    correct = torch.zeros(test_labels.shape)

    for img in range(n_images):
        nearest_label = labels[nearest_neighbors[img]].item()
        label_img = test_labels[img].item()
        if label_img == nearest_label:
            correct[img] = 1
    for k in acc_k.keys():
        acc_k[k] = (correct[test_demographic == k]).mean()

    return acc_k, correct
