import os
import json
import numpy as np
import torch
from tqdm import tqdm

def pairwise_f1(y_true, y_preds):
    '''
    y_true shape a*s
    y_preds shape f*s
    '''

    y_true = y_true.unsqueeze(1).cpu().to(torch.float32)  # a*1*s
    y_preds = y_preds.unsqueeze(0).cpu().to(torch.float32)  # 1*f*s
#    print(y_true)

    tp = (y_true * y_preds).to(torch.float32).sum(dim=-1)  # a*f
    fp = ((1 - y_true) * y_preds).to(torch.float32).sum(dim=-1)
    fn = (y_true * (1 - y_preds)).to(torch.float32).sum(dim=-1)
    epsilon = 1e-7
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)
    f1 = 2 * (precision * recall) / (precision + recall + epsilon)
    return f1


def _pairwise_f1(y_true, y_preds , threshold=0.1):
    '''
    y_true shape a*s
    y_preds shape f*s
    '''
    y_true = y_true.unsqueeze(1).cpu().to(torch.float32)  # a*1*s
    y_preds = y_preds.unsqueeze(0).cpu().to(torch.float32)  # 1*f*s
    tp = ((y_true == 1) & (y_preds == 1)).sum(-1)
    tn = ((y_true == 0) & (y_preds == 0)).sum(-1)
    fn = ((y_true == 1) & (y_preds == 0)).sum(-1)
    fp = ((y_true == 0) & (y_preds == 1)).sum(-1)

    epsilon = 1e-7

    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    f1 = (2 * precision * recall) / (precision + recall + epsilon)
    return f1

def features_to_binary_features(features, bits, neg_features=True):
    binary_features = []
    unique = list(np.unique(features.detach().cpu().numpy()))
    for c in range(2 ** bits):
        pos_feature = (features == float(c))
        binary_features.append(pos_feature)
        if neg_features:
            neg_feature = (features != c)
            binary_features.append(neg_feature)
    binary_features = torch.cat(binary_features, dim=1)
    return binary_features

def get_image_properties(paths):
    im2attr_path = r'/home/chaimb/AmeenNips21/Ours/datasets/VOC2008'
    import pickle
    attr_cnt = 64
    image_to_attrs = pickle.load(open(os.path.join(im2attr_path, 'pascal.pickle'), 'rb'))
    image_cnt = len(paths)
    attributes = np.zeros((attr_cnt, image_cnt))
    image_to_index = {}
    image_counter = 0

    for image in paths:
        image = image.split('/')[-1]
        i = 0
        for att in image_to_attrs[image]:
            att_idx = i
            attributes[att_idx, image_counter] = int(image_to_attrs[image][att])
            i += 1
        image_counter += 1
    return image_to_index, attributes

def attribute_f1_seq(attributes, binary_features):
    f1_scores = []
    sum = 0
    assert ((binary_features == 1) | (binary_features == 0)).all()
    for att in tqdm(attributes):
        f1_score  = pairwise_f1(att.unsqueeze(0), binary_features)
        f1_score_neg  = pairwise_f1(att.unsqueeze(0), ~binary_features )
        f1_scores.append(max(f1_score.max(), f1_score_neg.max()))
    f1_scores = torch.stack(f1_scores)
    f1_scores2 = []
    attributes = attributes.type(torch.ByteTensor)
    assert ((binary_features == 1) | (binary_features == 0)).all()
    for att in tqdm(binary_features):
        _f1_score  = pairwise_f1(att.unsqueeze(0), attributes)
        _f1_score_neg  = pairwise_f1(att.unsqueeze(0), ~attributes )
        f1_scores2.append(max(_f1_score.max(), _f1_score_neg.max()))
    f1_scores2 = torch.stack(f1_scores2)

    f1_1 = f1_scores.mean()
    f1_2 = f1_scores2.mean()
    cross_f1 = (2 * f1_1 * f1_2) / (f1_1 + f1_2)
    return cross_f1
