import numpy as np
from dataclasses import dataclass
from typing import Optional, Union, Tuple

from sklearn.metrics.pairwise import rbf_kernel
from sklearn.utils import check_random_state


def _make_proxy_point(
    X: np.ndarray,
    strategy: str = "max_corner",
    margin: float = 0.1,
) -> np.ndarray:
    
    X = np.asarray(X)
    if X.ndim != 2:
        raise ValueError("X must be 2D array")

    if strategy == "max_corner":
        x_max = np.max(X, axis=0)
        x_min = np.min(X, axis=0)
        span = x_max - x_min
        span = np.where(span > 0, span, 1.0)  
        z = x_max + margin * span
        return z

    if strategy == "mean_plus_kstd":
        mu = np.mean(X, axis=0)
        sd = np.std(X, axis=0)
        sd = np.where(sd > 0, sd, 1.0)
        z = mu + margin * sd
        return z

    raise ValueError(f"Unknown proxy strategy: {strategy}")


def _nullspace_via_svd(A: np.ndarray, rcond: float = 1e-12) -> np.ndarray:
    
    A = np.asarray(A)
    if A.size == 0:
        
        return np.eye(A.shape[1], dtype=A.dtype)

    U, s, Vt = np.linalg.svd(A, full_matrices=True)
    if s.size == 0:
        return Vt.T

    tol = rcond * np.max(s)
    rank = int(np.sum(s > tol))
    ns = Vt.T[:, rank:]
    return ns


@dataclass
class KNFST:
    
    gamma: Union[str, float] = "scale"     
    rcond: float = 1e-12                   
    proxy_strategy: str = "max_corner"
    proxy_margin: float = 0.1
    max_train_size: Optional[int] = 3000   
    random_state: int = 42
    dtype: np.dtype = np.float32

    
    X_aug_: Optional[np.ndarray] = None
    alpha_: Optional[np.ndarray] = None
    gamma_: Optional[float] = None
    normal_mean_proj_: Optional[float] = None
    train_scores_: Optional[np.ndarray] = None

    def _resolve_gamma(self, X: np.ndarray) -> float:
        d = X.shape[1]
        if isinstance(self.gamma, (int, float)):
            return float(self.gamma)
        if self.gamma == "auto":
            return 1.0 / max(d, 1)
        if self.gamma == "scale":
            
            v = float(np.var(X))
            v = v if v > 0 else 1.0
            return 1.0 / (max(d, 1) * v)
        raise ValueError(f"Unknown gamma setting: {self.gamma}")

    def fit(self, X: np.ndarray, y=None):
        
        rng = check_random_state(self.random_state)

        X = np.asarray(X, dtype=self.dtype)
        if X.ndim != 2:
            raise ValueError("X must be 2D")

        
        if self.max_train_size is not None and X.shape[0] > self.max_train_size:
            idx = rng.choice(X.shape[0], size=self.max_train_size, replace=False)
            X = X[idx]

        
        z = _make_proxy_point(X, strategy=self.proxy_strategy, margin=self.proxy_margin).astype(self.dtype)
        X_aug = np.vstack([X, z[None, :]])
        n = X_aug.shape[0]
        n0 = n - 1  
        idx0 = np.arange(n0)
        idx1 = np.array([n0])

        
        gamma_val = self._resolve_gamma(X_aug)
        K = rbf_kernel(X_aug, X_aug, gamma=gamma_val).astype(self.dtype)

        
        
        if n0 >= 2:
            H0 = np.eye(n0, dtype=self.dtype) - (1.0 / n0) * np.ones((n0, n0), dtype=self.dtype)
            G0 = K[:, idx0] @ H0  
            
            G = G0
        else:
            
            G = np.zeros((n, 0), dtype=self.dtype)

        
        U = _nullspace_via_svd(G.T, rcond=self.rcond).astype(self.dtype)  

        
        m0 = np.mean(K[:, idx0], axis=1)  
        m1 = K[:, idx1].reshape(-1)       
        d_vec = (m1 - m0).astype(self.dtype)

        
        uTd = U.T @ d_vec
        norm_uTd = float(np.linalg.norm(uTd))
        if norm_uTd < 1e-12:
            
            alpha = U[:, 0].copy()
        else:
            alpha = U @ (uTd / norm_uTd)

        
        denom = float(np.sqrt(np.maximum(alpha @ (K @ alpha), 1e-12)))
        alpha = (alpha / denom).astype(self.dtype)

        
        self.X_aug_ = X_aug
        self.alpha_ = alpha
        self.gamma_ = float(gamma_val)

        
        proj_normals = alpha @ K[:, idx0]  
        self.normal_mean_proj_ = float(np.mean(proj_normals))

        
        self.train_scores_ = np.abs(proj_normals - self.normal_mean_proj_).astype(self.dtype)

        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        
        if self.X_aug_ is None or self.alpha_ is None or self.gamma_ is None:
            raise RuntimeError("Call fit() first.")
        X = np.asarray(X, dtype=self.dtype)
        k = rbf_kernel(self.X_aug_, X, gamma=self.gamma_).astype(self.dtype)  
        proj = (self.alpha_.reshape(1, -1) @ k).reshape(-1)  
        return proj

    def decision_function(self, X: np.ndarray) -> np.ndarray:
        
        proj = self.transform(X)
        return np.abs(proj - float(self.normal_mean_proj_)).astype(self.dtype)

    def predict(self, X: np.ndarray) -> np.ndarray:
        
        return self.decision_function(X)