"""
Combined Context HMM implementation.

This module provides a combined context HMM that integrates multiple context variables
into a unified context representation for transition modeling.
"""

import numpy as np
from ..core.hmm_glm import CategoricalHMMComponent
from ..core.context_transitions import compute_context_dependent_transitions

class CombinedContextHMMComponent(CategoricalHMMComponent):
    """
    Combined Context HMM component that integrates multiple context variables.
    """
    
    def __init__(self, n_states, n_categories=2, random_state=None, n_contexts=None,
               context_combination='weighted_sum'):
        """
        Initialize the combined context HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        n_categories : int, optional (default=2)
            Number of categories (2 for binary outcomes)
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables
        context_combination : str, optional (default='weighted_sum')
            Method for combining context variables ('weighted_sum', 'product', or 'max')
        """
        super().__init__(n_states, n_categories, random_state, n_contexts)
        self.context_combination = context_combination
        self.context_weights = None
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the combined context HMM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables
        
        Returns:
        --------
        self : object
            Returns self
        """
        # If contexts are provided, learn context weights
        if contexts is not None:
            self.n_contexts = contexts.shape[1]
            self.context_weights = self._learn_context_weights(X, contexts)
            
            # Combine contexts
            combined_contexts = self._combine_contexts(contexts)
            
            # Store the combined contexts
            self.combined_contexts = combined_contexts
            
            # Fit the model with combined contexts
            super().fit(X, sample_weight, combined_contexts.reshape(-1, 1))
        else:
            # Fit without contexts
            super().fit(X, sample_weight, None)
        
        return self
    
    def predict(self, X, contexts=None):
        """
        Predict the most likely state sequence.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables
        
        Returns:
        --------
        states : ndarray, shape (n_samples,)
            Predicted state sequence
        """
        # If contexts are provided, combine them
        if contexts is not None and self.context_weights is not None:
            combined_contexts = self._combine_contexts(contexts)
            return super().predict(X, combined_contexts.reshape(-1, 1))
        else:
            return super().predict(X, None)
    
    def predict_proba(self, X, contexts=None):
        """
        Predict state probabilities.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables
        
        Returns:
        --------
        state_probs : ndarray, shape (n_samples, n_states)
            State probabilities
        """
        # If contexts are provided, combine them
        if contexts is not None and self.context_weights is not None:
            combined_contexts = self._combine_contexts(contexts)
            return super().predict_proba(X, combined_contexts.reshape(-1, 1))
        else:
            return super().predict_proba(X, None)
    
    def _learn_context_weights(self, X, contexts):
        """
        Learn weights for combining context variables.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts)
            Context variables
        
        Returns:
        --------
        weights : ndarray, shape (n_contexts,)
            Context weights
        """
        # Initialize weights
        n_contexts = contexts.shape[1]
        weights = np.ones(n_contexts) / n_contexts
        
        # Calculate correlation between each context variable and the outcome
        for i in range(n_contexts):
            # Calculate correlation
            corr = np.corrcoef(contexts[:, i], X)[0, 1]
            
            # Update weight based on absolute correlation
            weights[i] = abs(corr)
        
        # Normalize weights to sum to 1
        if np.sum(weights) > 0:
            weights = weights / np.sum(weights)
        else:
            weights = np.ones(n_contexts) / n_contexts
        
        return weights
    
    def _combine_contexts(self, contexts):
        """
        Combine multiple context variables into a single context representation.
        
        Parameters:
        -----------
        contexts : ndarray, shape (n_samples, n_contexts)
            Context variables
        
        Returns:
        --------
        combined_contexts : ndarray, shape (n_samples,)
            Combined context representation
        """
        if self.context_combination == 'weighted_sum':
            # Weighted sum of context variables
            combined_contexts = np.sum(contexts * self.context_weights, axis=1)
        elif self.context_combination == 'product':
            # Product of weighted context variables
            combined_contexts = np.prod(contexts ** self.context_weights, axis=1)
        elif self.context_combination == 'max':
            # Maximum of weighted context variables
            combined_contexts = np.max(contexts * self.context_weights, axis=1)
        else:
            raise ValueError(f"Unknown context combination method: {self.context_combination}")
        
        return combined_contexts
    
    def score(self, X, contexts=None):
        """
        Compute the average log-likelihood per sample.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables
        
        Returns:
        --------
        score : float
            Average log-likelihood per sample
        """
        # If contexts are provided, combine them
        if contexts is not None and self.context_weights is not None:
            combined_contexts = self._combine_contexts(contexts)
            contexts_reshaped = combined_contexts.reshape(-1, 1)
        else:
            contexts_reshaped = None
        
        # Compute log-likelihood
        log_likelihood = self._compute_log_likelihood(X, contexts_reshaped)
        
        # Return average log-likelihood per sample
        return log_likelihood / len(X)
    
    def _compute_log_likelihood(self, X, contexts=None):
        """
        Compute log-likelihood of the data given the model.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables
        
        Returns:
        --------
        log_likelihood : float
            Log-likelihood
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for forward algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward algorithm
        forward = np.zeros((n_samples, self.n_states))
        
        # Base case
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(self.n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Compute log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood


