import math
import random
import numpy as np
import time
from IPython.display import display, clear_output
import matplotlib.image as mpimg
from gridworld import *
from agent import *
from utils import *
import numpy as np
from typing import List, Tuple, Dict
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random
from typing import Dict, List, Tuple
from abc import ABC
from typing import Dict, Any
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Tuple


def max_entropy_value_iteration(env, policy, gamma=0.99, tolerance=1e-4, max_iterations=1000):
    """
    Simple implementation of Maximum Entropy Value Iteration with pre-computed entropy.
    
    Core equations:
    V(s) = H(π(·|s)) + E_a[Q(s,a)]
    Q(s,a) = γ * Σ_s' P(s'|s,a) * V(s')
    H(π(·|s)) = -Σ_a π(a|s)log(π(a|s))
    """
    n_states = env.getNumStates()
    n_actions = len(env.getActionSet())
    action_set = env.getActionSet()
    action_to_idx = {a: i for i, a in enumerate(action_set)}
    
    # Get available states and actions
    available_states = env.getAvailableStates()
    available_actions_dict = {
        s: [action_to_idx[a] for a in env.getAvailableActions(s)]
        for s in available_states
    }

    # Pre-compute entropy for all states
    state_entropies = np.zeros(n_states)
    for state in available_states:
        state_policy = policy[state]
        available_actions = available_actions_dict[state]
        
        # Calculate entropy term H(π(·|s))
        entropy = 0
        for action_idx in available_actions:
            prob = state_policy[action_idx]
            if prob > 0:  # Only consider non-zero probabilities
                entropy -= prob * np.log(prob)
        state_entropies[state] = entropy

    # Initialize value function
    V = np.zeros(n_states)
    
    for iteration in range(max_iterations):
        # Store previous iteration's values
        V_prev = V.copy()
        
        # Create new array for updated values
        V_new = np.zeros_like(V)
        
        # Update each state
        for state in available_states:
            # Get policy for current state
            state_policy = policy[state]
            available_actions = available_actions_dict[state]
            
            # Calculate Q(s,a) and E_a[Q(s,a)]
            expected_q = 0
            for action_idx in available_actions:
                action = action_set[action_idx]
                prob = state_policy[action_idx]
                
                # Calculate Q(s,a)
                q_value = 0
                transitions = env.transition_probabilities(state, action)
                for next_state, trans_prob in transitions.items():
                    q_value += gamma * trans_prob * V_prev[next_state]
                
                # Add to expected Q-value
                expected_q += prob * q_value
            
            # Update value function V(s) = H(π(·|s)) + E_a[Q(s,a)]
            V_new[state] = state_entropies[state] + expected_q
        
        # Update V with new values
        V = V_new
        
        # Check convergence
        max_delta = np.max(np.abs(V - V_prev))
        if max_delta < tolerance:
            break
    
    return V, max_delta


class MaxEntropyPolicy:
    def __init__(self, env, gamma: float = 0.95, alpha: float = 1.0):
        self.env = env 
        self.n_states = env.getNumStates()
        self.n_actions = len(env.getActionSet())
        self.action_set = env.getActionSet()
        self.action_to_idx = {a: i for i, a in enumerate(self.action_set)}
        self.gamma = gamma
        self.alpha = alpha
        
        # Initialize value functions
        self.Q = np.zeros((self.n_states, self.n_actions))
        self.V = np.zeros(self.n_states)
        
        # Get available states and actions
        self.available_states = env.getAvailableStates()
        self.available_actions_dict = {
            s: [self.action_to_idx[a] for a in env.getAvailableActions(s)]
            for s in self.available_states
        }
        
        # Initialize policy randomly using Dirichlet distribution
        self.pi = np.zeros((self.n_states, self.n_actions))
        for s in self.available_states:
            available_actions = self.available_actions_dict[s]
            probs = np.random.dirichlet(np.ones(len(available_actions)) * 0.5)
            self.pi[s, available_actions] = probs
            
        # Add history tracking
        self.history = {
            'value_diffs': [],
            'policy_entropy': [],
            'accumulated_entropy': [],
        }

    def update_value_functions(self, tolerance: float = 1e-4, max_iterations: int = 1000) -> float:
        """Updates state values V(s) using entropy-regularized value iteration until convergence.
    
        This method implements soft value iteration using the entropy-regularized Bellman optimality equation:
        V*(s) = α ln[∑_a exp(γ/α ∑_{s'} p(s'|s,a)V*(s'))]
        
        where:
        - α (alpha) is the temperature parameter controlling exploration vs exploitation
        - γ (gamma) is the discount factor for future rewards
        - p(s'|s,a) is the transition probability from state s to s' under action a
        - V*(s) is the optimal value function at state s
        
        The update rule provides a "soft" version of the standard Bellman equation by replacing
        the hard maximum with a smoothed maximum via the log-sum-exp operation. This:
        1. Encourages exploration through entropy regularization
        2. Converges to standard value iteration as α approaches 0
        3. Results in a smoother value function that can be more robust to uncertainties
        
        Algorithm steps:
        1. For each state s:
            a. Calculate expected next state values for each available action
            b. Apply the soft-maximum operation using the temperature parameter α
            c. Update the state value and track the maximum change
        2. Repeat until convergence (change < tolerance) or max iterations reached
        
        Args:
            tolerance (float, optional): Convergence threshold for maximum change in state values.
                Iteration stops when max|V_new(s) - V_old(s)| < tolerance. Defaults to 1e-4.
            max_iterations (int, optional): Maximum number of update iterations to perform
                before stopping, regardless of convergence. Defaults to 1000.
        
        Returns:
            float: The maximum change in state values from the final iteration, indicating
                the degree of convergence achieved.
        
        Note:
            The method modifies the internal V dictionary storing state values in-place.
            Lower tolerance values and higher max_iterations will result in more precise
            convergence but increased computation time.
        """
        max_delta = float('inf')
        iteration = 0
        
        while max_delta > tolerance and iteration < max_iterations:
            old_v = self.V.copy()
            max_delta = 0
            
            # Update all states
            for state in self.available_states:
                available_actions = self.available_actions_dict[state]
                
                # Calculate inner sum for each action
                action_values = []
                for action_idx in available_actions:
                    action = self.action_set[action_idx]
                    # Get transition probabilities
                    transitions = self.env.transition_probabilities(state, action)
                    
                    # Calculate sum_{s'} p(s'|s,a)V*(s')
                    expected_next_value = sum(prob * self.V[next_state] 
                                           for next_state, prob in transitions.items())
                    
                    action_values.append(expected_next_value)
                
                # Convert to numpy array for vectorized operations
                action_values = np.array(action_values)
                
                # Apply the update rule: V*(s) = α ln[sum_a exp(γ/α sum_{s'} p(s'|s,a)V*(s'))]
                exp_terms = np.exp((self.gamma / self.alpha) * action_values)
                self.V[state] = self.alpha * np.log(np.sum(exp_terms))
                
                # Update maximum delta
                max_delta = max(max_delta, abs(old_v[state] - self.V[state]))
            
            iteration += 1
        
        for state in self.available_states:
            available_actions = self.available_actions_dict[state]
            for action_idx in available_actions:
                action = self.action_set[action_idx]
                transitions = self.env.transition_probabilities(state, action)
                
                # Q(s,a) = γ ∑_{s'} p(s'|s,a)V*(s')
                self.Q[state, action_idx] = sum(prob * self.V[next_state] 
                                            for next_state, prob in transitions.items())

        return max_delta

    def policy_iteration(self, n_iterations: int, tolerance: float = 1e-4,
                        convergence_window: int = 10,
                        verbose: bool = False) -> bool:
        """Modified policy iteration that updates value function until convergence before policy updates"""
        recent_diffs = []
        
        for iteration in range(n_iterations):
            # Step 1: Update value functions until convergence
            max_value_diff = self.update_value_functions(tolerance=tolerance)
            
            # Step 2: Update policy for all states
            for state in self.available_states:
                self.update_policy(state)
            
            # Record history
            total_entropy = sum(self.compute_entropy(self.pi[s]) 
                              for s in self.available_states)
            self.history['value_diffs'].append(max_value_diff)
            self.history['policy_entropy'].append(total_entropy)
            self.history['accumulated_entropy'].append(self.V[self.env.getStartState()])
            
            # Check convergence
            recent_diffs.append(max_value_diff)
            if len(recent_diffs) > convergence_window:
                recent_diffs.pop(0)
                
                if len(recent_diffs) == convergence_window:
                    mean_diff = np.mean(recent_diffs)
                    std_diff = np.std(recent_diffs)
                    
                    if mean_diff < tolerance and std_diff < tolerance:
                        if verbose:
                            print(f"Converged at iteration {iteration}")
                            print(f"Final value diff mean: {mean_diff:.6f}, std: {std_diff:.6f}")
                            print(f"Final total entropy: {total_entropy:.6f}")
                        return True
            
            if iteration % 100 == 0 and verbose:
                print(f"Iteration {iteration}, Max Value Diff: {max_value_diff:.6f}")
                print(f"Total Entropy: {total_entropy:.6f}")
        
        print("Did not converge within maximum iterations")
        return False
  
    def compute_entropy(self, probs: np.ndarray, min_prob: float = 1e-8) -> float:
        """Compute entropy of a probability distribution with better numerical stability"""
        valid_probs = probs[probs > min_prob]
        if not np.any(valid_probs):
            return 0.0
        valid_probs = valid_probs / np.sum(valid_probs)
        return -np.sum(valid_probs * np.log(np.clip(valid_probs, min_prob, 1.0)))
    
    def update_policy(self, state: int) -> None:
        """Update π(a|s) using the softmax of Q-values"""
        available_actions = self.available_actions_dict[state]
        
        # Compute exp(γQ) for available actions
        logits = self.gamma * self.Q[state, available_actions]
        
        # Subtract max for numerical stability
        logits_max = np.max(logits)
        exp_logits = np.exp(logits - logits_max)
        
        # Compute softmax probabilities
        self.pi[state, available_actions] = exp_logits / np.sum(exp_logits)
        
        # Ensure numerical stability
        self.pi[state] = np.clip(self.pi[state], 1e-8, 1.0)
        self.pi[state] /= self.pi[state].sum()
    
    def get_action(self, state: int) -> str:
        """Sample action from the policy"""
        if state not in self.available_states:
            return np.random.choice(self.action_set)
        
        probs = self.pi[state]
        action_idx = np.random.choice(len(self.action_set), p=probs)
        return self.action_set[action_idx]

    def plot_training_history(self):
        """Plot the training history"""
        import matplotlib.pyplot as plt
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
        
        # Plot value differences
        ax1.plot(self.history['value_diffs'])
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Max Value Difference')
        ax1.set_title('Value Function Convergence')
        ax1.grid(True)
        
        # Plot policy entropy
        ax2.plot(self.history['policy_entropy'])
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Total Policy Entropy')
        ax2.set_title('Policy Entropy Evolution')
        ax2.grid(True)
        
        plt.tight_layout()
        plt.show()
        
        return plt.gcf()


class BaseMixturePolicy(ABC):
    def __init__(self, env, n_components: int, gamma: float = 0.95):
        """
        Base class for mixture policy implementations.
        
        Args:
            env: Environment that implements getNumStates(), getActionSet(), 
                getAvailableStates(), getAvailableActions(), transition_probabilities()
            n_components: Number of mixture components
            gamma: Discount factor
        """
        self.env = env
        self.n_components = n_components
        self.gamma = gamma
        
        # Initialize environment-related attributes
        self.n_states = env.getNumStates()
        self.n_actions = len(env.getActionSet())
        self.action_set = env.getActionSet()
        self.action_to_idx = {a: i for i, a in enumerate(self.action_set)}
        
        # Get available states and actions
        self.available_states = env.getAvailableStates()
        self.available_actions_dict = {
            s: [self.action_to_idx[a] for a in env.getAvailableActions(s)]
            for s in self.available_states
        }
        
        # Initialize base policies and values
        self._initialize_policies()
        self._initialize_values()
        
        # Initialize history tracking
        self.history = {
            'value_diffs': [],
            'mixture_entropy': [],
            'component_entropy': [],
            'total_entropy_diff': [],
            'accumulated_entropy': [],
            'policy_diffs': [],
        }
    
    def _initialize_policies(self) -> None:
        """Initialize mixture weights and component policies"""
        # Initialize mixture weights with Dirichlet distribution
        self.pi_c = np.random.dirichlet(np.ones(self.n_components) * 0.5, size=self.n_states)
        self.pi_c = np.clip(self.pi_c, 1e-10, 1.0)
        
        # Initialize component policies
        self.pi_b = np.zeros((self.n_components, self.n_states, self.n_actions))
        for k in range(self.n_components):
            for s in self.available_states:
                available_actions = self.available_actions_dict[s]
                probs = np.random.dirichlet(np.ones(len(available_actions)) * 0.5)
                self.pi_b[k, s, available_actions] = probs
        self.pi_b = np.clip(self.pi_b, 1e-10, 1.0)
    
    # def _initialize_policies(self) -> None:
    #     """Initialize mixture weights and component policies"""
    #     # Initialize mixture weights with uniform distribution
    #     self.pi_c = np.ones((self.n_states, self.n_components)) / self.n_components
    #     self.pi_c = np.clip(self.pi_c, 1e-10, 1.0)
        
    #     # Initialize component policies with uniform distribution
    #     self.pi_b = np.zeros((self.n_components, self.n_states, self.n_actions))
    #     for k in range(self.n_components):
    #         for s in self.available_states:
    #             available_actions = self.available_actions_dict[s]
    #             # Uniform distribution over available actions
    #             probs = np.ones(len(available_actions)) / len(available_actions)
    #             self.pi_b[k, s, available_actions] = probs
    #     self.pi_b = np.clip(self.pi_b, 1e-10, 1.0)
     
    def _initialize_values(self) -> None:
        """Initialize value functions"""
        self.Q = np.zeros((self.n_states, self.n_actions))
        self.V = np.zeros(self.n_states)
    
    def compute_entropy(self, probs: np.ndarray) -> float:
        """
        Compute entropy of a probability distribution.
        
        Args:
            probs: Probability distribution array
            
        Returns:
            float: Entropy value
        """
        valid_probs = probs[probs > 1e-10]
        return -np.sum(valid_probs * np.log(valid_probs + 1e-10))
    
    def get_mixture_policy(self, state: int) -> np.ndarray:
        """
        Compute π^m(a|s) = Σ_k π^c(k|s)π^b_k(a|s) for a single state.
        
        Args:
            state: State index
            
        Returns:
            np.ndarray: Mixture policy for the given state
        """
        pi_m = np.sum(self.pi_c[state, :, np.newaxis] * self.pi_b[:, state, :], axis=0)
        pi_m = np.clip(pi_m, 1e-10, 1.0)
        pi_m /= pi_m.sum()
        return pi_m
    
    def get_full_mixture_policy(self) -> np.ndarray:
        """
        Compute π^m(a|s) for all states.
        
        Returns:
            np.ndarray: Matrix of shape (n_states, n_actions) containing mixture policies
        """
        pi_m = np.zeros((self.n_states, self.n_actions))
        
        for state in self.available_states:
            pi_m[state] = self.get_mixture_policy(state)
        
        return pi_m
    
    def compute_responsibilities(self, state: int) -> np.ndarray:
        """
        Compute γ_k(a|s) = π^c(k|s)π^b_k(a|s)/π^m(a|s).
        
        Args:
            state: State index
            
        Returns:
            np.ndarray: Responsibility matrix of shape (n_components, n_actions)
        """
        pi_m = self.get_mixture_policy(state)
        responsibilities = np.zeros((self.n_components, self.n_actions))
        
        for k in range(self.n_components):
            responsibilities[k] = (self.pi_c[state, k] * self.pi_b[k, state]) / (pi_m + 1e-10)
        
        return np.clip(responsibilities, 1e-10, 1.0)
    
    @abstractmethod
    def update_value_functions(self, state: int) -> None:
        """Update value functions for the given state"""
        pass
    
    @abstractmethod
    def update_mixture_weights(self, state: int) -> None:
        """Update mixture weights π^c(k|s) for the given state"""
        pass
    
    @abstractmethod
    def update_component_policies(self, state: int) -> None:
        """Update component policies π^b_k(a|s) for the given state"""
        pass
    
    def get_action(self, state: int) -> str:
        """
        Sample action from the mixture policy.
        
        Args:
            state: State index
            
        Returns:
            str: Selected action
        """
        if state not in self.available_states:
            return np.random.choice(self.action_set)
        
        pi_m = self.get_mixture_policy(state)
        available_actions = self.available_actions_dict[state]
        
        # Normalize probabilities over available actions
        probs = np.zeros(len(self.action_set))
        for a_idx in available_actions:
            probs[a_idx] = pi_m[a_idx]
        probs = np.clip(probs, 1e-10, 1.0)
        probs /= probs.sum()
        
        action_idx = np.random.choice(len(self.action_set), p=probs)
        return self.action_set[action_idx]
    
    def calculate_total_entropy(self, verbose: bool = True) -> Dict[str, float]:
        """
        Calculate total entropy of mixture and component policies across all states
        using maximum entropy value iteration.
        
        Args:
            verbose: Whether to print entropy statistics
            
        Returns:
            Dict containing entropy statistics
        """
        # Get full mixture policy
        pi_m = self.get_full_mixture_policy()
        
        # Calculate mixture entropy using max entropy value iteration
        mixture_values, _ = max_entropy_value_iteration_optimized(
            self.env, 
            pi_m, 
            gamma=self.gamma, 
            tolerance=1e-5, 
            max_iterations=10000
        )
        mixture_entropy_of_initial_state = mixture_values[self.env.getStartState()] 
        # Calculate component entropy - average across all components
        component_values_list = []
        for k in range(self.n_components):
            # Get the k-th component policy across all states
            component_policy = self.pi_b[k]
            component_values, _ = max_entropy_value_iteration_optimized(
                self.env, 
                component_policy, 
                gamma=self.gamma, 
                tolerance=1e-5, 
                max_iterations=10000
            )
            component_values_list.append(component_values[self.env.getStartState()])
        
        # Average component values across all components
        avg_component_entropy_of_initial_state = np.mean(component_values_list, axis=0)
        
        # # Calculate total entropy values for relevant states
        # mixture_entropy_sum = np.sum([mixture_values[s] for s in self.available_states])
        # component_entropy_sum = np.sum([avg_component_values[s] for s in self.available_states])
        

        
        entropy_stats = {
            'mixture_entropy': mixture_entropy_of_initial_state,
            'component_entropy': avg_component_entropy_of_initial_state
        }
        
        if verbose:
            print("Entropy statistics:", entropy_stats)
        
        return entropy_stats
 
    # def calculate_total_entropy(self, verbose: bool = True) -> Dict[str, float]:
    #     """
    #     Calculate total entropy of mixture and component policies across all states.
        
    #     Args:
    #         verbose: Whether to print entropy statistics
            
    #     Returns:
    #         Dict containing entropy statistics
    #     """
    #     mixture_entropy_sum = 0
    #     component_entropy_sum = 0
        
    #     for state in self.available_states:
    #         pi_m = self.get_mixture_policy(state)
    #         mixture_entropy_sum += self.compute_entropy(pi_m)
            
    #         for k in range(self.n_components):
    #             component_entropy_sum += self.compute_entropy(self.pi_b[k, state])
    #     component_entropy_sum /= self.n_components
        
    #     entropy_stats = {
    #         'mixture_entropy': mixture_entropy_sum,
    #         'component_entropy': component_entropy_sum,
    #         'total_entropy_diff': mixture_entropy_sum - self.alpha * component_entropy_sum
    #             if hasattr(self, 'alpha') else mixture_entropy_sum
    #     }
        
    #     if verbose:
    #         print("Entropy statistics:", entropy_stats)
    #     return entropy_stats
    
    def analyze_convergence(self) -> Dict[str, Any]:
        """
        Analyze convergence patterns from training history.
        
        Returns:
            Dict containing convergence statistics
        """
        value_diffs = np.array(self.history['value_diffs'])
        final_window = value_diffs[-10:]
        
        value_convergence = {
            'final_diff': value_diffs[-1],
            'mean_diff': np.mean(final_window),
            'std_diff': np.std(final_window)
        }
        
        entropy_diffs = np.array(self.history['total_entropy_diff'])
        entropy_window = entropy_diffs[-10:]
        
        entropy_convergence = {
            'final_entropy_diff': entropy_diffs[-1],
            'mean_entropy_diff': np.mean(entropy_window),
            'std_entropy_diff': np.std(entropy_window)
        }
        
        return {
            'value_convergence': value_convergence,
            'entropy_convergence': entropy_convergence,
            'converged': (value_convergence['std_diff'] < 1e-4 and 
                         entropy_convergence['std_entropy_diff'] < 1e-4)
        }

    def plot_training_history(self):
        """Plot the training history of entropies"""
        import matplotlib.pyplot as plt
        
        # Create single figure
        plt.figure(figsize=(10, 6))
        
        # Plot entropies
        plt.plot(self.history['mixture_entropy'], label='Mixture Entropy')
        plt.plot(self.history['component_entropy'], label='Component Entropy')
        plt.plot(self.history['total_entropy_diff'], label='Total Entropy Diff')
        plt.xlabel('Iteration')
        plt.ylabel('Entropy')
        plt.title('Entropy Evolution During Training')
        plt.legend()
        plt.grid(True)
        
        # Set y-axis limits between 0 and 100
        plt.ylim(0, 100)
        
        plt.tight_layout()
        plt.show()
        
        return plt.gcf()  # Return the current figure for potential further customization


class ValueMixtureEntropyWeightedAgent(BaseMixturePolicy):
    def __init__(self, env, n_components: int, gamma: float = 0.95, alpha: float = 0.5):
        """
        Initialize ValueMixtureEntropyWeightedAgent following the monotonic improvement theorem.
        
        Args:
            env: Environment instance
            n_components: Number of policy components
            gamma: Discount factor
            alpha: Entropy weight parameter
        """
        super().__init__(env, n_components, gamma)
        self.alpha = alpha
        self.history['V'] = []
        self.history['pi_c'] = []
        self.history['pi_b'] = []
    
    def compute_marginal_policy(self, state: int) -> np.ndarray:
        """
        Compute π(a|s) = Σ_k w_k(s)·π_k(a|s)
        
        Args:
            state: State index
            
        Returns:
            np.ndarray: Marginal policy for state
        """
        # Use the existing get_mixture_policy method from BaseMixturePolicy
        return self.get_mixture_policy(state)
    
    def update_q_values(self) -> None:
        """Update Q-values based on current V-values."""
        for state in self.available_states:
            available_actions = self.available_actions_dict[state]
            for a_idx in available_actions:
                action = self.action_set[a_idx]
                transitions = self.env.transition_probabilities(state, action)
                expected_value = sum(prob * self.V[next_state] 
                                  for next_state, prob in transitions.items())
                self.Q[state, a_idx] = expected_value
    
    def update_value_functions(self) -> float:
        """
        Update V-values according to the theorem:
        V(s) = log Σ_k (Σ_a exp((log r_k(a|s) + γQ(s,a))/(1-α)))^(1-α)
        
        Returns:
            float: Maximum change in value function
        """
        max_delta = 0
        old_v = self.V.copy()
        
        for state in self.available_states:
            available_actions = self.available_actions_dict[state]
            responsibilities = self.compute_responsibilities(state)
            
            # Compute V(s) according to the theorem
            component_sums = np.zeros(self.n_components)
            for k in range(self.n_components):
                action_exp_sum = 0
                for idx, a_idx in enumerate(available_actions):
                    log_resp = np.log(responsibilities[k, a_idx] + 1e-8)
                    q_value = self.gamma * self.Q[state, a_idx]
                    exp_term = np.exp((log_resp + q_value) / (1 - self.alpha))
                    action_exp_sum += exp_term
                
                component_sums[k] = action_exp_sum
            
            # V(s) = log Σ_k (F_k(s))^(1-α)
            weighted_sum = np.sum(component_sums ** (1 - self.alpha))
            self.V[state] = np.log(weighted_sum)
            
            # Track maximum change
            max_delta = max(max_delta, abs(old_v[state] - self.V[state]))
        
        return max_delta
    
    def update_component_policies(self, state: int) -> None:
        """
        Update π_k(a|s) according to Equation (eq:policy_update) in the theorem:
        π_k(a|s) = exp((log r_k(a|s) + γQ(s,a))/(1-α)) / Σ_a' exp((log r_k(a'|s) + γQ(s,a'))/(1-α))
        
        Args:
            state: State index
        """
        available_actions = self.available_actions_dict[state]
        responsibilities = self.compute_responsibilities(state)
        
        for k in range(self.n_components):
            # Calculate action values
            action_values = np.zeros(len(available_actions))
            for idx, a_idx in enumerate(available_actions):
                log_resp = np.log(responsibilities[k, a_idx] + 1e-8)
                q_value = self.gamma * self.Q[state, a_idx]
                action_values[idx] = (log_resp + q_value) / (1 - self.alpha)
            
            # Apply softmax
            max_val = np.max(action_values)
            exp_values = np.exp(action_values - max_val)
            self.pi_b[k, state, available_actions] = exp_values / np.sum(exp_values)
    
    def update_mixture_weights(self, state: int) -> None:
        """
        Update w_k(s) according to Equation (eq:weights_update) in the theorem:
        w_k(s) = (F_k(s))^(1-α) / Σ_j (F_j(s))^(1-α)
        where F_k(s) = Σ_a exp((log r_k(a|s) + γQ(s,a))/(1-α))
        
        Args:
            state: State index
        """
        available_actions = self.available_actions_dict[state]
        responsibilities = self.compute_responsibilities(state)
        
        # Calculate F_k(s) for each component k
        F_values = np.zeros(self.n_components)
        for k in range(self.n_components):
            action_exp_sum = 0
            for idx, a_idx in enumerate(available_actions):
                log_resp = np.log(responsibilities[k, a_idx] + 1e-8)
                q_value = self.gamma * self.Q[state, a_idx]
                exp_term = np.exp((log_resp + q_value) / (1 - self.alpha))
                action_exp_sum += exp_term
            
            F_values[k] = action_exp_sum
        
        # Calculate w_k(s) = (F_k(s))^(1-α) / Σ_j (F_j(s))^(1-α)
        F_values_powered = F_values ** (1 - self.alpha)
        self.pi_c[state] = F_values_powered / np.sum(F_values_powered)
    
    # We don't need a separate method to update responsibilities
    # The base class already has compute_responsibilities that calculates them on the fly
    
    def policy_iteration(self, n_iterations: int, tolerance: float = 1e-4, 
                         convergence_window: int = 10, verbose: bool = False, save_history=True) -> bool:
        """
        Main policy iteration loop following the monotonic improvement theorem.
        
        Args:
            n_iterations: Maximum number of iterations
            tolerance: Convergence tolerance
            convergence_window: Window size for checking convergence
            verbose: Whether to print progress
            
        Returns:
            bool: Whether convergence was achieved
        """
        recent_diffs = []
        
        for iteration in range(n_iterations):
            # Store old values for comparison
            old_v = self.V.copy()
            old_pi_c = self.pi_c.copy()
            old_pi_b = self.pi_b.copy()
            
            # Save current value function to history
            self.history['V'].append(self.V.copy())
            
            # Step 1: Update Q-values using current V-values
            self.update_q_values()
            
            # Step 2: Update policies and weights for all states according to the theorem
            for state in self.available_states:
                self.update_component_policies(state)
                self.update_mixture_weights(state)
            
            self.history['pi_c'].append(self.pi_c.copy())
            self.history['pi_b'].append(self.pi_b.copy()) 
            # Step 3: Update value functions according to the theorem
            max_value_diff = self.update_value_functions()
            
            # Calculate policy differences
            policy_diff_c = np.max(np.abs(self.pi_c - old_pi_c))
            policy_diff_b = np.max(np.abs(self.pi_b - old_pi_b))
            value_diff = np.max(np.abs(self.V - old_v))
            max_policy_diff = max(policy_diff_c, policy_diff_b)
            
            self.history['value_diffs'].append(value_diff)
            self.history['policy_diffs'].append(max_policy_diff)
                
            # Update history
            if save_history and iteration % 3 == 0 and max_policy_diff > 1e-8:
                entropy_stats = self.calculate_total_entropy(verbose=False)
                self.history['mixture_entropy'].append(entropy_stats['mixture_entropy'])
                self.history['component_entropy'].append(entropy_stats['component_entropy'])
            
            # Check for monotonic improvement (value function should not decrease)
            if np.any(self.V < old_v - 1e-8):  # Allow for small numerical errors
                violating_states = np.where(self.V < old_v - 1e-8)[0]
                if verbose:
                    print(f"Warning: Value function decreased at iteration {iteration}")
                    print(f"Violating states: {violating_states}")
                    print(f"Max decrease: {np.min(self.V - old_v)}")
            
            # Check convergence
            recent_diffs.append(max_policy_diff)
            if len(recent_diffs) > convergence_window:
                recent_diffs.pop(0)
                
                if len(recent_diffs) == convergence_window:
                    mean_diff = np.mean(recent_diffs)
                    std_diff = np.std(recent_diffs)
                    
                    if mean_diff < tolerance and std_diff < tolerance and max_value_diff < tolerance:
                        if verbose:
                            print(f"Converged at iteration {iteration}")
                            self.calculate_total_entropy(verbose=True)
                        return True
            
            if verbose and iteration % 1 == 0:
                print(f"Iteration {iteration}, Value Diff: {value_diff:.8f}, "
                      f"Policy Diff: {max_policy_diff:.8f}")
        
        if verbose:
            print("Did not converge within maximum iterations")
            self.calculate_total_entropy(verbose=True)
        return False

