"""
GLM component for the HMM-GLM framework.

This module provides classes for the GLM component of the HMM-GLM framework,
including logistic and Poisson regression models.
"""

import numpy as np
from abc import ABC, abstractmethod
from sklearn.linear_model import LogisticRegression, PoissonRegressor
from sklearn.dummy import DummyClassifier, DummyRegressor

class GLMComponent(ABC):
    """
    Abstract base class for GLM components.
    """
    
    def __init__(self, features=None, random_state=None):
        """
        Initialize the GLM component.
        
        Parameters:
        -----------
        features : list of str, optional
            Names of features to use in the GLM
        random_state : int, optional
            Random seed for reproducibility
        """
        self.features = features
        self.random_state = random_state
        self.models = None
    
    @abstractmethod
    def fit(self, X, y, states, sample_weight=None):
        """
        Fit the GLM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        y : ndarray, shape (n_samples,)
            Target variable
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        self : object
            Returns self
        """
        pass
    
    @abstractmethod
    def predict(self, X, states):
        """
        Predict outcomes based on features and states.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_pred : ndarray, shape (n_samples,)
            Predicted outcomes
        """
        pass
    
    @abstractmethod
    def predict_proba(self, X, states):
        """
        Predict outcome probabilities based on features and states.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_proba : ndarray, shape (n_samples,)
            Predicted outcome probabilities
        """
        pass

class LogisticGLMComponent(GLMComponent):
    """
    Logistic regression GLM component for binary outcomes.
    """
    
    def __init__(self, features=None, C=1.0, class_weight=None, random_state=None):
        """
        Initialize the logistic regression GLM component.
        
        Parameters:
        -----------
        features : list of str, optional
            Names of features to use in the GLM
        C : float, optional (default=1.0)
            Inverse of regularization strength
        class_weight : str or dict, optional
            Class weights for imbalanced data
        random_state : int, optional
            Random seed for reproducibility
        """
        super().__init__(features, random_state)
        self.C = C
        self.class_weight = class_weight
        self.models = {}
    
    def fit(self, X, y, states, sample_weight=None):
        """
        Fit a separate logistic regression model for each state.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        y : ndarray, shape (n_samples,)
            Binary target variable (0 or 1)
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Get unique states
        unique_states = np.unique(states)
        
        # Fit a model for each state
        for state in unique_states:
            # Get samples for this state
            mask = (states == state)
            X_state = X[mask]
            y_state = y[mask]
            weights_state = sample_weight[mask] if sample_weight is not None else None
            
            # Check if there are enough samples
            if len(X_state) > 0:
                # Check if all samples have the same class
                if len(np.unique(y_state)) == 1:
                    # Use a dummy classifier that always predicts the same class
                    constant_value = y_state[0]
                    print(f"State {state} only has class {constant_value}. "
                         f"Creating a dummy classifier that always predicts this class.")
                    self.models[state] = DummyClassifier(strategy="constant", 
                                                      constant=constant_value)
                    self.models[state].fit(X_state, y_state)
                else:
                    # Use logistic regression
                    self.models[state] = LogisticRegression(
                        C=self.C, 
                        class_weight=self.class_weight,
                        random_state=self.random_state,
                        solver='liblinear'  # Use liblinear for small datasets
                    )
                    self.models[state].fit(X_state, y_state, sample_weight=weights_state)
                    
                    print(f"State {state} model trained: n_samples = {len(X_state)}")
            else:
                print(f"No samples for state {state}")
        
        return self
    
    def predict(self, X, states):
        """
        Predict binary outcomes based on features and states.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_pred : ndarray, shape (n_samples,)
            Predicted binary outcomes (0 or 1)
        """
        y_pred = np.zeros(len(X))
        
        # Get unique states
        unique_states = np.unique(states)
        
        # Predict for each state
        for state in unique_states:
            # Get samples for this state
            mask = (states == state)
            X_state = X[mask]
            
            # Check if there are samples and if we have a model for this state
            if len(X_state) > 0 and state in self.models:
                y_pred[mask] = self.models[state].predict(X_state)
        
        return y_pred
    
    def predict_proba(self, X, states):
        """
        Predict success probabilities based on features and states.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_proba : ndarray, shape (n_samples,)
            Predicted success probabilities
        """
        y_proba = np.zeros(len(X))
        
        # Get unique states
        unique_states = np.unique(states)
        
        # Predict for each state
        for state in unique_states:
            # Get samples for this state
            mask = (states == state)
            X_state = X[mask]
            
            # Check if there are samples and if we have a model for this state
            if len(X_state) > 0 and state in self.models:
                if isinstance(self.models[state], DummyClassifier):
                    # For dummy classifiers, use the constant value as probability
                    y_proba[mask] = self.models[state].constant
                else:
                    # For logistic regression, get the probability of class 1
                    y_proba[mask] = self.models[state].predict_proba(X_state)[:, 1]
        
        return y_proba

class PoissonGLMComponent(GLMComponent):
    """
    Poisson regression GLM component for count outcomes.
    """
    
    def __init__(self, features=None, alpha=1.0, random_state=None):
        """
        Initialize the Poisson regression GLM component.
        
        Parameters:
        -----------
        features : list of str, optional
            Names of features to use in the GLM
        alpha : float, optional (default=1.0)
            Regularization strength
        random_state : int, optional
            Random seed for reproducibility
        """
        super().__init__(features, random_state)
        self.alpha = alpha
        self.models = {}
    
    def fit(self, X, y, states, sample_weight=None):
        """
        Fit a separate Poisson regression model for each state.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        y : ndarray, shape (n_samples,)
            Count target variable (non-negative integers)
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Get unique states
        unique_states = np.unique(states)
        
        # Fit a model for each state
        for state in unique_states:
            # Get samples for this state
            mask = (states == state)
            X_state = X[mask]
            y_state = y[mask]
            weights_state = sample_weight[mask] if sample_weight is not None else None
            
            # Check if there are enough samples
            if len(X_state) > 0:
                # Check if all samples have the same value
                if len(np.unique(y_state)) == 1:
                    # Use a dummy regressor that always predicts the same value
                    constant_value = y_state[0]
                    print(f"State {state} only has value {constant_value}. "
                         f"Creating a dummy regressor that always predicts this value.")
                    self.models[state] = DummyRegressor(strategy="constant", 
                                                     constant=constant_value)
                    self.models[state].fit(X_state, y_state)
                else:
                    # Use Poisson regression
                    self.models[state] = PoissonRegressor(
                        alpha=self.alpha,
                        random_state=self.random_state
                    )
                    self.models[state].fit(X_state, y_state, sample_weight=weights_state)
                    
                    print(f"State {state} model trained: n_samples = {len(X_state)}")
            else:
                print(f"No samples for state {state}")
        
        return self
    
    def predict(self, X, states):
        """
        Predict count outcomes based on features and states.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_pred : ndarray, shape (n_samples,)
            Predicted count outcomes
        """
        y_pred = np.zeros(len(X))
        
        # Get unique states
        unique_states = np.unique(states)
        
        # Predict for each state
        for state in unique_states:
            # Get samples for this state
            mask = (states == state)
            X_state = X[mask]
            
            # Check if there are samples and if we have a model for this state
            if len(X_state) > 0 and state in self.models:
                y_pred[mask] = self.models[state].predict(X_state)
        
        return y_pred
    
    def predict_proba(self, X, states):
        """
        Predict expected counts based on features and states.
        
        For Poisson regression, this is the same as predict().
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        states : ndarray, shape (n_samples,)
            State assignments from the HMM component
        
        Returns:
        --------
        y_proba : ndarray, shape (n_samples,)
            Predicted expected counts
        """
        return self.predict(X, states)


