import numpy as np
import torch
from typing import Optional, Tuple

def _to_2d_torch(x, device: Optional[str] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
    """Accept torch.Tensor or np.ndarray and return a 2D torch.Tensor on device."""
    if isinstance(x, torch.Tensor):
        t = x
        if device is not None:
            t = t.to(device)
    else:  # assume numpy or array-like
        t = torch.as_tensor(x, device=device)
    if dtype is not None:
        t = t.to(dtype)
    # Ensure 2D
    if t.ndim == 1:
        t = t.unsqueeze(1)
    elif t.ndim > 2:
        t = t.reshape(t.shape[0], -1)
    if t.ndim != 2:
        raise ValueError(f"Expected 2D data (n, d); got shape {tuple(t.shape)}")
    return t

@torch.no_grad()
def check_surround_assumption(
    acts: np.ndarray,
    epsilon: float,
    delta: float,
    n_omegas: int = 2048,
    seed: int = 0,
    return_min_omega: bool = True,
):
    """
    Monte-Carlo check of: for all unit omega, P( (X-mean)^T omega > epsilon ) > delta.
    Returns the min proportion over sampled omegas + the omega attaining it (optional).
    """
    X = np.asarray(acts)
    if X.ndim != 2:
        raise ValueError(f"`acts` must be 2D (n,d). Got shape {X.shape}")

    n, d = X.shape
    mean_vec = X.mean(axis=0)
    centered = X - mean_vec

    rng = np.random.default_rng(seed)
    omegas = rng.standard_normal(size=(n_omegas, d))
    omegas /= (np.linalg.norm(omegas, axis=1, keepdims=True) + 1e-12)  # unit vectors

    # projections: (n, n_omegas)
    projs = centered @ omegas.T
    proportions = (projs > epsilon).mean(axis=0)  # length n_omegas

    min_idx = int(np.argmin(proportions))
    min_prop = float(proportions[min_idx])
    holds = (min_prop > delta)

    if return_min_omega:
        return holds, min_prop, proportions, omegas[min_idx], mean_vec, min_idx
    else:
        return holds, min_prop, proportions


def check_a_infty_convergence(
    alpha_sequence,
    window: int = 25,
    rtol: float = 1e-3,
    atol: float = 1e-8,
):
    """
    Empirically check whether A_infty = lim_{N->infty} N * exp(alpha_N) appears to converge.

    Args:
        alpha_sequence: Iterable of alpha_N values (1D). Can be list/np/torch; converted to float numpy.
        window: Number of tail points to use for the stability check (auto-truncated if shorter).
        rtol: Relative tolerance for tail flatness/slope.
        atol: Absolute tolerance for tail flatness/slope.

    Returns:
        dict with:
            - exists: bool, whether the tail looks numerically converged.
            - estimate: float, mean of the tail section of A_N.
            - tail_range: float, max |A_N - estimate| over the tail.
            - slope: float, linear slope of tail A_N vs N (should be ~0 when converged).
            - A_values: full numpy array of A_N for further inspection.
    """
    alpha = np.asarray(alpha_sequence, dtype=float).reshape(-1)
    if alpha.size == 0:
        raise ValueError("alpha_sequence must be non-empty")
    if not np.all(np.isfinite(alpha)):
        raise ValueError("alpha_sequence contains non-finite values")

    N = np.arange(1, alpha.size + 1, dtype=float)
    A_vals = N * np.exp(alpha)

    w = int(window) if window is not None else alpha.size
    w = max(1, min(alpha.size, w))
    tail = A_vals[-w:]
    Ns_tail = N[-w:]

    estimate = float(np.mean(tail))
    tail_range = float(np.max(np.abs(tail - estimate)))
    tail_std = float(np.std(tail))

    if w >= 2:
        # slope of tail A_N vs N to detect drift
        slope = float(np.polyfit(Ns_tail, tail, 1)[0])
    else:
        slope = 0.0

    scale = max(abs(estimate), 1.0)
    tol = rtol * scale + atol
    stable = (tail_range <= tol) or (tail_std <= tol)
    flat_slope = abs(slope) <= tol
    exists = bool(np.isfinite(estimate) and np.isfinite(tail_range) and (stable or flat_slope))

    return {
        "exists": exists,
        "estimate": estimate,
        "tail_range": tail_range,
        "slope": slope,
        "A_values": A_vals,
    }
