import logging
import torch

logger = logging.Logger(__name__)


def get_prf(res):
    # According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure
    if res["TP"] == 0:
        if res["FP"] == 0 and res["FN"] == 0:
            precision = 1.0
            recall = 1.0
            f1 = 1.0
        else:
            precision = 0.0
            recall = 0.0
            f1 = 0.0
    else:
        precision = 1.0 * res["TP"] / (res["TP"] + res["FP"])
        recall = 1.0 * res["TP"] / (res["TP"] + res["FN"])
        f1 = 2 * precision * recall / (precision + recall)

    return precision, recall, f1


def gen_micro_macro_result(res):
    if 'class' in res:
        res = res['class']
    precision = []
    recall = []
    f1 = []
    total = {"TP": 0, "FP": 0, "FN": 0, "TN": 0}
    for a in range(0, len(res)):
        total["TP"] += res[a]["TP"]
        total["FP"] += res[a]["FP"]
        total["FN"] += res[a]["FN"]
        total["TN"] += res[a]["TN"]

        p, r, f = get_prf(res[a])
        precision.append(p)
        recall.append(r)
        f1.append(f)

    micro_precision, micro_recall, micro_f1 = get_prf(total)

    macro_precision = 0
    macro_recall = 0
    macro_f1 = 0
    for a in range(0, len(f1)):
        macro_precision += precision[a]
        macro_recall += recall[a]
        macro_f1 += f1[a]

    macro_precision /= len(f1)
    macro_recall /= len(f1)
    macro_f1 /= len(f1)

    return {
        "micro_precision": round(micro_precision, 3),
        "micro_recall": round(micro_recall, 3),
        "micro_f1": round(micro_f1, 3),
        "macro_precision": round(macro_precision, 3),
        "macro_recall": round(macro_recall, 3),
        "macro_f1": round(macro_f1, 3)
    }


def null_accuracy_function(outputs, label, config, result=None):
    return None

def simple_rank_accuracy(pos, neg, config, result=None):
    if result is None:
        result = {'inst':[], 'total': 0, 'total_acc': 0}
    for a in range(len(pos)):
        result['total'] += 1
        if pos[a] > neg[a]:
            result['inst'].append(1)
            result['total_acc'] += 1
        else:
            result['inst'].append(0)
    return result

def single_label_top1_accuracy(outputs, label, config, result=None):
    if result is None:
        result = {'inst':[], 'class':[], 'total_acc': 0}
    id1 = torch.max(outputs, dim=1)[1]
    # id2 = torch.max(label, dim=1)[1]
    id2 = label
    nr_classes = outputs.size(1)
    while len(result['class']) < nr_classes:
        result['class'].append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})
    for a in range(0, len(id1)):
        # if len(result) < a:
        #    result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})

        it_is = int(id1[a])
        should_be = int(id2[a])
        if it_is == should_be:
            result['class'][it_is]["TP"] += 1
            result['total_acc'] += 1
        else:
            result['class'][it_is]["FP"] += 1
            result['class'][should_be]["FN"] += 1
        result['inst'].append([it_is, should_be])

    return result


def multi_label_accuracy(outputs, label, config, result=None):
    if len(label[0]) != len(outputs[0]):
        raise ValueError('Input dimensions of labels and outputs must match.')

    outputs = outputs.data
    labels = label.data

    if result is None:
        result = {'class':[], 'inst':[], 'total_acc': 0}

    total = 0
    nr_classes = outputs.size(1)

    while len(result['class']) < nr_classes:
        result['class'].append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})

    outputs_res = []
    labels_res = []

    for j in range(outputs.size(0)):
        threshold = min(0, max(outputs[j, :]))
        outputs1 = (outputs[j, :] >= threshold).long()
        labels1 = (labels[j, :].float() >= 0.5).long()
        #print(sum(outputs1), sum(labels1), max(outputs[j, :]))
        #exit(0)
        result['inst'].append([outputs1, labels1])
        outputs_res.append(outputs1)
        labels_res.append(labels1)
        if int((labels1 * outputs1).sum()) == int(labels1.sum()) and int(labels1.sum()) == int(outputs1.sum()):
            result['total_acc'] += 1

    outputs = torch.stack(outputs_res, 0)
    labels = torch.stack(labels_res, 0)

    for i in range(nr_classes):
        outputs1 = (outputs[:, i].float() >= 0.5).long()
        labels1 = (labels[:, i].float() >= 0.5).long()
        total += int((labels1 * outputs1).sum())
        total += int(((1 - labels1) * (1 - outputs1)).sum())

        if result is None:
            continue

        # if len(result) < i:
        #    result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})

        result['class'][i]["TP"] += int((labels1 * outputs1).sum())
        result['class'][i]["FN"] += int((labels1 * (1 - outputs1)).sum())
        result['class'][i]["FP"] += int(((1 - labels1) * outputs1).sum())
        result['class'][i]["TN"] += int(((1 - labels1) * (1 - outputs1)).sum())

    return result


def single_label_top2_accuracy(outputs, label, config, result=None):
    raise NotImplementedError
    # still bug here

    if result is None:
        result = []
        # print(label)

    id1 = torch.max(outputs, dim=1)[1]
    # id2 = torch.max(label, dim=1)[1]
    id2 = label
    nr_classes = outputs.size(1)
    while len(result) < nr_classes:
        result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})
    for a in range(0, len(id1)):
        # if len(result) < a:
        #    result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0})

        it_is = int(id1[a])
        should_be = int(id2[a])
        if it_is == should_be:
            result[it_is]["TP"] += 1
        else:
            result[it_is]["FP"] += 1
            result[should_be]["FN"] += 1

    _, prediction = torch.topk(outputs, 2, 1, largest=True)
    prediction1 = prediction[:, 0:1]
    prediction2 = prediction[:, 1:]

    prediction1 = prediction1.view(-1)
    prediction2 = prediction2.view(-1)

    return result
