"""
Context-aware transition matrices for HMM.

This module provides functions for computing context-dependent transition probabilities
and updating the context-specific parameters during the M-step of the EM algorithm.
"""

import numpy as np
from scipy.optimize import minimize

def compute_context_dependent_transitions(alpha, beta, contexts):
    """
    Compute context-dependent transition probabilities.
    
    Parameters:
    -----------
    alpha : ndarray, shape (n_states, n_states)
        Base transition log-probabilities
    beta : ndarray, shape (n_states, n_states, n_context_vars)
        Context-specific adjustment parameters
    contexts : ndarray, shape (n_samples, n_context_vars)
        Context variables for each sample
    
    Returns:
    --------
    trans_probs : ndarray, shape (n_samples, n_states, n_states)
        Context-dependent transition probabilities for each sample
    """
    n_samples, n_context_vars = contexts.shape
    n_states = alpha.shape[0]
    
    # Initialize transition probabilities
    trans_probs = np.zeros((n_samples, n_states, n_states))
    
    # Compute unnormalized log-probabilities
    for t in range(n_samples):
        for i in range(n_states):
            for j in range(n_states):
                # Base log-probability
                log_prob = alpha[i, j]
                
                # Add context-specific adjustments
                for k in range(n_context_vars):
                    log_prob += beta[i, j, k] * contexts[t, k]
                
                trans_probs[t, i, j] = np.exp(log_prob)
    
    # Normalize to ensure rows sum to 1
    for t in range(n_samples):
        for i in range(n_states):
            row_sum = np.sum(trans_probs[t, i, :])
            if row_sum > 0:
                trans_probs[t, i, :] /= row_sum
    
    return trans_probs

def update_context_dependent_transitions(alpha, beta, contexts, transition_counts, 
                                      sample_weight=None, regularization=0.01):
    """
    Update context-dependent transition parameters using optimization.
    
    Parameters:
    -----------
    alpha : ndarray, shape (n_states, n_states)
        Current base transition log-probabilities
    beta : ndarray, shape (n_states, n_states, n_context_vars)
        Current context-specific adjustment parameters
    contexts : ndarray, shape (n_samples, n_context_vars)
        Context variables for each sample
    transition_counts : ndarray, shape (n_samples-1, n_states, n_states)
        Expected transition counts from forward-backward algorithm
    sample_weight : ndarray, shape (n_samples-1,), optional
        Sample weights
    regularization : float, optional (default=0.01)
        Regularization strength for L2 regularization
    
    Returns:
    --------
    alpha_new : ndarray, shape (n_states, n_states)
        Updated base transition log-probabilities
    beta_new : ndarray, shape (n_states, n_states, n_context_vars)
        Updated context-specific adjustment parameters
    """
    n_states = alpha.shape[0]
    n_context_vars = contexts.shape[1]
    
    # Apply sample weights if provided
    if sample_weight is not None:
        weighted_counts = transition_counts * sample_weight[:, np.newaxis, np.newaxis]
    else:
        weighted_counts = transition_counts
    
    # Initialize new parameters
    alpha_new = np.zeros_like(alpha)
    beta_new = np.zeros_like(beta)
    
    # Define objective function for optimization
    def objective(params, i, j):
        alpha_ij = params[0]
        beta_ij = params[1:n_context_vars+1]
        
        log_probs = np.zeros(len(contexts))
        for t in range(len(contexts)):
            log_prob = alpha_ij + np.sum(beta_ij * contexts[t])
            log_probs[t] = log_prob
        
        # Calculate normalization terms
        Z = np.zeros(len(contexts))
        for j_prime in range(n_states):
            alpha_ij_prime = alpha[i, j_prime] if j_prime != j else alpha_ij
            beta_ij_prime = beta[i, j_prime] if j_prime != j else beta_ij
            
            for t in range(len(contexts)):
                Z[t] += np.exp(alpha_ij_prime + np.sum(beta_ij_prime * contexts[t]))
        
        # Calculate negative log-likelihood with regularization
        neg_ll = -np.sum(weighted_counts[:, i, j] * 
                       (log_probs - np.log(Z)))
        
        # Add L2 regularization
        reg_term = regularization * (alpha_ij**2 + np.sum(beta_ij**2))
        
        return neg_ll + reg_term
    
    # Optimize parameters for each transition
    for i in range(n_states):
        for j in range(n_states):
            initial_params = np.concatenate([[alpha[i, j]], beta[i, j]])
            
            result = minimize(objective, initial_params, args=(i, j),
                            method='L-BFGS-B')
            
            alpha_new[i, j] = result.x[0]
            beta_new[i, j] = result.x[1:n_context_vars+1]
    
    return alpha_new, beta_new

def initialize_context_parameters(n_states, n_contexts):
    """
    Initialize context-dependent transition parameters.
    
    Parameters:
    -----------
    n_states : int
        Number of states
    n_contexts : int
        Number of context variables
    
    Returns:
    --------
    alpha : ndarray, shape (n_states, n_states)
        Base transition log-probabilities
    beta : ndarray, shape (n_states, n_states, n_contexts)
        Context-specific adjustment parameters
    """
    # Initialize alpha with high self-transition probabilities
    alpha = np.zeros((n_states, n_states))
    for i in range(n_states):
        for j in range(n_states):
            if i == j:
                alpha[i, j] = np.log(0.7)
            else:
                alpha[i, j] = np.log(0.3 / (n_states - 1))
    
    # Initialize beta with small random values
    np.random.seed(42)  # For reproducibility
    beta = np.random.normal(0, 0.01, size=(n_states, n_states, n_contexts))
    
    return alpha, beta


