"""
Dynamic States HMM implementation.

This module provides a dynamic states HMM that automatically selects the optimal
number of states based on a model selection criterion.
"""

import numpy as np
from sklearn.model_selection import KFold
from ..core.hmm_glm import CategoricalHMMComponent

class DynamicStatesHMMComponent(CategoricalHMMComponent):
    """
    Dynamic States HMM component that selects the optimal number of states.
    """
    
    def __init__(self, min_states=2, max_states=5, criterion='bic', n_categories=2, 
               random_state=None, n_contexts=None):
        """
        Initialize the dynamic states HMM component.
        
        Parameters:
        -----------
        min_states : int, optional (default=2)
            Minimum number of states to consider
        max_states : int, optional (default=5)
            Maximum number of states to consider
        criterion : str, optional (default='bic')
            Model selection criterion ('bic', 'aic', or 'cv')
        n_categories : int, optional (default=2)
            Number of categories (2 for binary outcomes)
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        """
        # Initialize with min_states
        super().__init__(min_states, n_categories, random_state, n_contexts)
        
        self.min_states = min_states
        self.max_states = max_states
        self.criterion = criterion
        self.n_states_selected = min_states
        self.models = {}
        self.model_scores = {}
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the dynamic states HMM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Fit models with different numbers of states
        for n_states in range(self.min_states, self.max_states + 1):
            model = CategoricalHMMComponent(n_states, self.n_categories, 
                                          self.random_state, self.n_contexts)
            model.fit(X, sample_weight, contexts)
            
            self.models[n_states] = model
            
            # Compute model selection criterion
            if self.criterion == 'bic':
                score = self._compute_bic(model, X, contexts)
            elif self.criterion == 'aic':
                score = self._compute_aic(model, X, contexts)
            elif self.criterion == 'cv':
                score = self._compute_cv_score(model, X, sample_weight, contexts)
            else:
                raise ValueError(f"Unknown criterion: {self.criterion}")
            
            self.model_scores[n_states] = score
        
        # Select the best model
        if self.criterion in ['aic', 'bic']:
            # Lower is better for AIC/BIC
            self.n_states_selected = min(self.model_scores, key=self.model_scores.get)
        else:
            # Higher is better for CV score
            self.n_states_selected = max(self.model_scores, key=self.model_scores.get)
        
        # Copy parameters from the best model
        best_model = self.models[self.n_states_selected]
        self.n_states = best_model.n_states
        self.startprob_ = best_model.startprob_
        
        if self.context_aware:
            self.alpha_ = best_model.alpha_
            self.beta_ = best_model.beta_
            if hasattr(best_model, 'context_transmats_'):
                self.context_transmats_ = best_model.context_transmats_
        else:
            self.transmat_ = best_model.transmat_
        
        self.emissionprob_ = best_model.emissionprob_
        
        return self
    
    def _compute_bic(self, model, X, contexts=None):
        """
        Compute the Bayesian Information Criterion (BIC) for the model.
        
        Parameters:
        -----------
        model : CategoricalHMMComponent
            Fitted HMM model
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        bic : float
            BIC score (lower is better)
        """
        n_samples = len(X)
        
        # Compute log-likelihood
        log_likelihood = self._compute_log_likelihood(model, X, contexts)
        
        # Compute number of free parameters
        n_states = model.n_states
        n_categories = model.n_categories
        
        # Parameters:
        # - Initial state probabilities: n_states - 1 (sum to 1 constraint)
        # - Transition probabilities: n_states * (n_states - 1) (each row sums to 1)
        # - Emission probabilities: n_states * (n_categories - 1) (each row sums to 1)
        n_params = (n_states - 1) + n_states * (n_states - 1) + n_states * (n_categories - 1)
        
        # Add parameters for context-dependent transitions
        if model.context_aware and contexts is not None:
            n_contexts = contexts.shape[1]
            # For each context variable, we have n_states * n_states parameters
            n_params += n_states * n_states * n_contexts
        
        # Compute BIC
        bic = -2 * log_likelihood + n_params * np.log(n_samples)
        
        return bic
    
    def _compute_aic(self, model, X, contexts=None):
        """
        Compute the Akaike Information Criterion (AIC) for the model.
        
        Parameters:
        -----------
        model : CategoricalHMMComponent
            Fitted HMM model
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        aic : float
            AIC score (lower is better)
        """
        # Compute log-likelihood
        log_likelihood = self._compute_log_likelihood(model, X, contexts)
        
        # Compute number of free parameters
        n_states = model.n_states
        n_categories = model.n_categories
        
        # Parameters:
        # - Initial state probabilities: n_states - 1 (sum to 1 constraint)
        # - Transition probabilities: n_states * (n_states - 1) (each row sums to 1)
        # - Emission probabilities: n_states * (n_categories - 1) (each row sums to 1)
        n_params = (n_states - 1) + n_states * (n_states - 1) + n_states * (n_categories - 1)
        
        # Add parameters for context-dependent transitions
        if model.context_aware and contexts is not None:
            n_contexts = contexts.shape[1]
            # For each context variable, we have n_states * n_states parameters
            n_params += n_states * n_states * n_contexts
        
        # Compute AIC
        aic = -2 * log_likelihood + 2 * n_params
        
        return aic
    
    def _compute_cv_score(self, model, X, sample_weight=None, contexts=None, n_folds=5):
        """
        Compute cross-validation score for the model.
        
        Parameters:
        -----------
        model : CategoricalHMMComponent
            Fitted HMM model
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        n_folds : int, optional (default=5)
            Number of cross-validation folds
        
        Returns:
        --------
        cv_score : float
            Cross-validation score (higher is better)
        """
        # Initialize K-fold cross-validation
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=self.random_state)
        
        # Initialize scores
        scores = []
        
        # Perform cross-validation
        for train_index, test_index in kf.split(X):
            # Split data
            X_train = X[train_index]
            X_test = X[test_index]
            
            # Split sample weights if provided
            if sample_weight is not None:
                sample_weight_train = sample_weight[train_index]
            else:
                sample_weight_train = None
            
            # Split contexts if provided
            if contexts is not None:
                contexts_train = contexts[train_index]
                contexts_test = contexts[test_index]
            else:
                contexts_train = None
                contexts_test = None
            
            # Create and fit model
            cv_model = CategoricalHMMComponent(model.n_states, model.n_categories, 
                                             self.random_state, model.n_contexts)
            cv_model.fit(X_train, sample_weight_train, contexts_train)
            
            # Compute log-likelihood on test data
            log_likelihood = self._compute_log_likelihood(cv_model, X_test, contexts_test)
            
            # Add to scores
            scores.append(log_likelihood)
        
        # Compute mean score
        cv_score = np.mean(scores)
        
        return cv_score
    
    def _compute_log_likelihood(self, model, X, contexts=None):
        """
        Compute log-likelihood of the data given the model.
        
        Parameters:
        -----------
        model : CategoricalHMMComponent
            Fitted HMM model
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        log_likelihood : float
            Log-likelihood
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for forward algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, model.n_states))
        for i in range(model.n_states):
            emission_probs[:, i] = model.emissionprob_[i, X]
        
        # Prepare transition probabilities
        if model.context_aware and contexts is not None:
            from ..core.context_transitions import compute_context_dependent_transitions
            transition_probs = compute_context_dependent_transitions(
                model.alpha_, model.beta_, contexts)
        else:
            transition_probs = np.tile(model.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward algorithm
        forward = np.zeros((n_samples, model.n_states))
        
        # Base case
        forward[0] = model.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(model.n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Compute log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood


