"""
HMM-GLM model implementation.

This module provides the main HMM-GLM model class that integrates the HMM and GLM components.
"""

import numpy as np
import time

class HMMGLMModel:
    """
    HMM-GLM model for latent performance states.
    """
    
    def __init__(self, hmm_component, glm_component):
        """
        Initialize the HMM-GLM model.
        
        Parameters:
        -----------
        hmm_component : HMMComponent
            HMM component for modeling latent states
        glm_component : GLMComponent
            GLM component for modeling outcomes within states
        """
        self.hmm_component = hmm_component
        self.glm_component = glm_component
        self.is_fitted = False
    
    def fit(self, X_hmm, X_glm, y, contexts=None, sample_weight=None):
        """
        Fit the HMM-GLM model.
        
        Parameters:
        -----------
        X_hmm : ndarray, shape (n_samples, n_hmm_features)
            Feature matrix for the HMM component
        X_glm : ndarray, shape (n_samples, n_glm_features)
            Feature matrix for the GLM component
        y : ndarray, shape (n_samples,)
            Target variable
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        self : object
            Returns self
        """
        start_time = time.time()
        
        # Step 1: Fit the HMM component
        print("Fitting HMM component...")
        self.hmm_component.fit(y, sample_weight, contexts)
        
        # Step 2: Predict states using the HMM component
        print("Predicting HMM states...")
        states = self.hmm_component.predict(y, contexts)
        
        # Step 3: Fit the GLM component using the predicted states
        print("Fitting GLM component...")
        self.glm_component.fit(X_glm, y, states, sample_weight)
        
        self.is_fitted = True
        
        end_time = time.time()
        print(f"HMM-GLM model training completed (time: {end_time - start_time:.2f}s)")
        
        return self
    
    def predict(self, X_hmm, X_glm, contexts=None):
        """
        Predict outcomes using the HMM-GLM model.
        
        Parameters:
        -----------
        X_hmm : ndarray, shape (n_samples, n_hmm_features)
            Feature matrix for the HMM component
        X_glm : ndarray, shape (n_samples, n_glm_features)
            Feature matrix for the GLM component
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        y_pred : ndarray, shape (n_samples,)
            Predicted outcomes
        states : ndarray, shape (n_samples,)
            Predicted states
        """
        if not self.is_fitted:
            raise ValueError("Model has not been fitted yet.")
        
        # Step 1: Predict states using the HMM component
        # For prediction, we need to use a different approach since we don't have y
        # We'll use the most likely state sequence based on the transition probabilities
        
        # For simplicity, we'll use a dummy approach here
        # In a real implementation, this would be more sophisticated
        if hasattr(self.hmm_component, 'startprob_'):
            # Use the most likely initial state
            most_likely_state = np.argmax(self.hmm_component.startprob_)
            states = np.full(len(X_hmm), most_likely_state)
        else:
            # Default to state 0
            states = np.zeros(len(X_hmm), dtype=int)
        
        # Step 2: Predict outcomes using the GLM component and predicted states
        y_pred = self.glm_component.predict(X_glm, states)
        
        return y_pred, states
    
    def predict_proba(self, X_hmm, X_glm, contexts=None):
        """
        Predict outcome probabilities using the HMM-GLM model.
        
        Parameters:
        -----------
        X_hmm : ndarray, shape (n_samples, n_hmm_features)
            Feature matrix for the HMM component
        X_glm : ndarray, shape (n_samples, n_glm_features)
            Feature matrix for the GLM component
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        y_proba : ndarray, shape (n_samples,)
            Predicted outcome probabilities
        """
        if not self.is_fitted:
            raise ValueError("Model has not been fitted yet.")
        
        # Step 1: Predict state probabilities using the HMM component
        # For simplicity, we'll use a dummy approach here
        # In a real implementation, this would be more sophisticated
        if hasattr(self.hmm_component, 'startprob_'):
            # Use the initial state probabilities
            state_probs = np.tile(self.hmm_component.startprob_, (len(X_hmm), 1))
        else:
            # Default to uniform state probabilities
            n_states = len(self.glm_component.models)
            state_probs = np.ones((len(X_hmm), n_states)) / n_states
        
        # Step 2: Predict outcome probabilities for each state
        y_proba = np.zeros(len(X_glm))
        
        # Get unique states
        unique_states = np.array(list(self.glm_component.models.keys()))
        
        # Predict for each state
        for i, state in enumerate(unique_states):
            # Create a dummy state assignment for this state
            dummy_states = np.full(len(X_glm), state)
            
            # Predict probabilities for this state
            state_outcome_probs = self.glm_component.predict_proba(X_glm, dummy_states)
            
            # Weight by state probability
            y_proba += state_probs[:, i] * state_outcome_probs
        
        return y_proba
    
    def score(self, X_hmm, X_glm, y, contexts=None):
        """
        Compute the mean accuracy score on the given test data and labels.
        
        Parameters:
        -----------
        X_hmm : ndarray, shape (n_samples, n_hmm_features)
            Feature matrix for the HMM component
        X_glm : ndarray, shape (n_samples, n_glm_features)
            Feature matrix for the GLM component
        y : ndarray, shape (n_samples,)
            True labels
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        score : float
            Mean accuracy score
        """
        y_pred, _ = self.predict(X_hmm, X_glm, contexts)
        return np.mean(y_pred == y)


