from __future__ import annotations

from typing import Callable, Tuple

import numpy as np


def standardize_features(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    mean = X.mean(axis=0, keepdims=True)
    std = X.std(axis=0, keepdims=True)
    std[std == 0.0] = 1.0
    Xs = (X - mean) / std
    return Xs, mean, std


def apply_standardization(X: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    std_safe = std.copy()
    std_safe[std_safe == 0.0] = 1.0
    return (X - mean) / std_safe


def fit_source_standardizer(X_source: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Fit a single standardizer on source features and return (mean, std).

    Apply these statistics to BOTH source and target to avoid per-domain drift.
    """
    mean = X_source.mean(axis=0, keepdims=True)
    std = X_source.std(axis=0, keepdims=True)
    std[std == 0.0] = 1.0
    return mean, std


def standardize_pair(X: np.ndarray, Y: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    Z = np.vstack([X, Y])
    mean = Z.mean(axis=0, keepdims=True)
    std = Z.std(axis=0, keepdims=True)
    std[std == 0.0] = 1.0
    return (X - mean) / std, (Y - mean) / std, mean, std


