"""
TF-IDF model for text classification and regression.

Combines TF-IDF vectorization with sklearn linear models
(LogisticRegression for classification, Ridge for regression).
"""

from typing import Union, Optional, List
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression, Ridge

from models.base import BaseModel


class TfidfModel(BaseModel):
    """
    TF-IDF + Linear Model for text tasks.
    
    Supports both classification (LogisticRegression) and regression (Ridge).
    Automatically handles vectorization and prediction in a single interface.
    
    Example:
        >>> model = TfidfModel(task='classification', max_features=60000)
        >>> model.fit(train_texts, train_labels, sample_weight=weights)
        >>> predictions = model.predict(test_texts)
    """
    
    def __init__(
        self,
        task: str = 'classification',
        max_features: int = 60000,
        ngram_range: tuple = (1, 2),
        min_df: int = 2,
        # Classification params
        logreg_C: float = 1.0,
        logreg_max_iter: int = 5000,
        logreg_solver: str = 'saga',
        # Regression params
        ridge_alpha: float = 1.0,
        **kwargs
    ):
        """
        Initialize TF-IDF model.
        
        Args:
            task: 'classification' or 'regression'
            max_features: Maximum vocabulary size
            ngram_range: N-gram range for TF-IDF (default: unigrams + bigrams)
            min_df: Minimum document frequency for terms
            logreg_C: Inverse regularization strength for LogisticRegression
            logreg_max_iter: Max iterations for LogisticRegression
            logreg_solver: Solver for LogisticRegression ('lbfgs', 'saga', etc.)
            ridge_alpha: Regularization strength for Ridge regression
        """
        super().__init__(task=task)
        
        self.max_features = max_features
        self.ngram_range = ngram_range
        self.min_df = min_df
        self.logreg_C = logreg_C
        self.logreg_max_iter = logreg_max_iter
        self.logreg_solver = logreg_solver
        self.ridge_alpha = ridge_alpha
        
        # Initialize vectorizer
        self.vectorizer_ = TfidfVectorizer(
            max_features=max_features,
            ngram_range=ngram_range,
            min_df=min_df,
            lowercase=True,
        )
        
        # Initialize model (will be set in fit)
        self.model_ = None
    
    def fit(
        self,
        X: Union[List[str], np.ndarray],
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'TfidfModel':
        """
        Fit TF-IDF vectorizer and linear model.
        
        Args:
            X: Text data (list of strings or array of strings)
            y: Target labels (classification) or values (regression)
            sample_weight: Optional per-example weights
            
        Returns:
            self (fitted model)
        """
        # Convert to array if needed
        if isinstance(X, list):
            X = np.array(X)
        
        # Ensure X is 1D array of strings
        X = np.asarray(X).ravel()
        y = np.asarray(y)
        
        # Vectorize text
        X_vec = self.vectorizer_.fit_transform(X)
        
        # Initialize and fit model based on task
        if self.task == 'classification':
            self.model_ = LogisticRegression(
                C=self.logreg_C,
                max_iter=self.logreg_max_iter,
                solver=self.logreg_solver,
                n_jobs=-1,
            )
        else:  # regression
            self.model_ = Ridge(
                alpha=self.ridge_alpha,
                random_state=0,
            )
        
        # Fit with optional sample weights
        if sample_weight is not None:
            self.model_.fit(X_vec, y, sample_weight=sample_weight)
        else:
            self.model_.fit(X_vec, y)
        
        self.is_fitted_ = True
        return self
    
    def predict(
        self,
        X: Union[List[str], np.ndarray],
    ) -> np.ndarray:
        """
        Make predictions on text data.
        
        Args:
            X: Text data (list of strings or array of strings)
            
        Returns:
            Predictions (binary labels for classification, continuous for regression)
        """
        self._check_fitted()
        
        # Convert to array if needed
        if isinstance(X, list):
            X = np.array(X)
        
        # Ensure X is 1D array of strings
        X = np.asarray(X).ravel()
        
        # Vectorize and predict
        X_vec = self.vectorizer_.transform(X)
        return self.model_.predict(X_vec)
    
    def predict_proba(
        self,
        X: Union[List[str], np.ndarray],
    ) -> np.ndarray:
        """
        Get predicted class probabilities (classification only).
        
        Args:
            X: Text data
            
        Returns:
            Class probabilities, shape (N, 2)
        """
        if self.task != 'classification':
            raise NotImplementedError("predict_proba only available for classification")
        
        self._check_fitted()
        
        # Convert to array if needed
        if isinstance(X, list):
            X = np.array(X)
        
        # Ensure X is 1D array of strings
        X = np.asarray(X).ravel()
        
        # Vectorize and get probabilities
        X_vec = self.vectorizer_.transform(X)
        return self.model_.predict_proba(X_vec)
    
    def get_params(self) -> dict:
        """Get model hyperparameters."""
        params = super().get_params()
        params.update({
            'max_features': self.max_features,
            'ngram_range': self.ngram_range,
            'min_df': self.min_df,
        })
        if self.task == 'classification':
            params.update({
                'logreg_C': self.logreg_C,
                'logreg_max_iter': self.logreg_max_iter,
                'logreg_solver': self.logreg_solver,
            })
        else:
            params.update({
                'ridge_alpha': self.ridge_alpha,
            })
        return params
    
    def get_feature_names(self) -> Optional[np.ndarray]:
        """
        Get vocabulary feature names (requires fitted model).
        
        Returns:
            Array of feature names or None if not fitted
        """
        if not self.is_fitted_:
            return None
        return self.vectorizer_.get_feature_names_out()
    
    def get_vocabulary_size(self) -> Optional[int]:
        """
        Get actual vocabulary size (requires fitted model).
        
        Returns:
            Vocabulary size or None if not fitted
        """
        if not self.is_fitted_:
            return None
        return len(self.vectorizer_.vocabulary_)
