"""
Dual Objective HMM implementation.

This module provides a dual objective HMM that optimizes for both prediction accuracy
and state diversity.
"""

import numpy as np
from ..core.hmm_glm import CategoricalHMMComponent

class DualObjectiveHMMComponent(CategoricalHMMComponent):
    """
    Dual Objective HMM component that optimizes for both prediction accuracy and state diversity.
    """
    
    def __init__(self, n_states, n_categories=2, random_state=None, n_contexts=None,
               diversity_weight=0.3):
        """
        Initialize the dual objective HMM component.
        
        Parameters:
        -----------
        n_states : int
            Number of latent states
        n_categories : int, optional (default=2)
            Number of categories (2 for binary outcomes)
        random_state : int, optional
            Random seed for reproducibility
        n_contexts : int, optional
            Number of context variables for context-aware transitions
        diversity_weight : float, optional (default=0.3)
            Weight for the state diversity objective (0 to 1)
        """
        super().__init__(n_states, n_categories, random_state, n_contexts)
        self.diversity_weight = diversity_weight
    
    def fit(self, X, sample_weight=None, contexts=None):
        """
        Fit the dual objective HMM component to data.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        self : object
            Returns self
        """
        # First, fit the model normally
        super().fit(X, sample_weight, contexts)
        
        # Then, adjust emission probabilities to increase state diversity
        self._adjust_for_diversity()
        
        return self
    
    def _adjust_for_diversity(self):
        """
        Adjust emission probabilities to increase state diversity.
        """
        # Calculate current state separation
        state_separation = self._calculate_state_separation()
        
        # Calculate target emission probabilities for maximum diversity
        target_emissionprob = np.zeros_like(self.emissionprob_)
        
        for i in range(self.n_states):
            # Set success probability to increase from 0 to 1 across states
            target_success_prob = i / (self.n_states - 1) if self.n_states > 1 else 0.5
            target_emissionprob[i, 1] = target_success_prob
            target_emissionprob[i, 0] = 1.0 - target_success_prob
        
        # Blend current and target emission probabilities
        self.emissionprob_ = (1 - self.diversity_weight) * self.emissionprob_ + \
                           self.diversity_weight * target_emissionprob
        
        # Ensure emission probabilities are valid
        for i in range(self.n_states):
            self.emissionprob_[i] /= np.sum(self.emissionprob_[i])
    
    def _calculate_state_separation(self):
        """
        Calculate the separation between state emission probabilities.
        
        Returns:
        --------
        separation : float
            Average separation between state emission probabilities
        """
        # Extract success probabilities for each state
        success_probs = self.emissionprob_[:, 1]
        
        # Calculate pairwise differences
        n_states = len(success_probs)
        differences = []
        
        for i in range(n_states):
            for j in range(i + 1, n_states):
                differences.append(abs(success_probs[i] - success_probs[j]))
        
        # Calculate average difference
        if differences:
            return np.mean(differences)
        else:
            return 0.0
    
    def _compute_weighted_log_likelihood(self, X, contexts=None, sample_weight=None):
        """
        Compute weighted log-likelihood with diversity penalty.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        weighted_log_likelihood : float
            Weighted log-likelihood with diversity penalty
        """
        # Compute standard log-likelihood
        log_likelihood = self._compute_log_likelihood(X, contexts)
        
        # Compute state diversity
        state_diversity = self._calculate_state_diversity(X, contexts)
        
        # Compute weighted log-likelihood with diversity bonus
        diversity_bonus = self.diversity_weight * state_diversity * abs(log_likelihood)
        weighted_log_likelihood = log_likelihood + diversity_bonus
        
        return weighted_log_likelihood
    
    def _calculate_state_diversity(self, X, contexts=None):
        """
        Calculate the diversity of state assignments.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        diversity : float
            State diversity index (0 to 1)
        """
        # Predict states
        states = self.predict(X, contexts)
        
        # Calculate state counts
        state_counts = np.bincount(states, minlength=self.n_states)
        state_probs = state_counts / len(states)
        
        # Remove zero probabilities
        state_probs = state_probs[state_probs > 0]
        
        # Calculate entropy
        entropy = -np.sum(state_probs * np.log(state_probs))
        
        # Normalize by maximum entropy
        max_entropy = np.log(self.n_states)
        diversity = entropy / max_entropy if max_entropy > 0 else 0.0
        
        return diversity
    
    def score(self, X, contexts=None, sample_weight=None):
        """
        Compute the average weighted log-likelihood per sample.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        sample_weight : ndarray, shape (n_samples,), optional
            Sample weights
        
        Returns:
        --------
        score : float
            Average weighted log-likelihood per sample
        """
        # Compute weighted log-likelihood
        weighted_log_likelihood = self._compute_weighted_log_likelihood(X, contexts, sample_weight)
        
        # Return average weighted log-likelihood per sample
        return weighted_log_likelihood / len(X)
    
    def _compute_log_likelihood(self, X, contexts=None):
        """
        Compute log-likelihood of the data given the model.
        
        Parameters:
        -----------
        X : ndarray, shape (n_samples,)
            Binary outcome variable (0 or 1)
        contexts : ndarray, shape (n_samples, n_contexts), optional
            Context variables for context-aware transitions
        
        Returns:
        --------
        log_likelihood : float
            Log-likelihood
        """
        # Ensure X is a 1D array
        X = np.asarray(X).ravel()
        
        # Prepare emission probabilities for forward algorithm
        n_samples = len(X)
        emission_probs = np.zeros((n_samples, self.n_states))
        for i in range(self.n_states):
            emission_probs[:, i] = self.emissionprob_[i, X]
        
        # Prepare transition probabilities
        if self.context_aware and contexts is not None:
            from ..core.context_transitions import compute_context_dependent_transitions
            transition_probs = compute_context_dependent_transitions(
                self.alpha_, self.beta_, contexts)
        else:
            transition_probs = np.tile(self.transmat_[np.newaxis, :, :], 
                                     (n_samples - 1, 1, 1))
        
        # Run forward algorithm
        forward = np.zeros((n_samples, self.n_states))
        
        # Base case
        forward[0] = self.startprob_ * emission_probs[0]
        forward[0] /= np.sum(forward[0])
        
        # Recursive case
        for t in range(1, n_samples):
            for j in range(self.n_states):
                forward[t, j] = np.sum(forward[t-1] * transition_probs[t-1, :, j]) * \
                              emission_probs[t, j]
            
            # Normalize to prevent underflow
            forward[t] /= np.sum(forward[t])
        
        # Compute log-likelihood
        log_likelihood = np.sum(np.log(np.sum(forward[-1])))
        
        return log_likelihood


