import matplotlib.pyplot as plt
from sklearn.metrics import f1_score


def compute_integral_with_mean(y, mean):
    if len(y) == 1:
        return y[0]
    x = [i/(len(y)-1) for i in range(len(y))]
    f = [value - mean for value in y]
    delta_x = x[1]-x[0]
    return delta_x*(2*sum(f) - f[0] - f[-1])

def compute_metrics(y_full_bias, y, mean):
    return abs(compute_integral_with_mean(y, mean)/compute_integral_with_mean(y_full_bias, mean))



"""
    The granularity parameter is used to approximate the metric when the labels list is very long.
    This reduces the amount of computation needed by approximating the performance curves.
"""

def temporal_bias_binary(labels, granularity=100):
    
    bias_labels = [0 for _ in range(len(labels) - sum(labels))] + [1 for _ in range(sum(labels))]
    
    y_full_bias = []
    y_neg_then_pos = []
    y_pos_then_neg = []
    if granularity > 10*len(labels):
        granularity = int(granularity/10)
    if int(len(labels)/granularity) == 0:
        granularity = 1
    for i in range(0,len(labels), int(len(labels)/granularity)):
        y_pred = [0 for _ in range(i+1)] + [1 for _ in range(len(labels)-(i+1))]
        y_neg_then_pos.append(f1_score(labels, y_pred, average="macro"))
        y_pred = [1 for _ in range(i+1)] + [0 for _ in range(len(labels)-(i+1))]
        y_pos_then_neg.append(f1_score(labels, y_pred, average="macro"))
        y_pred = [0 for _ in range(i+1)] + [1 for _ in range(len(labels)-(i+1))]
        y_full_bias.append(f1_score(bias_labels, y_pred, average="macro"))

    mean = (y_full_bias[0]+y_full_bias[-1])/2
    if (y_full_bias[0] > y_full_bias[-1] and y_neg_then_pos[0] > y_neg_then_pos[-1]) or (y_full_bias[0] < y_full_bias[-1] and y_neg_then_pos[0] < y_neg_then_pos[-1]):
        return compute_metrics(y_full_bias, y_neg_then_pos, mean)
    else:
        return compute_metrics(y_full_bias, y_pos_then_neg, mean)


def temporal_bias_binary_with_plot(labels, granularity=100):
    
    bias_labels = [0 for _ in range(len(labels) - sum(labels))] + [1 for _ in range(sum(labels))]
    
    y_full_bias = []
    y_neg_then_pos = []
    y_pos_then_neg = []
    if granularity > len(labels):
        granularity = len(labels)
    if int(len(labels)/granularity) == 0:
        granularity = 1
    for i in range(0,len(labels), int(len(labels)/granularity)):
        y_pred = [0 for _ in range(i+1)] + [1 for _ in range(len(labels)-(i+1))]
        y_neg_then_pos.append(f1_score(labels, y_pred, average="macro"))
        y_pred = [1 for _ in range(i+1)] + [0 for _ in range(len(labels)-(i+1))]
        y_pos_then_neg.append(f1_score(labels, y_pred, average="macro"))
        y_pred = [0 for _ in range(i+1)] + [1 for _ in range(len(labels)-(i+1))]
        y_full_bias.append(f1_score(bias_labels, y_pred, average="macro"))

    x = [i/(len(y_full_bias) - 1) for i in range(len(y_full_bias))]

    mean = (y_full_bias[0]+y_full_bias[-1])/2
    if (y_full_bias[0] > y_full_bias[-1] and y_neg_then_pos[0] > y_neg_then_pos[-1]) or (y_full_bias[0] < y_full_bias[-1] and y_neg_then_pos[0] < y_neg_then_pos[-1]):
        tau = compute_metrics(y_full_bias, y_neg_then_pos, mean)
        plt.plot(x, y_neg_then_pos, label="y_neg_then_pos")
        plt.plot(x, y_full_bias, label="Fully biased system")
        plt.title(f"tau={tau}")
        plt.legend()
        plt.show()
        return tau
    else:
        tau = compute_metrics(y_full_bias, y_pos_then_neg, mean)
        plt.plot(x, y_pos_then_neg, label="y_neg_then_pos")
        plt.plot(x, y_full_bias, label="Fully biased system")
        plt.title(f"tau={tau}")
        plt.legend()
        plt.show()
        return tau



def temporal_bias(labels, granularity=100):
    
    label_to_id = {}
    for label in labels:
        label_to_id[label] = label_to_id.get(label, len(label_to_id))
    unique_labels = list(label_to_id.keys())
    
    assert all(type(label)==int for label in unique_labels), "Some given labels are not of type 'int' (text labels are not supported)"
    
    possible_labels = list(range(max(labels) + 1))
    
    if len(possible_labels) < 2:
        print(f"Not enough labels to compute bias (no label or only one label is given, got {len(possible_labels)})")
        return 0
    if len(possible_labels) == 2:
        return temporal_bias_binary(labels, granularity)
    
    partial_biases = []
    for label_1 in possible_labels:
        for label_2 in possible_labels[label_1+1:]:
            labels_for_partial_bias = []
            for label in labels:
                if label == label_1:
                    labels_for_partial_bias.append(0)
                elif label == label_2:
                    labels_for_partial_bias.append(1)
            partial_biases.append(temporal_bias_binary(labels_for_partial_bias, granularity))
    
    invert_partial_biases = [1/element for element in partial_biases]
    return len(invert_partial_biases)/sum(invert_partial_biases)
    