from __future__ import annotations
import numpy as np
from dataclasses import dataclass

@dataclass
class EstimatorConfig:
    n_samples: int = 2000
    n_blocks: int = 10
    delta: float = 1e-3
    antithetic: bool = True
    use_importance: bool = False
    shift_u: np.ndarray | None = None  # (d,)

def _draw_gaussians(d: int, sigma: float, n: int, rng: np.random.Generator, antithetic: bool):
    if antithetic:
        m = (n + 1)//2
        Xi = rng.normal(0.0, sigma, size=(m, d))
        Xi = np.vstack([Xi, -Xi])[:n]
    else:
        Xi = rng.normal(0.0, sigma, size=(n, d))
    return Xi

def _is_weights(xi: np.ndarray, u: np.ndarray, sigma: float) -> np.ndarray:
    if u is None:
        return np.ones(xi.shape[0], dtype=float)
    dot = (xi @ u) / (sigma**2)
    c = - (np.dot(u, u)) / (2 * sigma**2)
    return np.exp(dot + c)

def _blockify(arr: np.ndarray, n_blocks: int):
    n = len(arr)
    b = max(1, n // n_blocks)
    return [arr[i:i+b] for i in range(0, n, b)]

def mom_mean(values: np.ndarray, n_blocks: int) -> float:
    blocks = _blockify(values, n_blocks)
    means = np.array([np.mean(bl) for bl in blocks if len(bl) > 0])
    return float(np.median(means))

def mom_vector_mean(V: np.ndarray, n_blocks: int) -> np.ndarray:
    blocks = _blockify(V, n_blocks)
    block_means = np.stack([bl.mean(axis=0) for bl in blocks if len(bl)>0], axis=0)
    return np.median(block_means, axis=0)

def mom_mean_bounds(values: np.ndarray, n_blocks: int, delta: float):
    blocks = _blockify(values, n_blocks)
    block_means = np.array([np.mean(bl) for bl in blocks if len(bl)>0], dtype=float)
    center = float(np.median(block_means))
    dev = block_means - center
    s = float(np.sqrt(np.mean(dev**2) + 1e-12))
    B = max(1, len(block_means))
    rad = s * np.sqrt(2.0 * np.log(2.0 / max(delta,1e-12)) / B)
    return center - rad, center + rad

def estimate_variance_ucb(z: np.ndarray, reg_fn, sigma: float, cfg: EstimatorConfig) -> float:
    rng = np.random.default_rng()
    Xi = _draw_gaussians(d=2, sigma=sigma, n=cfg.n_samples, rng=rng, antithetic=cfg.antithetic)
    w = _is_weights(Xi, cfg.shift_u, sigma) if cfg.use_importance else np.ones(len(Xi))
    Y = np.array([float(reg_fn(z[0]+xi[0], z[1]+xi[1])) for xi in Xi])
    WY = w * Y
    WY2 = w * (Y**2)
    mu_lcb, _ = mom_mean_bounds(WY, cfg.n_blocks, cfg.delta)
    _, m2_ucb = mom_mean_bounds(WY2, cfg.n_blocks, cfg.delta)
    var_ucb = float(max(0.0, m2_ucb - (mu_lcb**2)))
    return var_ucb

def estimate_grad_norm_ucb(z: np.ndarray, reg_fn, sigma: float, cfg: EstimatorConfig,
                           norm_type: str = "l2") -> float:
    rng = np.random.default_rng()
    Xi = _draw_gaussians(d=2, sigma=sigma, n=cfg.n_samples, rng=rng, antithetic=cfg.antithetic)
    w = _is_weights(Xi, cfg.shift_u, sigma) if cfg.use_importance else np.ones(len(Xi))
    V = np.empty_like(Xi)
    for i, xi in enumerate(Xi):
        V[i] = (xi / (sigma**2)) * float(reg_fn(z[0]+xi[0], z[1]+xi[1])) * w[i]
    mu_hat = mom_vector_mean(V, cfg.n_blocks)
    blocks = _blockify(V, cfg.n_blocks)
    block_means = np.stack([bl.mean(axis=0) for bl in blocks if len(bl)>0], axis=0)
    diffs = block_means - mu_hat[None,:]
    if norm_type == "l2":
        center = float(np.linalg.norm(mu_hat, ord=2))
        s = float(np.sqrt(np.mean(np.sum(diffs**2, axis=1)) + 1e-12))
    elif norm_type == "l1":
        center = float(np.linalg.norm(mu_hat, ord=1))
        s = float(np.sqrt(np.mean(np.sum(np.abs(diffs), axis=1)**2) + 1e-12))
    elif norm_type == "linf":
        center = float(np.linalg.norm(mu_hat, ord=np.inf))
        s = float(np.sqrt(np.mean(np.max(np.abs(diffs), axis=1)**2) + 1e-12))
    else:
        raise ValueError("norm_type must be one of {'l2','l1','linf'}")
    B = max(1, len(block_means))
    rad = s * np.sqrt(2.0 * np.log(2.0 / max(cfg.delta,1e-12)) / B)
    return max(0.0, center + rad)
