from dataclasses import dataclass
from typing import Optional, Tuple, Literal
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.svm import OneClassSVM

@dataclass
class OCSVMConfig:
    nu: float = 0.05
    kernel: Literal["rbf", "linear"] = "rbf"
    gamma: str = "scale"
    standardize: bool = True
    max_train_elements: Optional[int] = 20000
    calibrate_fpr: Optional[float] = 0.05
    random_state: int = 0


@dataclass
class OCSVMOutlierDetector:
    scaler: Optional[StandardScaler]
    ocsvm: OneClassSVM
    element_shape: Tuple[int, ...]
    threshold: Optional[float]
    
    def score(self, X_elements: np.ndarray) -> np.ndarray:
        X_flat = _flatten_elements(X_elements, self.element_shape)
        X_flat = _ensure_float(X_flat)
        if self.scaler is not None:
            X_flat = self.scaler.transform(X_flat)
        return self.ocsvm.decision_function(X_flat).ravel()
    
    def predict_outliers(self, X_elements: np.ndarray) -> np.ndarray:
        scores = self.score(X_elements)
        
        if self.threshold is not None:
            return scores < self.threshold
        
        X_flat = _flatten_elements(X_elements, self.element_shape)
        X_flat = _ensure_float(X_flat)
        if self.scaler is not None:
            X_flat = self.scaler.transform(X_flat)
        return self.ocsvm.predict(X_flat) == -1


def _ensure_float(X: np.ndarray) -> np.ndarray:
    X = np.asarray(X)
    if not np.issubdtype(X.dtype, np.floating):
        X = X.astype(np.float32)
    if not np.isfinite(X).all():
        raise ValueError(
            "Input contains NaN or Inf values. "
            "Please clean or impute missing values before using OC-SVM."
        )
    return X


def _flatten_elements(X: np.ndarray, element_shape: Tuple[int, ...]) -> np.ndarray:
    X = np.asarray(X)
    expected_ndim = 1 + len(element_shape)
    
    if X.ndim != expected_ndim:
        raise ValueError(
            f"Expected array with {expected_ndim} dimensions for element_shape={element_shape}, "
            f"got shape {X.shape} with {X.ndim} dimensions."
        )
    
    if tuple(X.shape[1:]) != tuple(element_shape):
        raise ValueError(
            f"Expected element shape {element_shape}, got {tuple(X.shape[1:])}."
        )
    
    K = X.shape[0]
    D = int(np.prod(element_shape))
    return X.reshape(K, D)


def fit_ocsvm_detector(
    X_train: np.ndarray,
    config: Optional[OCSVMConfig] = None,
) -> OCSVMOutlierDetector:
    if config is None:
        config = OCSVMConfig()
    
    X_train = np.asarray(X_train)
    
    # Validate input shape: (M, n_obs, *element_shape)
    if X_train.ndim < 3:
        raise ValueError(
            f"Expected X_train with shape (M, n_obs, *element_shape), "
            f"got shape {X_train.shape} with {X_train.ndim} dimensions."
        )
    
    M, n_obs = X_train.shape[:2]
    element_shape = tuple(X_train.shape[2:])
    total_elements = M * n_obs
    
    X_pool = X_train.reshape(total_elements, *element_shape) # pool all obs
    
    # subsampling
    rng = np.random.default_rng(config.random_state)
    if config.max_train_elements is not None and total_elements > config.max_train_elements:
        idx = rng.choice(total_elements, size=config.max_train_elements, replace=False)
        X_fit = X_pool[idx]
    else:
        X_fit = X_pool
    
    X_fit_flat = _flatten_elements(X_fit, element_shape)
    X_fit_flat = _ensure_float(X_fit_flat)
    
    scaler = None
    if config.standardize:
        scaler = StandardScaler()
        X_fit_flat = scaler.fit_transform(X_fit_flat)
    
    # Fit One-Class SVM
    ocsvm = OneClassSVM(
        kernel=config.kernel,
        nu=config.nu,
        gamma=config.gamma,
    )
    ocsvm.fit(X_fit_flat)
    
    threshold = None
    if config.calibrate_fpr is not None:
        scores_fit = ocsvm.decision_function(X_fit_flat).ravel()
        threshold = np.quantile(scores_fit, config.calibrate_fpr)
    
    return OCSVMOutlierDetector(
        scaler=scaler,
        ocsvm=ocsvm,
        element_shape=element_shape,
        threshold=threshold,
    )


def clean_observations(
    detector: OCSVMOutlierDetector,
    X_obs: np.ndarray,
    seed: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    X_obs = np.asarray(X_obs)
    
    scores = detector.score(X_obs)
    outlier_mask = detector.predict_outliers(X_obs)
    
    inlier_idx = np.where(~outlier_mask)[0]
    outlier_idx = np.where(outlier_mask)[0]
    
    X_clean = X_obs.copy()
    
    if len(inlier_idx) > 0 and len(outlier_idx) > 0:
        rng = np.random.default_rng(seed)
        replacement_idx = rng.choice(inlier_idx, size=len(outlier_idx), replace=True)
        X_clean[outlier_idx] = X_obs[replacement_idx]
    elif len(inlier_idx) == 0:
        pass
    
    return X_clean, outlier_mask, scores

__all__ = [
    "OCSVMConfig",
    "OCSVMOutlierDetector",
    "fit_ocsvm_detector",
    "clean_observations",
]
