"""Weight construction helpers."""

from __future__ import annotations

from typing import Iterable, Sequence

import numpy as np
from scipy.stats import chi2


def _normalize(weights: np.ndarray) -> np.ndarray:
    total = weights.sum()
    if total <= 0:
        return np.ones_like(weights) / len(weights)
    return weights / total


def compute_optimal_weights(
    biases: Sequence[float], variances: Sequence[float], lambd: float
) -> np.ndarray:
    biases = np.asarray(biases, dtype=float)
    variances = np.asarray(variances, dtype=float)
    denom = variances + float(lambd) * biases**2
    denom = np.clip(denom, 1e-12, None)
    weights = 1.0 / denom
    return _normalize(weights)


def compute_conservative_weights(
    biases: Sequence[float], variances: Sequence[float], lambd: float = 1.0
) -> np.ndarray:
    biases = np.asarray(biases, dtype=float)
    variances = np.asarray(variances, dtype=float)
    z_q = 1.96
    adj_bias = biases.copy()
    if len(adj_bias) > 1:
        se = np.sqrt(np.maximum(variances[1:] + variances[0], 0.0))
        adj_bias[1:] = np.abs(biases[1:]) + z_q * se
    denom = variances + float(lambd) * adj_bias**2
    denom = np.clip(denom, 1e-12, None)
    return _normalize(1.0 / denom)


def compute_conservative_variance_weights(
    biases: Sequence[float],
    variances: Sequence[float],
    K_list: Sequence[int],
    lambd: float = 1.0,
    alpha_var: float = 1e-4,
) -> np.ndarray:
    biases = np.asarray(biases, dtype=float)
    variances = np.asarray(variances, dtype=float)
    K_arr = np.asarray(K_list, dtype=int)
    vars_adj = variances.copy()
    if len(vars_adj) > 1:
        chi2_q = chi2.ppf(alpha_var, np.maximum(K_arr[1:] - 1, 1))
        vars_adj[1:] *= (K_arr[1:] - 1) / np.clip(chi2_q, 1e-12, None)
    denom = vars_adj + float(lambd) * biases**2
    denom = np.clip(denom, 1e-12, None)
    return _normalize(1.0 / denom)


