"""
HMM component for the HMM-GLM framework.

This module provides classes for the HMM component of the HMM-GLM framework,
including categorical and Gaussian emission models.
"""

import numpy as np
from abc import ABC, abstractmethod
from ..context_transitions import (
    compute_context_dependent_transitions,
    update_context_dependent_transitions,
    initialize_context_parameters
)

class HMMComponent(ABC):
    """
    Abstract base class for HMM components.
    """
    
    def __init__(self, n_states, random_state=None, n_contexts=None):
        """
        Initialize the HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        """
        self.n_states = n_states
        self.random_state = random_state
        self.n_contexts = n_contexts
        
        # Set random seed
        if random_state is not None:
            np.random.seed(random_state)
        
        # Initialize parameters
        self.startprob_ = np.ones(n_states) / n_states
        
        if n_contexts is not None:
            self.alpha_, self.beta_ = initialize_context_parameters(n_states, n_contexts)
            self.context_aware = True
        else:
            self.transmat_ = self._initialize_transmat()
            self.context_aware = False
    
    def _initialize_transmat(self):
        """Initialize transition matrix with high self-transition probabilities."""
        transmat = np.ones((self.n_states, self.n_states)) * 0.3 / (self.n_states - 1)
        np.fill_diagonal(transmat, 0.7)
        return transmat
    
    @abstractmethod
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the HMM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        pass
    
    @abstractmethod
    def predict(self, X, contexts=None):
        """
        Predict the most likely state sequence.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        states : ndarray, shape (n_samples,)
            Predicted state sequence
        """
        pass
    
    @abstractmethod
    def predict_proba(self, X, contexts=None):
        """
        Predict state probabilities.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        state_probs : ndarray, shape (n_samples, n_states)
            State probabilities
        """
        pass

class CategoricalHMMComponent(HMMComponent):
    """
    Categorical HMM component for binary outcomes.
    """
    
    def __init__(self, n_states, n_categories=2, random_state=None, n_contexts=None):
        """
        Initialize the categorical HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        n_categories : int, optional (default=2)
            Number of categories (2 for binary outcomes)
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        """
        super().__init__(n_states, random_state, n_contexts)
        self.n_categories = n_categories
        
        # Initialize emission probabilities
        self.emissionprob_ = self._initialize_emissionprob()
    
    def _initialize_emissionprob(self):
        """Initialize emission probabilities with different success rates across states."""
        emissionprob = np.zeros((self.n_states, self.n_categories))
        
        for i in range(self.n_states):
            # Success probability increases with state index
            success_prob = 0.1 + 0.8 * i / (self.n_states - 1) if self.n_states > 1 else 0.5
            emissionprob[i, 1] = success_prob
            emissionprob[i, 0] = 1.0 - success_prob
        
        return emissionprob
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the categorical HMM component to binary data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Initialize parameters
        n_samples = len(X)
        
        # Prepare emission probabilities for forward-backward algorithm
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Apply sample weights if provided
        if sample_weight is not None:
            emission_probs = emission_probs ** sample_weight[:, np.newaxis]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward-backward algorithm
        log_likelihood, state_probs, transition_counts = self._forward_backward(
            emission_probs, transition_probs)
        
        # M-step: Update parameters
        
        # Update initial state probabilities
        self.startprob_ = state_probs[0] / np.sum(state_probs[0])
        
        # Update transition parameters
        if self.context_aware and contexts is not None:
            self.alpha_, self.beta_ = update_context_dependent_transitions(
                self.alpha_, self.beta_, contexts[:-1], transition_counts, sample_weight[:-1] if sample_weight is not None else None)
        else:
            # Standard transition matrix update
            self.transmat_ = np.sum(transition_counts, axis=0)
            row_sums = np.sum(self.transmat_, axis=1)
            self.transmat_ = self.transmat_ / row_sums[:, np.newaxis]
        
        # Update emission probabilities
        for i in range(self.n_states):
            for k in range(self.n_categories):
                if sample_weight is not None:
                    self.emissionprob_[i, k] = np.sum(state_probs[:, i] * (X == k) * sample_weight) / np.sum(state_probs[:, i] * sample_weight)
                else:
                    self.emissionprob_[i, k] = np.sum(state_probs[:, i] * (X == k)) / np.sum(state_probs[:, i])
        
        return self
    
    def predict(self, X, contexts=None):
        """
        Predict the most likely state sequence.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        states : ndarray, shape (n_samples,)
            Predicted state sequence
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for Viterbi algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run Viterbi algorithm
        states = self._viterbi(emission_probs, transition_probs)
        
        return states
    
    def predict_proba(self, X, contexts=None):
        """
        Predict state probabilities.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        state_probs : ndarray, shape (n_samples, n_states)
            State probabilities
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for forward algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward algorithm
        state_probs = self._forward(emission_probs, transition_probs)
        
        return state_probs
    
    def _forward_backward(self, emission_probs, transition_probs):
        """Perform forward-backward algorithm."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        forward = np.zeros((n_samples, n_states))
        backward = np.zeros((n_samples, n_states))
        
        # Forward pass
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        for t in range(1, n_samples):
            for j in range(n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Backward pass
        backward[-1] = 1.0
        
        for t in range(n_samples-2, -1, -1):
            for i in range(n_states):
                backward[t, i] = np.sum(backward[t+1] * transition_probs[t, i, :] * 
                                      emission_probs[t+1, :])
            
            # Normalize to prevent underflow
            backward[t] /= np.sum(backward[t])
        
        # Calculate state probabilities
        state_probs = forward * backward
        state_probs /= np.sum(state_probs, axis=1)[:, np.newaxis]
        
        # Calculate transition counts
        transition_counts = np.zeros((n_samples-1, n_states, n_states))
        for t in range(n_samples-1):
            for i in range(n_states):
                for j in range(n_states):
                    transition_counts[t, i, j] = forward[t, i] * \
                                               transition_probs[t, i, j] * \
                                               emission_probs[t+1, j] * \
                                               backward[t+1, j]
            
            # Normalize
            transition_counts[t] /= np.sum(transition_counts[t])
        
        # Calculate log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood, state_probs, transition_counts
    
    def _viterbi(self, emission_probs, transition_probs):
        """Viterbi algorithm for finding most likely state sequence."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        viterbi = np.zeros((n_samples, n_states))
        backpointer = np.zeros((n_samples, n_states), dtype=int)
        
        # Base case
        viterbi[0] = np.log(self.startprob_) + np.log(emission_probs[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(n_states):
                # Calculate probabilities of transitioning to state j
                probs = viterbi[t-1] + np.log(transition_probs[t-1, :, j])
                
                # Find most likely previous state
                backpointer[t, j] = np.argmax(probs)
                viterbi[t, j] = probs[backpointer[t, j]] + np.log(emission_probs[t, j])
        
        # Backtrack
        states = np.zeros(n_samples, dtype=int)
        states[-1] = np.argmax(viterbi[-1])
        
        for t in range(n_samples-2, -1, -1):
            states[t] = backpointer[t+1, states[t+1]]
        
        return states
    
    def _forward(self, emission_probs, transition_probs):
        """Forward algorithm for calculating state probabilities."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        forward = np.zeros((n_samples, n_states))
        
        # Base case
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        return forward

class GaussianHMMComponent(HMMComponent):
    """
    Gaussian HMM component for continuous observations.
    """
    
    def __init__(self, n_states, n_features, random_state=None, n_contexts=None):
        """
        Initialize the Gaussian HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        n_features : int
            Number of features
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        """
        super().__init__(n_states, random_state, n_contexts)
        self.n_features = n_features
        
        # Initialize means and covariances
        self.means_ = np.zeros((n_states, n_features))
        self.covars_ = np.tile(np.eye(n_features), (n_states, 1, 1))
    
    def _initialize_means_covars(self, X):
        """Initialize means and covariances using K-means."""
        from sklearn.cluster import KMeans
        
        # Use K-means to initialize means
        kmeans = KMeans(n_clusters=self.n_states, random_state=self.random_state)
        kmeans.fit(X)
        
        # Set means to cluster centroids
        self.means_ = kmeans.cluster_centers_
        
        # Set covariances to cluster covariances
        for i in range(self.n_states):
            cluster_samples = X[kmeans.labels_ == i]
            if len(cluster_samples) > 1:
                self.covars_[i] = np.cov(cluster_samples, rowvar=False) + 1e-6 * np.eye(self.n_features)
            else:
                self.covars_[i] = np.eye(self.n_features)
    
    def _compute_emission_probs(self, X):
        """Compute emission probabilities for Gaussian observations."""
        from scipy.stats import multivariate_normal
        
        n_samples = X.shape[0]
        emission_probs = np.zeros((n_samples, self.n_states))
        
        for i in range(self.n_states):
            mvn = multivariate_normal(mean=self.means_[i], cov=self.covars_[i])
            emission_probs[:, i] = mvn.pdf(X)
        
        # Normalize to prevent underflow
        emission_probs /= np.sum(emission_probs, axis=1)[:, np.newaxis]
        
        return emission_probs
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the Gaussian HMM component to continuous data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Initialize means and covariances
        self._initialize_means_covars(X)
        
        # Compute emission probabilities
        emission_probs = self._compute_emission_probs(X)
        
        # Apply sample weights if provided
        if sample_weight is not None:
            emission_probs = emission_probs ** sample_weight[:, np.newaxis]
        
        # Prepare transition probabilities
        n_samples = X.shape[0]
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward-backward algorithm
        log_likelihood, state_probs, transition_counts = self._forward_backward(
            emission_probs, transition_probs)
        
        # M-step: Update parameters
        
        # Update initial state probabilities
        self.startprob_ = state_probs[0] / np.sum(state_probs[0])
        
        # Update transition parameters
        if self.context_aware and contexts is not None:
            self.alpha_, self.beta_ = update_context_dependent_transitions(
                self.alpha_, self.beta_, contexts[:-1], transition_counts, sample_weight[:-1] if sample_weight is not None else None)
        else:
            # Standard transition matrix update
            self.transmat_ = np.sum(transition_counts, axis=0)
            row_sums = np.sum(self.transmat_, axis=1)
            self.transmat_ = self.transmat_ / row_sums[:, np.newaxis]
        
        # Update means and covariances
        for i in range(self.n_states):
            if sample_weight is not None:
                weights = state_probs[:, i] * sample_weight
                weights_sum = np.sum(weights)
                if weights_sum > 0:
                    self.means_[i] = np.sum(X * weights[:, np.newaxis], axis=0) / weights_sum
                    diff = X - self.means_[i]
                    self.covars_[i] = np.dot(diff.T, diff * weights[:, np.newaxis]) / weights_sum + 1e-6 * np.eye(self.n_features)
            else:
                weights = state_probs[:, i]
                weights_sum = np.sum(weights)
                if weights_sum > 0:
                    self.means_[i] = np.sum(X * weights[:, np.newaxis], axis=0) / weights_sum
                    diff = X - self.means_[i]
                    self.covars_[i] = np.dot(diff.T, diff * weights[:, np.newaxis]) / weights_sum + 1e-6 * np.eye(self.n_features)
        
        return self
    
    def predict(self, X, contexts=None):
        """
        Predict the most likely state sequence.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        states : ndarray, shape (n_samples,)
            Predicted state sequence
        """
        # Compute emission probabilities
        emission_probs = self._compute_emission_probs(X)
        
        # Prepare transition probabilities
        n_samples = X.shape[0]
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run Viterbi algorithm
        states = self._viterbi(emission_probs, transition_probs)
        
        return states
    
    def predict_proba(self, X, contexts=None):
        """
        Predict state probabilities.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples, n_features)
            Feature matrix
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        state_probs : ndarray, shape (n_samples, n_states)
            State probabilities
        """
        # Compute emission probabilities
        emission_probs = self._compute_emission_probs(X)
        
        # Prepare transition probabilities
        n_samples = X.shape[0]
        if self.context_aware and contexts is not None:
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward algorithm
        state_probs = self._forward(emission_probs, transition_probs)
        
        return state_probs
    
    def _forward_backward(self, emission_probs, transition_probs):
        """Perform forward-backward algorithm."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        forward = np.zeros((n_samples, n_states))
        backward = np.zeros((n_samples, n_states))
        
        # Forward pass
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        for t in range(1, n_samples):
            for j in range(n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Backward pass
        backward[-1] = 1.0
        
        for t in range(n_samples-2, -1, -1):
            for i in range(n_states):
                backward[t, i] = np.sum(backward[t+1] * transition_probs[t, i, :] * 
                                      emission_probs[t+1, :])
            
            # Normalize to prevent underflow
            backward[t] /= np.sum(backward[t])
        
        # Calculate state probabilities
        state_probs = forward * backward
        state_probs /= np.sum(state_probs, axis=1)[:, np.newaxis]
        
        # Calculate transition counts
        transition_counts = np.zeros((n_samples-1, n_states, n_states))
        for t in range(n_samples-1):
            for i in range(n_states):
                for j in range(n_states):
                    transition_counts[t, i, j] = forward[t, i] * \
                                               transition_probs[t, i, j] * \
                                               emission_probs[t+1, j] * \
                                               backward[t+1, j]
            
            # Normalize
            transition_counts[t] /= np.sum(transition_counts[t])
        
        # Calculate log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood, state_probs, transition_counts
    
    def _viterbi(self, emission_probs, transition_probs):
        """Viterbi algorithm for finding most likely state sequence."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        viterbi = np.zeros((n_samples, n_states))
        backpointer = np.zeros((n_samples, n_states), dtype=int)
        
        # Base case
        viterbi[0] = np.log(self.startprob_) + np.log(emission_probs[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(n_states):
                # Calculate probabilities of transitioning to state j
                probs = viterbi[t-1] + np.log(transition_probs[t-1, :, j])
                
                # Find most likely previous state
                backpointer[t, j] = np.argmax(probs)
                viterbi[t, j] = probs[backpointer[t, j]] + np.log(emission_probs[t, j])
        
        # Backtrack
        states = np.zeros(n_samples, dtype=int)
        states[-1] = np.argmax(viterbi[-1])
        
        for t in range(n_samples-2, -1, -1):
            states[t] = backpointer[t+1, states[t+1]]
        
        return states
    
    def _forward(self, emission_probs, transition_probs):
        """Forward algorithm for calculating state probabilities."""
        n_samples = emission_probs.shape[0]
        n_states = emission_probs.shape[1]
        
        # Initialize
        forward = np.zeros((n_samples, n_states))
        
        # Base case
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        return forward


