import numpy as np
import tensorflow as tf
from sklearn.metrics import roc_auc_score, average_precision_score,  roc_curve, f1_score
import matplotlib.pyplot as plt

from src.metrics import compute_auroc, compute_aurc

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def temp_scaling_nll(logits, y, iters=300):
    logits = tf.convert_to_tensor(logits, dtype=tf.float32, name='logits')
    temp = tf.Variable(initial_value=1.0, trainable=True, dtype=tf.float32) 
    def compute_temp_loss():
        y_pred_model_w_temp = tf.math.divide(logits, temp)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(tf.convert_to_tensor(tf.keras.utils.to_categorical(y)), y_pred_model_w_temp))
        return loss
    optimizer = tf.optimizers.Adam(learning_rate=0.01)
    for i in range(iters):
        opts = optimizer.minimize(compute_temp_loss, var_list=[temp])
    opt_temp = temp.numpy()
    return opt_temp

def temp_scaling_auroc(scores, pred_y, true_label):
    auroc = 0
    wait = 0
    opt_temp = None
    for temp in [i / 100 for i in range(1, 501, 1)]:
        val_class = np.array([softmax(x / temp) for x in scores])
        val = np.array([value[pred_value] for value, pred_value in zip(val_class, pred_y)])
        auroc_val = compute_auroc(true_label, val)
        if auroc_val > auroc:
            auroc = auroc_val
            opt_temp = temp
            wait = 0
        else:
            wait += 1
        if wait > 20:
            break
    return opt_temp

def temp_scaling_aurc(scores, pred_y, true_label):
    aurc = np.inf
    wait = 0
    opt_temp = None
    for temp in [i / 100 for i in range(1, 501, 1)]:
        val_class = np.array([softmax(x / temp) for x in scores])
        val = np.array([value[pred_value] for value, pred_value in zip(val_class, pred_y)])
        aurc_val = compute_aurc(true_label, val)
        if aurc_val < aurc:
            aurc = aurc_val
            opt_temp = temp
            wait = 0
        else:
            wait += 1
        if wait > 20:
            break
    return opt_temp

def compute_opt_threshold(metric, true_label):
    sorted_metric = np.sort(metric)
    best_threshold = 0.0
    best_acc = 0.0
    for threshold in sorted_metric:
        metric_label = (metric >= threshold).astype(int)
        correct_samples = np.sum(true_label == metric_label)
        filtering_acc = (correct_samples / len(true_label)) * 100
        if filtering_acc > best_acc:
            best_acc = filtering_acc
            best_threshold = threshold
    return best_threshold

def compute_filtering_pr(metric, true_label, threshold):
    pred_labels = (metric <= threshold).astype(int)
    true_positives = np.sum((pred_labels == 1) & (true_label == 1))
    false_positives = np.sum((pred_labels == 1) & (true_label == 0))
    false_negatives = np.sum((pred_labels == 0) & (true_label == 1))
    precision = true_positives / (true_positives + false_positives) * 100 if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) * 100 if (true_positives + false_negatives) > 0 else 0
    return precision, recall

def reliability_diagram(metric, true_y, pred_y, n_bins):
    bin_size = 1.0 / n_bins
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    indices = np.digitize(metric, bins, right=True)

    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_counts = np.zeros(n_bins, dtype=np.int32)

    for b in range(n_bins):
        selected = np.where(indices == b + 1)[0]
        if len(selected) > 0:
            bin_acc[b] = np.mean(true_y[selected] == pred_y[selected])
            bin_conf[b] = np.mean(metric[selected])
            bin_counts[b] = len(selected)

    avg_acc = np.sum(bin_acc * bin_counts) / np.sum(bin_counts)
    avg_conf = np.sum(bin_conf * bin_counts) / np.sum(bin_counts)

    gaps = np.abs(bin_acc - bin_conf)
    ece = np.sum(gaps * bin_counts) / np.sum(bin_counts) * 100
    
    result = {
        'bins' : bins,
        'bin_size' : bin_size,
        'bin_counts' : bin_counts,
        'bin_acc' : bin_acc,
        'bin_conf' : bin_conf,
        'avg_acc' : avg_acc,
        'avg_conf' : avg_conf,
        'gaps' : gaps,
        'ece' : ece
    }
    
    return result

def plot_reliability_diagram(result, metric):
    positions = result['bins'][:-1] + result['bin_size']/2.0

    fig, axs = plt.subplots(2, 1, figsize=(4,5), dpi=100, sharex=True, gridspec_kw={'height_ratios': [1,0.5]})

    axs[0].bar(positions,
            result['bin_acc'],
            width=result['bin_size'],
            edgecolor='black',
            color='blue',
            linewidth=1,
            label='Accuracy')

    axs[0].bar(positions,
            result['gaps'],
            bottom=np.minimum(result['bin_acc'], result['bin_conf']),
            width=result['bin_size'],
            edgecolor='black',
            color='red',
            linewidth=1,
            hatch="//",
            label='Gap')

    axs[0].text(0.7, 0.1,
                f"ECE={result['ece']:.2f}",
                color="black",
                bbox=dict(facecolor='white', alpha=0.5),
                fontsize=11)

    axs[0].plot([0,1], [0,1], linestyle = "--", color="gray")
    axs[0].set_xlim(0,1)
    axs[0].grid(True, alpha=0.3)
    axs[0].grid(zorder=0)
    axs[0].set_ylabel('Accuracy', fontsize=11)
    axs[0].legend()
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

    axs[1].bar(positions,
               result['bin_counts'] / np.sum(result['bin_counts']),
               width=result['bin_size'],
               edgecolor='black',
               color='tab:orange')

    axs[1].axvline(x=result['avg_acc'], linestyle="--", color="blue", label='Average Accuracy')
    axs[1].axvline(x=result['avg_conf'], linestyle='--', color='red', label='Average Confidence')

    axs[1].grid(True)
    axs[1].set_xticks(np.linspace(0.0, 1.0, 11))
    axs[1].set_xlabel(metric, fontsize=11)
    axs[1].set_ylabel('% of Samples', fontsize=11)
    axs[1].legend()

    plt.tight_layout()
    plt.show()

def compute_classification_error(metric, true_y):
    count = 0
    for i in range(len(metric)):
        pred = np.argmax(metric[i])
        if true_y[i] == pred:
            count += 1
    error = 1 - count/len(metric)
    return error

def compute_selective_pred_acc(metric, dataset, model, n):
    x = np.array([x for x, y in dataset])
    y = np.array([y for x, y in dataset])
    idx = np.argsort(metric)[n:]
    eval_ds = model.evaluate(x[idx], y[idx], verbose=0)
    return eval_ds[1]

def calculate_confidence_interval(data, confidence=0.95):
    n = data.shape[0]
    mean = np.mean(data, axis=0)
    se = np.std(data, axis=0) / np.sqrt(n)
    h = se * 1.96  # For 95% confidence interval
    return mean, h

def create_percentile_range(min_value, max_value, percentiles=100):
    step = (max_value - min_value) / (percentiles + 1)
    percentile_values = [min_value + step * i for i in range(1, percentiles + 1)]
    return percentile_values

