import numpy as np
import scipy.stats
from statsmodels.nonparametric.smoothers_lowess import lowess

# ==============================================================================
# 1. IMPLEMENTATION OF CALIBRATION METRICS
# ==============================================================================

def calculate_ici_survival(y_true, y_pred_survival, evaluation_time):
    """
    Calculates the Integrated Calibration Index (ICI) for survival data at a specific time.

    ICI is the weighted mean of the absolute difference between a smoothed calibration 
    curve (using LOWESS) and the ideal 45-degree line. Lower is better.

    Args:
        y_true (sksurv.util.Surv): Structured array of true event times and statuses.
        y_pred_survival (np.ndarray): 1D array of predicted survival probabilities at evaluation_time.
        evaluation_time (float): The specific time point t to evaluate calibration.

    Returns:
        float: The ICI score. Returns NaN if not enough data to compute.
    """
    # Filter out subjects censored before the evaluation time
    valid_idx = (y_true["time"] >= evaluation_time) | (y_true["event"])
    
    if np.sum(valid_idx) < 2:  # Not enough data to compute
        return np.nan

    y_true_valid = y_true[valid_idx]
    y_pred_valid = y_pred_survival[valid_idx]

    # Create binary outcome: 1 if survived past evaluation_time, 0 otherwise
    true_outcome = (y_true_valid["time"] > evaluation_time).astype(int)

    # Use LOWESS to smooth the calibration curve
    # lowess returns 2 arrays: sorted predictions and corresponding smoothed outcomes
    smoothed_curve = lowess(true_outcome, y_pred_valid, is_sorted=False, frac=0.6, it=0)
    
    # ICI is the mean absolute difference between smoothed outcomes and predictions
    ici = np.mean(np.abs(smoothed_curve[:, 1] - smoothed_curve[:, 0]))
    
    return ici


def calculate_d_calibration(y_true, y_pred_survival, evaluation_time, n_bins=10):
    """
    Calculates the D-Calibration statistic and p-value at a specific time.

    This test checks if the predicted survival probabilities for subjects who had an 
    event are uniformly distributed. A non-significant p-value (e.g., > 0.05) 
    suggests the model is well-calibrated.

    Args:
        y_true (sksurv.util.Surv): Structured array of true event times and statuses.
        y_pred_survival (np.ndarray): 1D array of predicted survival probabilities.
        evaluation_time (float): The time point t to evaluate calibration.
        n_bins (int): Number of bins to group probabilities.

    Returns:
        tuple[float, float]: The D-Calibration statistic and the corresponding p-value.
    """
    # Select subjects who had an event at or before the evaluation time
    event_idx = (y_true["time"] <= evaluation_time) & y_true["event"]
    
    if np.sum(event_idx) < n_bins: # Not enough events to perform the test
        return np.nan, np.nan
        
    # Get the survival predictions for these subjects
    event_preds = y_pred_survival[event_idx]
    
    # Bin the predictions
    observed_counts, _ = np.histogram(event_preds, bins=n_bins, range=(0, 1))
    
    n_events = len(event_preds)
    expected_count = n_events / n_bins
    
    # Calculate the chi-squared-like statistic
    statistic = np.sum((observed_counts - expected_count)**2) / expected_count

    # The p-value from the chi-squared distribution
    p_value = scipy.stats.chi2.sf(statistic, df=n_bins - 1)
    
    return statistic, p_value


def calculate_ece_survival(y_true, y_pred_survival, evaluation_time, n_bins=10):
    """
    Calculates the L1 Expected Calibration Error (ECE) for survival data at a specific time.

    ECE measures the difference between expected accuracy and observed accuracy. 
    It's the weighted average of the absolute difference between the mean predicted 
    probability and the fraction of positive outcomes in each bin. Lower is better.

    Args:
        y_true (sksurv.util.Surv): Structured array of true event times and statuses.
        y_pred_survival (np.ndarray): 1D array of predicted survival probabilities.
        evaluation_time (float): The time point t to evaluate calibration.
        n_bins (int): Number of bins to group probabilities.

    Returns:
        float: The ECE score.
    """
    # Filter out subjects censored before the evaluation time
    valid_idx = (y_true["time"] >= evaluation_time) | (y_true["event"])
    
    if np.sum(valid_idx) < 2:
        return np.nan

    y_true_valid = y_true[valid_idx]
    y_pred_valid = y_pred_survival[valid_idx]
    
    # Create binary outcome
    true_outcome = (y_true_valid["time"] > evaluation_time).astype(int)
    
    # Bin predictions
    bin_limits = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(y_pred_valid, bin_limits[1:-1])
    
    ece = 0.0
    total_samples = len(y_pred_valid)
    
    for i in range(n_bins):
        in_bin = (bin_indices == i)
        prop_in_bin = np.mean(in_bin)
        
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(true_outcome[in_bin])
            confidence_in_bin = np.mean(y_pred_valid[in_bin])
            ece += np.abs(accuracy_in_bin - confidence_in_bin) * prop_in_bin
            
    return ece
