"""Utilities for computing confidence intervals."""

import numpy as np
import polars as pl
from scipy import stats


def clopper_pearson(df, min_return, max_return, beta=0.05):
    """
    Computes Clopper-Pearson confidence intervals for binomial success.

    Args:
        df: Polars dataframe with task_id, episode_id, total_return columns
        min_return: Minimum possible return (float)
        max_return: Maximum possible return (float)
        beta: Significance level (float), default 0.05

    Returns:
        lower_bound: Lower bound of the confidence interval (float or array-like)
    """
    assert min_return == 0.0 and max_return == 1.0, (
        "Clopper-Pearson is only valid for Bernoulli returns in [0, 1]"
    )
    # Compute number of successes and trials per task
    stats_df = (
        df.group_by("task_id")
        .agg(
            [
                pl.col("total_return").sum().alias("k_successes"),
                pl.col("episode_id").count().alias("n_trials"),
            ]
        )
        .sort("task_id")
    )

    # Ensure inputs are numpy arrays for vectorized operations
    k = np.array(stats_df["k_successes"])
    n = np.array(stats_df["n_trials"])

    # --- Lower bound ---
    # Use beta.ppf(alpha, k, n - k + 1)
    # The ppf (Percent Point Function) is the inverse of the CDF.
    # We use np.nan_to_num to handle the edge case k=0, which returns NaN
    lower = stats.beta.ppf(beta, k, n - k + 1)
    lower = np.nan_to_num(lower, nan=0.0)

    return pl.DataFrame(
        {
            "task_id": stats_df["task_id"],
            "mean_return": stats_df["k_successes"] / stats_df["n_trials"],
            "lower_bound": lower,
        }
    )


def hoeffding(df, min_return, max_return, beta=0.05):
    """
    Computes Hoeffding confidence intervals for real-valued returns.

    Args:
        df: Polars dataframe with task_id, episode_id, total_return columns
        min_return: Minimum possible return (float)
        max_return: Maximum possible return (float)
        gamma: Significance level (float), default 0.05
    Returns:
        lower_bound: dataframe with task_id, mean_return, and lower_bound columns
    """

    episodes_per_task = df.group_by("task_id").agg(
        pl.col("episode_id").count().alias("n_episodes")
    )
    assert episodes_per_task.select(pl.col("n_episodes").n_unique()).item() == 1, (
        "All tasks must have the same number of episodes"
    )
    n_episodes = episodes_per_task.select(pl.col("n_episodes").first()).item()

    # Compute mean return per task
    mean_returns = df.group_by("task_id").agg(
        pl.col("total_return").mean().alias("mean_return")
    )

    # Hoeffding lower bound calculation
    hoeffding_bound = (max_return - min_return) * np.sqrt(
        (np.log(1 / beta)) / (2 * n_episodes)
    )

    # Compute lower bounds
    mean_returns = mean_returns.with_columns(
        (pl.col("mean_return") - hoeffding_bound).alias("lower_bound")
    )

    return mean_returns.select(["task_id", "mean_return", "lower_bound"]).sort(
        "task_id"
    )


def empirical_bernstein(df, min_return, max_return, beta=0.05):
    """
    Computes Empirical Bernstein lower bounds for real-valued returns.

    The bound is based on Maurer and Pontil (2009). It adapts to the
    sample variance, often providing tighter bounds than Hoeffding when
    variance is low.

    Args:
        df: Polars dataframe with task_id, episode_id, total_return columns
        min_return: Minimum possible return (float)
        max_return: Maximum possible return (float)
        gamma: Significance level (float), default 0.05
    Returns:
        dataframe with task_id, mean_return, and lower_bound columns
    """

    # 1. Validate sample size consistency
    episodes_per_task = df.group_by("task_id").agg(
        pl.col("episode_id").count().alias("n_episodes")
    )

    assert episodes_per_task.select(pl.col("n_episodes").n_unique()).item() == 1, (
        "All tasks must have the same number of episodes"
    )

    n = episodes_per_task.select(pl.col("n_episodes").first()).item()

    if n < 2:
        raise ValueError(
            "Empirical Bernstein requires at least 2 samples to compute variance."
        )

    # 2. Compute Mean and Unbiased Variance per task
    stats = df.group_by("task_id").agg(
        [
            pl.col("total_return").mean().alias("mean_return"),
            # Polars .var() computes unbiased sample variance (ddof=1) by default
            pl.col("total_return").var().alias("sample_var"),
        ]
    )

    # 3. Scale statistics to the [0, 1] domain
    # The theorem is strictly defined for random variables in [0, 1].
    # We must normalize the variance: Var(X_scaled) = Var(X) / (range)^2
    range_span = max_return - min_return

    stats = stats.with_columns(
        (pl.col("sample_var") / (range_span**2)).alias("scaled_var")
    )

    # 4. Calculate the Deviation Term (epsilon)
    # Formula: sqrt(2 * Var * ln(2/d) / n) + 7 * ln(2/d) / (3 * (n - 1))
    # Note: We use ln(2/gamma) because the empirical bound requires a union bound
    # (cost of estimating both mean and variance).
    log_term = np.log(2 / beta)

    stats = stats.with_columns(
        (
            np.sqrt(2 * pl.col("scaled_var") * log_term / n)
            + (7 * log_term) / (3 * (n - 1))
        ).alias("epsilon_scaled")
    )

    # 5. Compute Lower Bound and Denormalize
    # Bound = Mean - (epsilon_scaled * range_span)
    stats = stats.with_columns(
        (pl.col("mean_return") - (pl.col("epsilon_scaled") * range_span)).alias(
            "lower_bound"
        )
    )

    return stats.select(["task_id", "mean_return", "lower_bound"]).sort("task_id")


def dkw_mean_lower_bound(df, min_return, max_return, beta=0.05):
    """
    Computes a guaranteed lower bound on the mean using the
    Dvoretzky-Kiefer-Wolfowitz (DKW) inequality.

    This constructs the "worst-case" CDF that fits within the DKW
    confidence band and computes its mean.

    Args:
        df: Polars dataframe with task_id, total_return
        min_return: Minimum possible return (a)
        max_return: Maximum possible return (b)
        beta: Significance level (default 0.05)
    Returns:
        dataframe with task_id, mean_return, and lower_bound
    """

    # 1. Verify sample sizes
    episodes_per_task = df.group_by("task_id").agg(
        pl.col("total_return").count().alias("n_episodes")
    )
    n = episodes_per_task.select(pl.col("n_episodes").first()).item()
    assert episodes_per_task["n_episodes"].n_unique() == 1, (
        "All tasks must have the same number of episodes"
    )

    # 2. Calculate DKW Epsilon
    # Standard DKW (Massart's tight constant): P(sup|Fn - F| > eps) <= 2exp(-2n*eps^2)
    # We construct a one-sided bound, so we use:
    # eps = sqrt( ln(1/gamma) / (2n) )
    epsilon = np.sqrt(np.log(1 / beta) / (2 * n))

    # 3. Sort data to compute the Empirical CDF
    # We need to process the "gaps" between sorted returns to integrate the mean.
    sorted_df = df.sort(["task_id", "total_return"])

    # 4. Compute the Worst-Case Mean Integral
    # Formula: Mean = min_return + Integral_of_(1 - F_upper(x)) dx
    # We sum the areas of rectangles defined by the sorted samples.

    dkw_stats = sorted_df.with_columns(
        [
            # Calculate the gap between current sample and previous sample
            # For the first sample, the "previous" is min_return.
            (
                pl.col("total_return")
                - pl.col("total_return").shift(1).over("task_id").fill_null(min_return)
            ).alias("interval_width"),
            # Calculate the Empirical CDF value *before* this sample
            # (i.e., 0 for the first gap, 1/n for the second, etc.)
            ((pl.col("total_return").cum_count().over("task_id") - 1) / n).alias(
                "ecdf_prev"
            ),
        ]
    )

    # Construct the Upper Bound of the CDF (The "Worst Case" curve)
    # The true CDF cannot be higher than Empirical_CDF + Epsilon
    dkw_stats = dkw_stats.with_columns(
        pl.min_horizontal(1.0, pl.col("ecdf_prev") + epsilon).alias("cdf_upper_bound")
    )

    # Integrate: Sum of width * probability_mass_above_curve
    # The probability mass contributing to the mean is (1 - CDF)
    dkw_stats = dkw_stats.with_columns(
        (pl.col("interval_width") * (1.0 - pl.col("cdf_upper_bound"))).alias(
            "mean_contribution"
        )
    )

    # 5. Aggregate
    result = dkw_stats.group_by("task_id").agg(
        [
            pl.col("total_return").mean().alias("mean_return"),
            # The integration starts at min_return, so we add it as the base
            (pl.col("mean_contribution").sum() + min_return).alias("lower_bound"),
        ]
    )

    return result.sort("task_id")
