"""
Partial Weighted HMM implementation.

This module provides a partial weighted HMM that applies different weights to
different components of the model.
"""

import numpy as np
from ..core.hmm_glm import CategoricalHMMComponent

class PartialWeightedHMMComponent(CategoricalHMMComponent):
    """
    Partial Weighted HMM component that applies different weights to different components.
    """
    
    def __init__(self, n_states, n_categories=2, random_state=None, n_contexts=None,
               emission_weight=1.0, transition_weight=0.5):
        """
        Initialize the partial weighted HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        n_categories : int, optional (default=2)
            Number of categories (2 for binary outcomes)
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        emission_weight : float, optional (default=1.0)
            Weight for emission probabilities
        transition_weight : float, optional (default=0.5)
            Weight for transition probabilities
        """
        super().__init__(n_states, n_categories, random_state, n_contexts)
        self.emission_weight = emission_weight
        self.transition_weight = transition_weight
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the partial weighted HMM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Initialize parameters
        n_samples = len(X)
        
        # Prepare emission probabilities for forward-backward algorithm
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Apply sample weights if provided
        if sample_weight is not None:
            # Apply emission weight to sample weights
            emission_weights = sample_weight ** self.emission_weight
            emission_probs = emission_probs ** emission_weights[:, np.newaxis]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            from ..core.context_transitions import compute_context_dependent_transitions
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Apply transition weight to transition probabilities
        transition_probs = transition_probs ** self.transition_weight
        
        # Normalize transition probabilities
        for t in range(n_samples - 1):
            for i in range(self.n_states):
                transition_probs[t, i, :] /= np.sum(transition_probs[t, i, :])
        
        # Run forward-backward algorithm
        log_likelihood, state_probs, transition_counts = self._forward_backward(
            emission_probs, transition_probs)
        
        # M-step: Update parameters
        
        # Update initial state probabilities
        self.startprob_ = state_probs[0] / np.sum(state_probs[0])
        
        # Update transition parameters
        if self.context_aware and contexts is not None:
            from ..core.context_transitions import update_context_dependent_transitions
            self.alpha_, self.beta_ = update_context_dependent_transitions(
                self.alpha_, self.beta_, contexts[:-1], transition_counts, 
                sample_weight[:-1] if sample_weight is not None else None)
        else:
            # Standard transition matrix update
            self.transmat_ = np.sum(transition_counts, axis=0)
            row_sums = np.sum(self.transmat_, axis=1)
            self.transmat_ = self.transmat_ / row_sums[:, np.newaxis]
        
        # Update emission probabilities
        for i in range(self.n_states):
            for k in range(self.n_categories):
                if sample_weight is not None:
                    self.emissionprob_[i, k] = np.sum(state_probs[:, i] * (X == k) * sample_weight) / \
                                             np.sum(state_probs[:, i] * sample_weight)
                else:
                    self.emissionprob_[i, k] = np.sum(state_probs[:, i] * (X == k)) / \
                                             np.sum(state_probs[:, i])
        
        return self
    
    def _compute_weighted_log_likelihood(self, X, contexts=None, sample_weight=None):
        """
        Compute weighted log-likelihood of the data given the model.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        log_likelihood : float
            Weighted log-likelihood
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for forward algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Apply sample weights if provided
        if sample_weight is not None:
            # Apply emission weight to sample weights
            emission_weights = sample_weight ** self.emission_weight
            emission_probs = emission_probs ** emission_weights[:, np.newaxis]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            from ..core.context_transitions import compute_context_dependent_transitions
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Apply transition weight to transition probabilities
        transition_probs = transition_probs ** self.transition_weight
        
        # Normalize transition probabilities
        for t in range(n_samples - 1):
            for i in range(self.n_states):
                transition_probs[t, i, :] /= np.sum(transition_probs[t, i, :])
        
        # Run forward algorithm
        forward = np.zeros((n_samples, self.n_states))
        
        # Base case
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(self.n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Compute log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood
    
    def score(self, X, contexts=None, sample_weight=None):
        """
        Compute the average weighted log-likelihood per sample.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        score : float
            Average weighted log-likelihood per sample
        """
        # Compute weighted log-likelihood
        log_likelihood = self._compute_weighted_log_likelihood(X, contexts, sample_weight)
        
        # Return average log-likelihood per sample
        return log_likelihood / len(X)


