import numpy as np


try:
    import faiss  
    FAISS_AVAILABLE = True
except Exception:
    FAISS_AVAILABLE = False


class DVM_AD:
    

    def __init__(
        self,
        mode: str = "both",
        eps: float = 0.1,
        artificial_mode: str = "max",
        use_faiss: bool = True,
        faiss_float32: bool = True,
        chunk_size: int = 4096,
    ):
        assert mode in {"min", "max", "both"}
        assert artificial_mode in {"max", "min", "farthest"}
        assert 0.0 < eps < 0.5, "eps should be in (0, 0.5) for two-tailed selection."

        self.mode = mode
        self.eps = float(eps)
        self.artificial_mode = artificial_mode

        self.use_faiss = bool(use_faiss) and FAISS_AVAILABLE
        self.faiss_float32 = bool(faiss_float32)
        self.chunk_size = int(chunk_size)

        
        self.npd = None                  
        self.basepoint_X = None          
        self.filtered_eigvals = None     
        self.faiss_index = None          

    
    
    
    def _construct_reference_point(self, X: np.ndarray) -> np.ndarray:
        if self.artificial_mode == "max":
            return np.max(X, axis=0)
        if self.artificial_mode == "min":
            return np.min(X, axis=0)
        if self.artificial_mode == "farthest":
            mu = np.mean(X, axis=0)
            idx = np.argmax(np.linalg.norm(X - mu, axis=1))
            return X[idx]
        raise ValueError("Invalid artificial_mode")

    
    
    
    def compute_discriminants(self, X: np.ndarray, y: np.ndarray) -> np.ndarray:
        
        X = np.asarray(X)
        y = np.asarray(y)
        classes = np.unique(y)
        mean_total = X.mean(axis=0)

        
        P_W = []
        for c in classes:
            Xc = X[y == c]
            mc = Xc.mean(axis=0)
            P_W.append((Xc - mc).T)
        P_W = np.hstack(P_W)  

        
        P_S = (X - mean_total).T
        S_S = P_S @ P_S.T  

        
        eigvals_S, Q_S = np.linalg.eigh(S_S)  
        
        
        D_pinv = np.linalg.pinv(np.diag(eigvals_S))

        
        M = P_W @ P_W.T
        T = D_pinv @ (Q_S.T @ M @ Q_S)

        
        eigvals, eigvecs = np.linalg.eigh(T)  

        eps = self.eps
        if self.mode == "max":
            mask = eigvals > (1.0 - eps)
        elif self.mode == "min":
            mask = eigvals < eps
        elif self.mode == "both":
            mask = (eigvals < eps) | (eigvals > (1.0 - eps))
        else:
            raise ValueError("mode must be 'min' / 'max' / 'both'")

        
        if not np.any(mask):
            mask = np.ones_like(eigvals, dtype=bool)

        self.filtered_eigvals = eigvals[mask]
        theta = Q_S @ eigvecs[:, mask]  
        return theta

    
    
    
    def fit(self, X_train: np.ndarray, y_train: np.ndarray = None):
        
        X_train = np.asarray(X_train)
        n, d = X_train.shape

        if y_train is None:
            
            y_train = np.zeros(n, dtype=int)
        else:
            y_train = np.asarray(y_train)

        x_ref = self._construct_reference_point(X_train)

        
        new_label = (np.max(y_train) + 1) if y_train.size else 1

        X_aug = np.vstack([X_train, x_ref])
        y_aug = np.hstack([y_train, new_label])

        self.npd = self.compute_discriminants(X_aug, y_aug)

        Z_aug = X_aug @ self.npd
        if self.faiss_float32:
            Z_aug = Z_aug.astype(np.float32)
        self.basepoint_X = Z_aug

        
        if self.use_faiss:
            m = self.basepoint_X.shape[1]
            index = faiss.IndexFlatL2(m)
            index.add(self.basepoint_X.astype(np.float32))
            self.faiss_index = index

        return self

    
    
    
    def transform(self, X: np.ndarray) -> np.ndarray:
        if self.npd is None:
            raise RuntimeError("Model not fitted. Call fit() first.")
        return np.asarray(X) @ self.npd

    
    
    
    @staticmethod
    def _min_l2_to_set_chunked(A: np.ndarray, B: np.ndarray, chunk_size: int = 4096) -> np.ndarray:
        
        A = np.asarray(A)
        B = np.asarray(B)
        nA = A.shape[0]

        
        norm_B = np.sum(B ** 2, axis=1)[None, :]  

        out = np.empty(nA, dtype=A.dtype)
        for s in range(0, nA, chunk_size):
            e = min(s + chunk_size, nA)
            Ac = A[s:e]
            norm_A = np.sum(Ac ** 2, axis=1)[:, None]  
            dot = Ac @ B.T                              
            dist2 = norm_A + norm_B - 2.0 * dot         
            out[s:e] = np.sqrt(np.maximum(dist2.min(axis=1), 0.0))
        return out

    
    
    
    def predict(self, X_test: np.ndarray) -> np.ndarray:
        if self.basepoint_X is None:
            raise RuntimeError("Model not fitted. Call fit() first.")

        Z = self.transform(X_test)
        if self.faiss_float32:
            Z = Z.astype(np.float32)

        
        if self.use_faiss and self.faiss_index is not None:
            D, _ = self.faiss_index.search(Z.astype(np.float32), k=1)
            return np.sqrt(D[:, 0])

        
        return self._min_l2_to_set_chunked(Z, self.basepoint_X, chunk_size=self.chunk_size)

    
    
    
    def convergence_stats(self, X: np.ndarray) -> dict:
        
        if self.basepoint_X is None:
            raise RuntimeError("Model not fitted. Call fit() first.")

        centroid = self.basepoint_X.mean(axis=0)
        Z = self.transform(X)
        if self.faiss_float32:
            Z = Z.astype(np.float32)

        dists = np.linalg.norm(Z - centroid, axis=1)
        sq_train = np.mean(np.sum((self.basepoint_X - centroid) ** 2, axis=1))
        sq_proj = np.mean(np.sum((Z - centroid) ** 2, axis=1))

        return {
            : float(dists.mean()),
            : float(dists.max()),
            : float(dists.min()),
            : float(dists.std()),
            : (np.nan if sq_train == 0 else float(sq_proj / sq_train)),
            : (None if self.filtered_eigvals is None else int(self.filtered_eigvals.size)),
        }
