import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import os

def distrs_compute(tr_values, te_values, tr_labels, te_labels, num_bins=50, log_bins=False, plot_name=None, label_encoder=None):
    if label_encoder is not None:
        num_classes = len(label_encoder.classes_)
    else:
        num_classes = len(set(tr_labels))

    sqr_num = int(np.ceil(np.sqrt(num_classes)))
    tr_distrs, te_distrs, all_bins = [], [], []

    plt.figure(figsize=(15, 15))
    for i in range(num_classes):
        tr_list = tr_values[tr_labels == i]
        te_list = te_values[te_labels == i]

        small_delta = 1e-10
        tr_list = np.clip(tr_list, a_min=small_delta, a_max=None)
        te_list = np.clip(te_list, a_min=small_delta, a_max=None)

        all_list = np.concatenate((tr_list, te_list))
        max_v, min_v = np.max(all_list), np.min(all_list)

        plt.subplot(sqr_num, sqr_num, i + 1)
        if log_bins:
            bins = np.logspace(np.log10(min_v), np.log10(max_v + 1e-8), num_bins + 1)
            plt.gca().set_xscale("log")
        else:
            bins = np.linspace(min_v, max_v + 1e-8, num_bins + 1)

        h1, _, _ = plt.hist(tr_list, bins=bins, weights=np.ones_like(tr_list)/len(tr_list), alpha=0.5, label='Train', color='b')
        h2, _, _ = plt.hist(te_list, bins=bins, weights=np.ones_like(te_list)/len(te_list), alpha=0.5, label='Test', color='r')
        plt.title(f'Class {i}')
        plt.legend()

        tr_distrs.append(h1)
        te_distrs.append(h2)
        all_bins.append(bins)

    if plot_name is None:
        os.makedirs("risk_score_debug", exist_ok=True)
        plot_name = 'risk_score_debug/metric_distribution'
    plt.tight_layout()
    plt.savefig(plot_name + '.png', bbox_inches='tight')
    plt.close()

    return np.stack(tr_distrs), np.stack(te_distrs), np.stack(all_bins)

def find_index(bins, value):
    return np.searchsorted(bins, value, side='right') - 1

def score_calculate(tr_distr, te_distr, ind):
    total = tr_distr[ind] + te_distr[ind]
    if total > 0:
        return tr_distr[ind] / total
    # Fallback: look nearby
    for t_n in range(1, len(tr_distr)):
        for offset in [-t_n, t_n]:
            new_ind = ind + offset
            if 0 <= new_ind < len(tr_distr):
                total = tr_distr[new_ind] + te_distr[new_ind]
                if total > 0:
                    return tr_distr[new_ind] / total
    return 0.5  # fallback to neutral

def risk_score_compute(tr_distrs, te_distrs, all_bins, data_values, data_labels):
    risk_scores = []
    for i in range(len(data_values)):
        c_value = data_values[i]
        c_label = int(data_labels[i])
        if c_label >= len(tr_distrs):
            risk_scores.append(0.5)
            continue

        c_tr_distr = tr_distrs[c_label]
        c_te_distr = te_distrs[c_label]
        c_bins = all_bins[c_label]
        c_index = find_index(c_bins, c_value)
        c_index = np.clip(c_index, 0, len(c_tr_distr) - 1)

        c_score = score_calculate(c_tr_distr, c_te_distr, c_index)
        risk_scores.append(c_score)

    return np.array(risk_scores)

def calculate_risk_score(tr_values, te_values, tr_labels, te_labels, data_values, data_labels, num_bins=50, log_bins=False):
    le = LabelEncoder()
    all_labels = np.concatenate([tr_labels, te_labels, data_labels])
    le.fit(all_labels)

    tr_labels_enc = le.transform(tr_labels)
    te_labels_enc = le.transform(te_labels)
    data_labels_enc = le.transform(data_labels)

    tr_distrs, te_distrs, all_bins = distrs_compute(
        tr_values, te_values, tr_labels_enc, te_labels_enc,
        num_bins=num_bins, log_bins=log_bins, label_encoder=le
    )
    return risk_score_compute(tr_distrs, te_distrs, all_bins, data_values, data_labels_enc)
