import numpy as np
from typing import Dict, Optional, Tuple, List


def _rolling_mean(values: np.ndarray, window_size: int) -> np.ndarray:
    """
    Compute a simple rolling mean with a fixed window size.

    The output is aligned to the right (i.e., mean of the last `window_size` elements
    ending at the current index). For the first positions where there are fewer
    than `window_size` samples available, use the available prefix.
    """
    if window_size <= 1:
        return values.astype(float)

    rolling = np.empty_like(values, dtype=float)
    cumulative_sum = np.cumsum(values, dtype=float)
    for t in range(len(values)):
        start_idx = max(0, t - window_size + 1)
        count = t - start_idx + 1
        total = cumulative_sum[t] - (cumulative_sum[start_idx - 1] if start_idx > 0 else 0.0)
        rolling[t] = total / max(count, 1)
    return rolling


def _ols_fit_loglog(x_positive: np.ndarray, y_positive: np.ndarray) -> Tuple[float, float, float]:
    """
    Fit y = a + b x by ordinary least squares and return (a, b, r2).
    Inputs are already log-transformed and finite.
    """
    n = len(x_positive)
    if n < 2:
        return 0.0, 0.0, 0.0

    sum_x = float(np.sum(x_positive))
    sum_y = float(np.sum(y_positive))
    sum_xx = float(np.sum(x_positive * x_positive))
    sum_xy = float(np.sum(x_positive * y_positive))

    denominator = n * sum_xx - sum_x * sum_x
    if abs(denominator) < 1e-12:
        return 0.0, 0.0, 0.0

    b = (n * sum_xy - sum_x * sum_y) / denominator
    a = (sum_y - b * sum_x) / n

    # R^2
    y_hat = a + b * x_positive
    ss_tot = float(np.sum((y_positive - np.mean(y_positive)) ** 2))
    ss_res = float(np.sum((y_positive - y_hat) ** 2))
    r2 = 1.0 - (ss_res / ss_tot) if ss_tot > 0 else 0.0
    return a, b, r2


def estimate_d_and_C(
    rewards: np.ndarray,
    window_size: int = 4,
    min_points: int = 20,
    trim_quantile: float = 0.0,
    clip_d_range: Tuple[float, float] = (0.0, 3.0),
    smooth_delta_window: int = 3,
    eps: float = 1e-8,
) -> Optional[Dict[str, float]]:
    """
    Estimate (d, C) from a single arm's reward sequence under the definition:
        |mu_t - mu_{t-1}| <= C * t^{-d}

    Approach:
    1) Smooth rewards with a rolling mean to estimate mu_t.
    2) Compute absolute differences delta_t = |mu_t - mu_{t-1}| for t >= 1.
    3) Regress log(delta_t) on log(t) using OLS: log(delta_t) ≈ log C - d log t.
       Intercept = log C, slope = -d.

    Returns a dict with keys: d, C, slope, intercept, r2, num_points.
    Returns None if insufficient data.
    """
    if rewards is None:
        return None

    rewards = np.asarray(rewards, dtype=float).reshape(-1)
    T = len(rewards)
    if T < max(window_size + 2, min_points + 1):
        return None

    mu_hat = _rolling_mean(rewards, window_size)

    # delta_t for t in [1, T-1], using 1-based t in the inequality
    delta = np.abs(mu_hat[1:] - mu_hat[:-1])
    if smooth_delta_window and smooth_delta_window > 1:
        delta = _rolling_mean(delta, smooth_delta_window)
    t_index = np.arange(1, len(delta) + 1, dtype=float)  # 1,2,...

    # Filter small/zero values and non-finite
    finite_mask = np.isfinite(delta) & (delta > eps)
    delta = delta[finite_mask]
    t_index = t_index[finite_mask]

    if len(delta) < min_points:
        return None

    # Optional trimming to reduce outlier influence
    if 0.0 < trim_quantile < 0.5 and len(delta) >= 20:
        low = np.quantile(delta, trim_quantile)
        high = np.quantile(delta, 1.0 - trim_quantile)
        keep = (delta >= max(low, eps)) & (delta <= max(high, eps))
        if np.any(keep):
            delta = delta[keep]
            t_index = t_index[keep]

    # Log-log regression
    log_t = np.log(t_index + 0.0)
    log_delta = np.log(delta)

    # Remove any residual non-finite
    good = np.isfinite(log_t) & np.isfinite(log_delta)
    log_t = log_t[good]
    log_delta = log_delta[good]

    if len(log_t) < min_points:
        return None

    intercept, slope, r2 = _ols_fit_loglog(log_t, log_delta)

    # According to log(delta_t) ≈ log C - d log t
    # slope = -d, intercept = log C
    d_estimate = float(-slope)
    C_estimate = float(np.exp(intercept)) if np.isfinite(intercept) else float("nan")

    # Light regularization/clipping for d
    d_min, d_max = clip_d_range
    if not np.isfinite(d_estimate):
        return None
    if d_min is not None or d_max is not None:
        d_estimate = float(np.clip(d_estimate, d_min if d_min is not None else -np.inf,
                                   d_max if d_max is not None else np.inf))

    if not np.isfinite(C_estimate) or C_estimate <= 0:
        # Fallback for C using median of delta * t^d
        C_fallback = np.median(delta * (t_index ** d_estimate))
        C_estimate = float(C_fallback) if np.isfinite(C_fallback) and C_fallback > 0 else float("nan")

    if not np.isfinite(C_estimate) or C_estimate <= 0:
        return None

    return {
        "d": d_estimate,
        "C": C_estimate,
        "slope": float(slope),
        "intercept": float(intercept),
        "r2": float(r2),
        "num_points": float(len(log_t)),
    }


def estimate_d_and_C_multi_arm(
    rewards_matrix: np.ndarray,
    window_size: int = 8,
    min_points: int = 12,
    trim_quantile: float = 0.05,
    clip_d_range: Tuple[float, float] = (0.05, 2.0),
    eps: float = 1e-8,
) -> List[Optional[Dict[str, float]]]:
    """
    Estimate (d, C) for multiple arms. The input is a matrix of shape (K, T).
    Returns a list of dicts (or None) per arm.
    """
    rewards_matrix = np.asarray(rewards_matrix, dtype=float)
    if rewards_matrix.ndim != 2:
        raise ValueError("rewards_matrix must be 2D with shape (K, T)")

    results: List[Optional[Dict[str, float]]] = []
    for arm_idx in range(rewards_matrix.shape[0]):
        result = estimate_d_and_C(
            rewards_matrix[arm_idx],
            window_size=window_size,
            min_points=min_points,
            trim_quantile=trim_quantile,
            clip_d_range=clip_d_range,
            eps=eps,
        )
        results.append(result)
    return results


__all__ = [
    "estimate_d_and_C",
    "estimate_d_and_C_multi_arm",
]


def _generate_synthetic_sequence(
    T: int,
    d_true: float,
    C_true: float = 0.4,
    base: float = 0.5,
    noise_std: float = 0.02,
    seed: Optional[int] = 42,
) -> np.ndarray:
    """
    Generate a synthetic reward sequence whose mean satisfies approximately
        |mu_t - mu_{t-1}| ≈ C_true * t^{-d_true}
    by moving the mean at each step by a signed step of target magnitude.
    """
    rng = np.random.default_rng(seed)
    mu = np.zeros(T, dtype=float)
    mu[0] = base
    for t in range(1, T):
        step_mag = C_true * ((t + 0.0) ** (-d_true))
        step_sign = 1.0 if (t % 2 == 0) else -1.0  # alternate signs to avoid drift/saturation
        mu[t] = mu[t - 1] + step_sign * step_mag
    # do not clip mu to preserve power-law structure
    # sample rewards with noise
    rewards = mu + rng.normal(0.0, noise_std, size=T)
    rewards = np.clip(rewards, 0.0, 1.0)  # final clipping only on observations
    return rewards

def _demo_single_arm():
    T = 400
    d_true = 0.8
    C_true = 0.30
    rewards = _generate_synthetic_sequence(T=T, d_true=d_true, C_true=C_true)
    result = estimate_d_and_C(rewards, window_size=6, min_points=25)
    print("[Single Arm Demo]")
    print(f"True d = {d_true:.3f}, True C = {C_true:.3f}")
    if result is None:
        print("Estimator returned None (insufficient data or numerical issue)")
    else:
        print(f"Estimated d = {result['d']:.3f}, Estimated C = {result['C']:.3f}")
        print(f"R^2 = {result['r2']:.3f}, points = {int(result['num_points'])}")


def _demo_multi_arm():
    K = 3
    T = 400
    truths = [(0.6, 0.28), (0.8, 0.30), (1.0, 0.32)]  # (d_true, C_true) per arm
    rewards_matrix = []
    for d_true, C_true in truths:
        rewards_matrix.append(_generate_synthetic_sequence(T=T, d_true=d_true, C_true=C_true, seed=123))
    rewards_matrix = np.stack(rewards_matrix, axis=0)

    results = estimate_d_and_C_multi_arm(rewards_matrix, window_size=6, min_points=25)
    print("\n[Multi Arm Demo]")
    for i, (truth, est) in enumerate(zip(truths, results)):
        d_true, C_true = truth
        if est is None:
            print(f"Arm {i}: True(d={d_true:.3f}, C={C_true:.3f}) -> Estimator None")
        else:
            print(
                f"Arm {i}: True(d={d_true:.3f}, C={C_true:.3f}) -> "
                f"Estimated(d={est['d']:.3f}, C={est['C']:.3f}, R^2={est['r2']:.3f}, n={int(est['num_points'])})"
            )


if __name__ == "__main__":
    _demo_single_arm()
    _demo_multi_arm()



