"""
Simple LogisticRegression / Ridge model for tabular data.

Used for synthetic experiments where we want simple linear models.
"""

from typing import Optional
import numpy as np

from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.preprocessing import StandardScaler

from models.base import BaseModel


class LogRegModel(BaseModel):
    """
    Simple LogisticRegression (classification) or Ridge (regression) model.

    For synthetic data experiments where we want interpretable linear models.
    """

    def __init__(
        self,
        task: str = 'classification',
        C: float = 1.0,
        alpha: float = 1.0,
        max_iter: int = 1000,
        random_state: int = 42,
        **kwargs,
    ):
        """
        Initialize LogReg/Ridge model.

        Args:
            task: 'classification' or 'regression'
            C: Regularization strength for LogReg (inverse)
            alpha: Regularization strength for Ridge
            max_iter: Maximum iterations for LogReg
            random_state: Random seed
        """
        super().__init__(task=task, **kwargs)
        self.C = C
        self.alpha = alpha
        self.max_iter = max_iter
        self.random_state = random_state

        self.scaler_ = None
        self.model_ = None

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'LogRegModel':
        """Fit the model."""
        X = np.asarray(X)
        y = np.asarray(y)

        # Standardize features
        self.scaler_ = StandardScaler()
        X_scaled = self.scaler_.fit_transform(X)

        # Build model
        if self.task == 'classification':
            self.model_ = LogisticRegression(
                C=self.C,
                max_iter=self.max_iter,
                solver='lbfgs',
                random_state=self.random_state,
            )
        else:
            self.model_ = Ridge(
                alpha=self.alpha,
                random_state=self.random_state,
            )

        # Fit
        if sample_weight is not None:
            self.model_.fit(X_scaled, y, sample_weight=sample_weight)
        else:
            self.model_.fit(X_scaled, y)

        self.is_fitted_ = True
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        self._check_fitted()
        X = np.asarray(X)
        X_scaled = self.scaler_.transform(X)

        if self.task == 'classification':
            return self.model_.predict(X_scaled)
        else:
            # For regression, return continuous values
            # Caller can threshold at 0 for classification
            return self.model_.predict(X_scaled)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Get class probabilities (classification only)."""
        if self.task != 'classification':
            raise NotImplementedError("predict_proba only for classification")
        self._check_fitted()
        X = np.asarray(X)
        X_scaled = self.scaler_.transform(X)
        return self.model_.predict_proba(X_scaled)

    def __repr__(self) -> str:
        if self.task == 'classification':
            return f"LogRegModel(task='classification', C={self.C})"
        return f"LogRegModel(task='regression', alpha={self.alpha})"
