from __future__ import annotations
import numpy as np
from typing import Sequence, Dict, Any

def run_lr_topup_svd_per_direction(
    per_dir_kept_schedule: int | Sequence[int],
    p: int = 10,
    n0: int = 200,
    sigma: float = 1.0,
    r: float = 0.5,
    scenario: str = "unbiased",    # "unbiased" | "biased" | "none"
    theta_star: np.ndarray | None = None,
    theta_center: np.ndarray | None = None,
    bias_norm: float = 0.3,
    seed: int = 0,
    raw_batch_start: int = 5000,
    raw_growth: float = 2.0,
    max_batches: int = 100,
    also_return_nofilter: bool = True,
) -> Dict[str, Any]:
    """
    Iterative verifier-guided retraining in linear regression using SVD directions.

    For each round k:
      1) Use the last estimate β_k to generate synthetic responses along each SVD direction v_j:
           y_raw ~ N(mean = v_j^T β_k,  var = σ^2)
      2) Filtered path (if scenario != "none"): keep samples satisfying
           |y_raw - v_j^T θ^0| <= r ||v_j|| + sqrt(2/π) σ
         and ensure exactly per_dir_kept_schedule[k] samples are kept for each direction.
         If acceptance is too low, fall back to averaging whatever was collected.
      3) Update β_{k+1} by recombining the per-direction means.
      4) Optionally compute a parallel "no-filter" baseline with the same sample sizes.

    Args
    ----
    per_dir_kept_schedule : int | Sequence[int]
        Number of *post-filter* kept samples per direction, per round.
        If int, the same value is used for all rounds.
    p : int
        Feature dimension (= number of directions).
    n0 : int
        Real sample size used for the initial OLS estimator β_0.
    sigma : float
        Noise standard deviation.
    r : float
        Filtering radius parameter.
    scenario : {"unbiased","biased","none"}
        Filtering setup:
          - "unbiased": verifier center θ^0 = θ*
          - "biased"  : θ^0 shifted in the direction of θ*
          - "none"    : no filtering, all samples kept
    theta_star : np.ndarray | None
        True parameter vector of shape (p,). If None, generated randomly.
    theta_center : np.ndarray | None
        Verifier center θ^0. If None, derived according to scenario.
    bias_norm : float
        Amount of shift added if scenario="biased" and theta_center is None.
    seed : int
        Random seed for reproducibility.
    raw_batch_start : int
        Initial pre-filter batch size per direction.
    raw_growth : float
        Growth factor for pre-filter batch size if not enough samples are accepted.
    max_batches : int
        Maximum number of batch-growth attempts before fallback.
    also_return_nofilter : bool
        If True, compute and return a baseline trajectory without filtering.

    Returns
    -------
    dict :
        {
          "theta_star": (p,),
          "theta_center": (p,) or None,
          "beta_hist": (K+1, p),                     # filtered path β_0 ... β_K
          "n_generated_by_dir": (K, p),              # pre-filter counts
          "n_kept_by_dir": (K, p),                   # post-filter kept counts
          "n_generated_total": (K,),
          "n_kept_total": (K,),
          "acceptance_rate": (K,),
          # present only if also_return_nofilter=True:
          "beta_hist_nofilter": (K+1, p),
          "n_used_by_dir_nofilter": (K, p)
        }
    """
    # ---- normalize and validate inputs ----
    if isinstance(per_dir_kept_schedule, int):
        per_dir_schedule = np.array([per_dir_kept_schedule], dtype=int)
    else:
        per_dir_schedule = np.asarray(per_dir_kept_schedule, dtype=int)
    if per_dir_schedule.ndim != 1 or np.any(per_dir_schedule < 0):
        raise ValueError("per_dir_kept_schedule must be a 1-D array of nonnegative ints (or a single int).")

    if scenario not in {"unbiased", "biased", "none"}:
        raise ValueError('scenario must be one of {"unbiased","biased","none"}.')

    K = int(per_dir_schedule.size)
    rng_master = np.random.default_rng(seed)

    # ---- true parameter θ* ----
    if theta_star is None:
        theta_star = rng_master.normal(size=p)
        theta_star /= (np.linalg.norm(theta_star) + 1e-12)
        theta_star *= np.sqrt(p)
    else:
        theta_star = np.asarray(theta_star, dtype=float).reshape(-1)
        if theta_star.shape != (p,):
            raise ValueError("theta_star must have shape (p,)")

    # ---- initial real data and OLS estimate β_0 ----
    X0 = rng_master.normal(size=(n0, p))
    y0 = X0 @ theta_star + rng_master.normal(scale=sigma, size=n0)
    beta_k, *_ = np.linalg.lstsq(X0, y0, rcond=None)

    # ---- directions: columns of V from SVD ----
    Vt = np.linalg.svd(X0, full_matrices=False)[2]  # (p, p)
    directions = Vt.T                               # (p, p)

    # ---- verifier center θ^0 ----
    if scenario == "none":
        theta_center_used = None
    else:
        if theta_center is not None:
            theta_center_used = np.asarray(theta_center, dtype=float).reshape(-1)
            if theta_center_used.shape != (p,):
                raise ValueError("theta_center must have shape (p,)")
        elif scenario == "unbiased":
            theta_center_used = theta_star.copy()
        else:  # "biased"
            delta = bias_norm * theta_star / (np.linalg.norm(theta_star) + 1e-12)
            theta_center_used = theta_star + delta

    # ---- bookkeeping ----
    n_generated_by_dir = np.zeros((K, p), dtype=int)
    n_kept_by_dir      = np.zeros((K, p), dtype=int)
    n_generated_total  = np.zeros(K, dtype=int)
    n_kept_total       = np.zeros(K, dtype=int)
    acceptance_rate    = np.zeros(K, dtype=float)

    beta_hist = [beta_k.copy()]

    if also_return_nofilter:
        beta_k_nf = beta_k.copy()
        beta_hist_nofilter = [beta_k_nf.copy()]
        n_used_by_dir_nofilter = np.zeros((K, p), dtype=int)

    # ---- iterative retraining ----
    for k in range(K):
        per_dir_target = int(per_dir_schedule[k])
        round_rng = np.random.default_rng(seed + 1009 * (k + 1))

        # ----- filtered path -----
        a_coords = np.zeros(p, dtype=float)
        for j in range(p):
            vj = directions[:, j]
            vj_norm2 = float(vj @ vj)
            needed = per_dir_target

            if needed == 0:
                a_coords[j] = float((vj @ beta_k) / (vj_norm2 + 1e-12))
                continue

            kept = []
            batch = int(max(1, raw_batch_start))
            batches_used = 0

            center = float(vj @ theta_center_used) if theta_center_used is not None else 0.0
            thresh = (r * np.sqrt(vj_norm2) + np.sqrt(2.0/np.pi) * sigma) if scenario != "none" else np.inf

            while len(kept) < needed:
                mean_dir = float(vj @ beta_k)
                y_raw = mean_dir + round_rng.normal(scale=sigma, size=batch)

                keep_mask = np.ones(batch, dtype=bool) if scenario == "none" else (np.abs(y_raw - center) <= thresh)
                y_kept_batch = y_raw[keep_mask]

                if y_kept_batch.size:
                    kept.extend(y_kept_batch.tolist())

                n_generated_by_dir[k, j] += batch

                if len(kept) < needed:
                    batch = int(np.ceil(batch * raw_growth))
                    batches_used += 1
                    if batches_used > max_batches:
                        # fallback: average what we have, or fallback to current projection
                        a_coords[j] = float(np.mean(kept)) / (vj_norm2 + 1e-12) if kept else float((vj @ beta_k) / (vj_norm2 + 1e-12))
                        break

            if len(kept) >= needed:
                y_kept = np.array(kept[:needed], dtype=float)
                a_coords[j] = float(np.mean(y_kept) / (vj_norm2 + 1e-12))
                n_kept_by_dir[k, j] = needed
            else:
                n_kept_by_dir[k, j] = min(len(kept), needed)

        beta_k = directions @ a_coords
        beta_hist.append(beta_k.copy())

        n_generated_total[k] = int(n_generated_by_dir[k].sum())
        n_kept_total[k]      = int(n_kept_by_dir[k].sum())
        acceptance_rate[k]   = float(n_kept_total[k] / max(n_generated_total[k], 1))

        # ----- no-filter baseline -----
        if also_return_nofilter:
            round_rng_nf = np.random.default_rng(seed + 1009 * (k + 1) + 7)
            a_coords_nf = np.zeros(p, dtype=float)
            for j in range(p):
                vj = directions[:, j]
                vj_norm2 = float(vj @ vj)
                m = per_dir_target
                if m == 0:
                    a_coords_nf[j] = float((vj @ beta_k_nf) / (vj_norm2 + 1e-12))
                    continue
                mean_dir_nf = float(vj @ beta_k_nf)
                y_nf = mean_dir_nf + round_rng_nf.normal(scale=sigma, size=m)
                a_coords_nf[j] = float(np.mean(y_nf) / (vj_norm2 + 1e-12))
                n_used_by_dir_nofilter[k, j] = m
            beta_k_nf = directions @ a_coords_nf
            beta_hist_nofilter.append(beta_k_nf.copy())

    out: Dict[str, Any] = {
        "theta_star": theta_star,
        "theta_center": theta_center_used,
        "beta_hist": np.vstack(beta_hist),
        "n_generated_by_dir": n_generated_by_dir,
        "n_kept_by_dir": n_kept_by_dir,
        "n_generated_total": n_generated_total,
        "n_kept_total": n_kept_total,
        "acceptance_rate": acceptance_rate,
    }
    if also_return_nofilter:
        out.update({
            "beta_hist_nofilter": np.vstack(beta_hist_nofilter),
            "n_used_by_dir_nofilter": n_used_by_dir_nofilter,
        })
    return out
