import numpy as np
import torch


def otsu_thresholding_1d(data):
    """
    Applies Otsu's thresholding method to a 1D data vector and calculates intra-class variance.

    Parameters:
    data (np.ndarray): 1D data vector.

    Returns:
    tuple: Optimal threshold, intra-class variances.
    """
    # Ensure data is a numpy array
    data = np.asarray(data)

    # Compute histogram of the data
    hist, bin_edges = np.histogram(data, bins=256)

    # Calculate probability of each bin
    pixel_total = data.size
    probability = hist / pixel_total

    # Calculate cumulative sums and cumulative means
    cumulative_sum = np.cumsum(probability)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    cumulative_mean = np.cumsum(probability * bin_centers)

    # Global mean
    global_mean = cumulative_mean[-1]

    # Class means
    mean_1 = cumulative_mean / (cumulative_sum + 1e-10)
    mean_2 = (global_mean - cumulative_mean) / (1 - cumulative_sum + 1e-10)

    # Class variances
    variance_1 = (np.cumsum(probability * (bin_centers - mean_1) ** 2) / (cumulative_sum + 1e-10))
    variance_2 = \
                ((np.cumsum(probability[::-1] * (bin_centers[::-1] - mean_2) ** 2)[::-1]) / (1 - cumulative_sum + 1e-10))

    ##    inter_class_variance = (global_mean * cumulative_sum - cumulative_mean) ** 2 / (cumulative_sum * (1 - cumulative_sum) + 1e-10)
    inter_class_variance = cumulative_sum * (1 - cumulative_sum) * (mean_1 - mean_2) ** 2

    # Intra-class variance
    intra_class_variance = cumulative_sum * variance_1 + (1 - cumulative_sum) * variance_2

    # Optimal threshold is the one that minimizes intra-class variance
    optimal_threshold_index = np.argmin(intra_class_variance)
    ##    optimal_threshold = bin_centers[optimal_threshold_index]

    return intra_class_variance[optimal_threshold_index], inter_class_variance[optimal_threshold_index]


def otsu_thresholding_unbiased(data):
    """
    Applies a vectorized Otsu's thresholding method directly on a 1D data vector without using a for loop.
    This implementation uses the unbiased variance calculation.

    Parameters:
    data (np.ndarray): 1D data vector.

    Returns:
    tuple: Optimal threshold index, inter-class variances, intra-class variances.
    """
    # Ensure data is a numpy array and sort it
    data = np.sort(np.asarray(data))[::-1]

    # Compute cumulative sums
    cumulative_sum = np.cumsum(data)

    # Total sum and total mean
    total_sum = cumulative_sum[-1]
    total_mean = total_sum / data.size

    # Class weights
    weights_1 = np.arange(1, data.size) / data.size
    weights_2 = 1 - weights_1

    # Means of the two classes
    means_1 = cumulative_sum[:-1] / np.arange(1, data.size)
    means_2 = (total_sum - cumulative_sum[:-1]) / np.arange(data.size - 1, 0, -1)

    # Unbiased variances of the two classes
    term1 = np.cumsum(data ** 2)
    variance_1 = (term1[:-1] - cumulative_sum[:-1] ** 2 / np.arange(1, data.size)) / (
                np.arange(1, data.size) - 1 + 1e-15)
    variance_2 = ((term1[-1] - term1[:-1]) - (total_sum - cumulative_sum[:-1]) ** 2 / np.arange(data.size - 1, 0,
                                                                                                -1)) / (
                             np.arange(data.size - 1, 0, -1) - 1 + 1e-15)

    # Replace NaN values with zero (which can occur if a class has only one element)
    variance_1[np.isnan(variance_1)] = 0
    variance_2[np.isnan(variance_2)] = 0

    # Intra-class variance
    intra_class_variance = weights_1 * variance_1 + weights_2 * variance_2

    # Inter-class variance
    inter_class_variance = weights_1 * weights_2 * (means_1 - means_2) ** 2

    # Optimal threshold index is where the inter-class variance is maximized (or intra-class variance is minimized)
    optimal_threshold_index = np.argmax(inter_class_variance)

    return optimal_threshold_index, intra_class_variance, inter_class_variance


def get_intra_inter(group_1, group_2):
    mean_1 = np.mean(group_1)
    mean_2 = np.mean(group_2)

    # Calculate the variances within each group (Intra-class variance)
    variance_1 = np.var(group_1, ddof=1)  # Using ddof=1 for sample variance
    variance_2 = np.var(group_2, ddof=1)

    # Calculate the number of elements in each group
    n1 = len(group_1)
    n2 = len(group_2)

    # Intra-class variance (average variance weighted by the number of elements)
    intra_class_variance = (n1 * variance_1 + n2 * variance_2) / (n1 + n2)

    # Calculate the overall mean of both groups combined
    overall_mean = np.mean(np.concatenate((group_1, group_2)))

    # Inter-class variance
    inter_class_variance = (n1 * (mean_1 - overall_mean) ** 2 + n2 * (mean_2 - overall_mean) ** 2) / (n1 + n2)

    # Output the results
    return intra_class_variance, inter_class_variance


def filter_generated_max(val_x, gen_x):
    res1 = torch.cdist(torch.Tensor(val_x), torch.Tensor(val_x)).numpy()
    res1 = res1 + np.eye(val_x.shape[0])*np.max(res1)
    d = np.min(res1, axis=0)
    d = np.clip(d, 0, 20)
    d = np.max(d)
    gen_d = np.min(np.linalg.norm(gen_x[:, None, :] - val_x[None,:, :], axis=-1), axis=-1)
    return gen_x[gen_d>d]