import math
import random
import numpy as np
import time
from IPython.display import display, clear_output
import matplotlib.image as mpimg
from gridworld import *
from agent import *
from itertools import count
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.image as mpimg
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase


# def max_entropy_value_iteration(env, policy, alpha=1.0, gamma=0.99, tolerance: float = 1e-4, max_iterations: int = 1000) -> float:
#     """
#     Updates the Maximum Entropy Value Function until convergence.
    
#     Mathematical Foundation:
#     ----------------------
#     In Maximum Entropy RL, the value function combines expected returns with an entropy term:
#         V(s) = H(π(·|s)) + E_a[Q(s,a)]
    
#     where:
#     - H(π(·|s)) is the entropy of the policy at state s: -Σ_a π(a|s)log(π(a|s))
#     - Q(s,a) = γ * Σ_s' P(s'|s,a) * V(s') is the future value
#     - γ is the discount factor
    
#     The expectation E_a[Q(s,a)] under the optimal maximum entropy policy equals:
#         α * log(Σ_a exp(Q(s,a)/α))
    
#     where α is the temperature parameter. This equivalence comes from the fact that 
#     the optimal policy follows the Boltzmann distribution:
#         π*(a|s) = exp(Q(s,a)/α) / Σ_a' exp(Q(s,a')/α)
    
#     Therefore, the complete update rule is:
#         V(s) = H(π(·|s)) + α * [max_a(Q(s,a)/α) + log(Σ_a exp(Q(s,a)/α - max_a(Q(s,a)/α)))]
    
#     The last form uses the log-sum-exp trick for numerical stability.

#     Parameters:
#     ----------
#     policy : array-like
#         Current policy distribution over actions for each state
#     alpha : float, default=1.0
#         Temperature parameter that determines the contribution of the entropy term
#     gamma : float, default=0.99
#         Discount factor for future rewards
#     tolerance : float, default=1e-4
#         Convergence threshold for value iteration
#     max_iterations : int, default=1000
#         Maximum number of iterations for value iteration

#     Returns:
#     -------
#     tuple(numpy.ndarray, float)
#         - Updated value function for each state
#         - Final maximum delta (for convergence checking)

#     Note:
#     ----
#     The implementation uses several numerical stability techniques:
#     - Minimum probability thresholds for entropy computation
#     - Log-sum-exp trick for softmax computation
#     - Value clipping for exponential terms
#     """
#     max_delta = float('inf')
#     iteration = 0
#     min_prob = 1e-8

#     n_states = env.getNumStates()
#     n_actions = len(env.getActionSet())
#     action_set = env.getActionSet()
#     action_to_idx = {a: i for i, a in enumerate(action_set)}
    
#     # Get available states and actions
#     available_states = env.getAvailableStates()
#     available_actions_dict = {
#         s: [action_to_idx[a] for a in env.getAvailableActions(s)]
#         for s in available_states
#     }

#     mV = np.zeros(n_states)  # Mixture entropy value function
#     mQ = np.zeros((n_states, n_actions))  # Mixture entropy Q-function

#     while max_delta > tolerance and iteration < max_iterations:
#         old_mv = mV.copy()
#         max_delta = 0
        
#         # Update all states
#         for state in available_states:
#             available_actions = available_actions_dict[state]
#             mixture_policy = policy[state]
             
#             # Calculate immediate mixture entropy with numerical stability
#             valid_probs = mixture_policy[available_actions]
#             valid_mask = valid_probs > min_prob
#             if np.any(valid_mask):
#                 valid_probs = valid_probs[valid_mask]
#                 valid_probs = valid_probs / np.sum(valid_probs)  # Renormalize
#                 mixture_entropy = -np.sum(
#                     valid_probs * np.log(np.clip(valid_probs, min_prob, 1.0))
#                 )
#             else:
#                 mixture_entropy = 0.0
            
#             # Calculate inner sum for each action
#             action_values = []
#             for action_idx in available_actions:
#                 action = action_set[action_idx]
#                 transitions = env.transition_probabilities(state, action)
                
#                 # Calculate sum_{s'} p(s'|s,a)mV*(s') with numerical stability
#                 expected_next_value = 0.0
#                 for next_state, prob in transitions.items():
#                     if np.isfinite(mV[next_state]):
#                         expected_next_value += prob * mV[next_state]
                
#                 action_values.append(expected_next_value)
            
#             # Convert to numpy array for vectorized operations
#             action_values = np.array(action_values)
            
#             # Apply the update rule with mixture entropy as immediate reward
#             if len(action_values) > 0:
#                 # Add numerical stability to exponential terms
#                 scaled_values = (gamma / alpha) * action_values
#                 max_val = np.max(scaled_values)
#                 exp_terms = np.exp(np.clip(scaled_values - max_val, -100, 100))
#                 summed_exp = np.sum(exp_terms)
                
#                 if summed_exp > min_prob:
#                     mV[state] = mixture_entropy + alpha * (max_val + np.log(summed_exp))
#                 else:
#                     mV[state] = mixture_entropy
#             else:
#                 mV[state] = mixture_entropy
            
#             # Update maximum delta with numerical stability check
#             if np.isfinite(old_mv[state]) and np.isfinite(mV[state]):
#                 max_delta = max(max_delta, abs(old_mv[state] - mV[state]))
        
#         iteration += 1
        
#         if not np.isfinite(max_delta):
#             print("Warning: Non-finite values detected in mixture entropy update")
#             break
        
#     return mV, max_delta

def max_entropy_value_iteration(env, policy, gamma=0.99, tolerance=1e-4, max_iterations=1000):
    """
    Simple implementation of Maximum Entropy Value Iteration with pre-computed entropy.
    
    Core equations:
    V(s) = H(π(·|s)) + E_a[Q(s,a)]
    Q(s,a) = γ * Σ_s' P(s'|s,a) * V(s')
    H(π(·|s)) = -Σ_a π(a|s)log(π(a|s))
    """
    n_states = env.getNumStates()
    n_actions = len(env.getActionSet())
    action_set = env.getActionSet()
    action_to_idx = {a: i for i, a in enumerate(action_set)}
    
    # Get available states and actions
    available_states = env.getAvailableStates()
    available_actions_dict = {
        s: [action_to_idx[a] for a in env.getAvailableActions(s)]
        for s in available_states
    }

    # Pre-compute entropy for all states
    state_entropies = np.zeros(n_states)
    for state in available_states:
        state_policy = policy[state]
        available_actions = available_actions_dict[state]
        
        # Calculate entropy term H(π(·|s))
        entropy = 0
        for action_idx in available_actions:
            prob = state_policy[action_idx]
            if prob > 0:  # Only consider non-zero probabilities
                entropy -= prob * np.log(prob)
        state_entropies[state] = entropy

    # Initialize value function
    V = np.zeros(n_states)
    
    for iteration in range(max_iterations):
        # Store previous iteration's values
        V_prev = V.copy()
        
        # Create new array for updated values
        V_new = np.zeros_like(V)
        
        # Update each state
        for state in available_states:
            # Get policy for current state
            state_policy = policy[state]
            available_actions = available_actions_dict[state]
            
            # Calculate Q(s,a) and E_a[Q(s,a)]
            expected_q = 0
            for action_idx in available_actions:
                action = action_set[action_idx]
                prob = state_policy[action_idx]
                
                # Calculate Q(s,a)
                q_value = 0
                transitions = env.transition_probabilities(state, action)
                for next_state, trans_prob in transitions.items():
                    q_value += gamma * trans_prob * V_prev[next_state]
                
                # Add to expected Q-value
                expected_q += prob * q_value
            
            # Update value function V(s) = H(π(·|s)) + E_a[Q(s,a)]
            V_new[state] = state_entropies[state] + expected_q
        
        # Update V with new values
        V = V_new
        
        # Check convergence
        max_delta = np.max(np.abs(V - V_prev))
        if max_delta < tolerance:
            break
    
    return V, max_delta

import numpy as np

def max_entropy_value_iteration_optimized(env, policy, gamma=0.99, tolerance=1e-4, max_iterations=1000):
    """
    Optimized implementation of Maximum Entropy Value Iteration with pre-computed entropy.
    
    Core equations:
    V(s) = H(π(·|s)) + E_a[Q(s,a)]
    Q(s,a) = γ * Σ_s' P(s'|s,a) * V(s')
    H(π(·|s)) = -Σ_a π(a|s)log(π(a|s))
    """
    n_states = env.getNumStates()
    action_set = env.getActionSet()
    action_to_idx = {a: i for i, a in enumerate(action_set)}
    
    # Get available states and actions
    available_states = env.getAvailableStates()
    available_actions_dict = {
        s: [action_to_idx[a] for a in env.getAvailableActions(s)]
        for s in available_states
    }
    
    # Pre-compute entropy for all states - this is constant throughout iterations
    state_entropies = np.zeros(n_states)
    for state in available_states:
        state_policy = policy[state]
        available_actions = available_actions_dict[state]
        
        # Calculate entropy term H(π(·|s)) using vectorized operations
        probs = np.array([state_policy[action_idx] for action_idx in available_actions])
        non_zero_mask = probs > 0
        if np.any(non_zero_mask):
            entropy = -np.sum(probs[non_zero_mask] * np.log(probs[non_zero_mask]))
        else:
            entropy = 0
        state_entropies[state] = entropy
    
    # Pre-compute transition probabilities and policy probabilities
    # This is the most critical optimization
    transitions_cache = {}
    policy_probs_cache = {}
    
    for state in available_states:
        for action_idx in available_actions_dict[state]:
            action = action_set[action_idx]
            
            # Cache transition probabilities
            transitions = env.transition_probabilities(state, action)
            transitions_cache[(state, action_idx)] = list(transitions.items())
            
            # Cache policy probability
            policy_probs_cache[(state, action_idx)] = policy[state][action_idx]
    
    # Initialize value function
    V = np.zeros(n_states)
    
    # Main iteration loop
    for iteration in range(max_iterations):
        # Store previous iteration's values
        V_prev = V.copy()
        
        # Create new array for updated values
        V_new = np.zeros_like(V)
        
        # Update each state
        for state in available_states:
            available_actions = available_actions_dict[state]
            
            # Calculate expected Q-value for the state
            expected_q = 0
            for action_idx in available_actions:
                # Get cached policy probability
                prob = policy_probs_cache[(state, action_idx)]
                
                # Skip if probability is effectively zero
                if prob < 1e-10:
                    continue
                
                # Calculate Q(s,a) using cached transitions
                q_value = 0
                for next_state, trans_prob in transitions_cache[(state, action_idx)]:
                    q_value += gamma * trans_prob * V_prev[next_state]
                
                # Add to expected Q-value
                expected_q += prob * q_value
            
            # Update value function V(s) = H(π(·|s)) + E_a[Q(s,a)]
            V_new[state] = state_entropies[state] + expected_q
        
        # Update V with new values
        V = V_new
        
        # Check convergence
        max_delta = np.max(np.abs(V - V_prev))
        if max_delta < tolerance:
            break
    
    return V, max_delta

def plotPolicy(env, policy, with_number=False, put_start=False, put_goal=False, ax=None, plot_policy=True, save_path=None):
    """
    Plot the policy for a grid world environment.
    
    Parameters:
        env: The environment object
        policy: The policy to plot
        with_number: Whether to show state numbers
        put_start: Whether to mark the start state
        put_goal: Whether to mark the goal state
        ax: Matplotlib axis to plot on (creates one if None)
        plot_policy: Whether to plot policy arrows
        save_path: Path to save the figure (None means don't save)
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.image as mpimg
    
    numRows, numCols = env.getGridDimensions()
    matrixMDP = env.matrixMDP
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    
    numStates = env.getNumStates()  # Assuming a method getNumStates exists
    i_0, j_0=env.getStateXY(env.getStartState())
    for idx in range(numStates):
        i, j = env.getStateXY(idx)
        if i_0 == i and j_0 == j:
            continue
        dx, dy = 0, 0
        if with_number:
            ax.text(j + 0.1, numRows - i - 0.2, str(idx), fontsize=8)
        if matrixMDP[i][j] != -1 and matrixMDP[i][j] != -2 and ((i,j)!= env.getStateXY(env.getGoalState()) or not put_goal) and plot_policy:
            if policy[idx].argmax() == 0:  # up
                dy = 0.25
            elif policy[idx].argmax() == 1:  # right
                dx = 0.25
            elif policy[idx].argmax() == 2:  # down
                dy = -0.25
            elif policy[idx].argmax() == 3:  # left
                dx = -0.25
            ax.arrow(j + 0.5, numRows - i - 0.5, dx, dy, head_width=0.2, head_length=0.05, fc='k', ec='k')
        elif matrixMDP[i][j] == -1:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="gray"))
        elif matrixMDP[i][j] == -2:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor='none', edgecolor='black', hatch='//'))
    
    if put_start:
        i, j = env.getStateXY(env.getStartState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="darkblue"))
        robot_image = mpimg.imread('robot.png')
        ax.imshow(robot_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))
    if put_goal:
        i, j = env.getStateXY(env.getGoalState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="white"))
        apple_image = mpimg.imread('apple.jpg')
        ax.imshow(apple_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))
    
    ax.set_xlim([0, numCols])
    ax.set_ylim([0, numRows])
    # Remove x and y labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    # Save the figure if a path is provided
    if save_path is not None:
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    
    return ax

def plotPolicyold(env, policy, with_number=False, put_start=False, put_goal=False, ax=None, plot_policy=True):
    numRows, numCols = env.getGridDimensions()
    matrixMDP = env.matrixMDP
    if ax is None:
        ax = plt.gca()
    
    numStates = env.getNumStates()  # Assuming a method getNumStates exists
    
    for idx in range(numStates):
        i, j = env.getStateXY(idx)
        dx, dy = 0, 0
        if with_number:
            ax.text(j + 0.1, numRows - i - 0.2, str(idx), fontsize=8)
        if matrixMDP[i][j] != -1 and matrixMDP[i][j] != -2 and ((i,j)!= env.getStateXY(env.getGoalState()) or not put_goal) and plot_policy:
            if policy[idx].argmax() == 0:  # up
                dy = 0.25
            elif policy[idx].argmax() == 1:  # right
                dx = 0.25
            elif policy[idx].argmax() == 2:  # down
                dy = -0.25
            elif policy[idx].argmax() == 3:  # left
                dx = -0.25
            ax.arrow(j + 0.5, numRows - i - 0.5, dx, dy, head_width=0.2, head_length=0.05, fc='k', ec='k')
        elif matrixMDP[i][j] == -1:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="gray"))
        elif matrixMDP[i][j] == -2:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor='none', edgecolor='black', hatch='//'))

    
    if put_start:
        i, j = env.getStateXY(env.getStartState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="darkblue"))
        robot_image = mpimg.imread('robot.png')
        ax.imshow(robot_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))

    if put_goal:
        i, j = env.getStateXY(env.getGoalState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="white"))
        apple_image = mpimg.imread('apple.jpg')
        ax.imshow(apple_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))

    
    ax.set_xlim([0, numCols])
    ax.set_ylim([0, numRows])

    # Remove x and y labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])   


def dynPlotPolicy(env, numStates):
    numRows, numCols = env.getGridDimensions()
    matrixMDP = env.matrixMDP
    plt.clf()
    ax = plt.gca()  # Get the current axis

    for idx in range(numStates):
        i, j = env.getStateXY(idx)
        if matrixMDP[i][j] == -1:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="gray"))
        if matrixMDP[i][j] == -2:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor='none', edgecolor='black', hatch='//'))
    
    i, j = env.getStateXY(env.getGoalState())
    ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="cyan"))
    
    i, j = env.currX, env.currY
    ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="red"))
    
    plt.xlim([0, numCols])
    plt.ylim([0, numRows])

    for i in range(numCols):
        plt.axvline(i, color='k', linestyle=':')
    plt.axvline(numCols, color='k', linestyle=':')
    
    for j in range(numRows):
        plt.axhline(j, color='k', linestyle=':')
    plt.axhline(numRows, color='k', linestyle=':')

    plt.show()
 
    
def plot_lap_distances(env, rewards, ax=None):
    if ax is None:
        ax = plt.gca()
    
    # Normalize the reward values to [0, 1] for consistent color mapping
    norm = Normalize(vmin=min(rewards), vmax=max(rewards))
    cmap = plt.cm.Blues  # You can choose any colormap you like
    
    for idx in range(env.getNumStates()):
        i, j = env.getStateXY(idx)
        if env.matrixMDP[i][j] != -1 and (i,j) != env.getStateXY(env.getStartState()):
            color = cmap(norm(rewards[idx]))  # Map the reward to a color
            ax.add_patch(patches.Rectangle((j, env.numRows - i - 1), 1.0, 1.0, color=color))
        elif (i,j) == env.getStateXY(env.getStartState()):
            robot_image = mpimg.imread('robot.png')
            ax.imshow(robot_image, aspect='auto', extent=(j, j + 1, env.numRows - i - 1, env.numRows - i))
        else:
            ax.add_patch(patches.Rectangle((j, env.numRows - i - 1), 1.0, 1.0, facecolor="gray"))
    
    ax.set_xlim([0, env.numCols])
    ax.set_ylim([0, env.numRows])

    # Remove x and y labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    
    # Create a color bar with the normalized colors and corresponding values
    sm = ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])  # Only needed for matplotlib versions < 3.1
    plt.colorbar(sm, ax=ax, orientation='vertical', label='Reward Value')


def test_poilicy(env, pi, eps_counts=1000, eps_len=100):
    lens = []
    rewards = []
    success_rate = 0
    for i in range(eps_counts):
        env.reset()
        ep_len = 0
        reward_sum = 0
        while not env.isTerminal() and ep_len < eps_len:
            action = random.choices(env.getActionSet(), weights=pi[env.getCurrentState()], k=1)[0]
            o, reward, done = env.step(action)
            reward_sum += reward
            ep_len += 1
        if env.currX == env.goalX and env.currY == env.goalY:
            success_rate += 1
        lens.append(ep_len)
        rewards.append(reward_sum)
    print("rewards: ", np.mean(rewards))
    print("success rate: ", success_rate/eps_counts)
    print("average length: ", np.mean(lens))


def play_policy(env, pi):
    env.reset()
    while not env.isTerminal():
        action = random.choices(env.getActionSet(), weights=pi[env.getCurrentState()] , k=1)[0]
        env.step(action)
        plt.figure()  
        dynPlotPolicy(env, env.getNumStates())
        display(plt.gcf())
        time.sleep(0.1)
        clear_output(wait=True)
        plt.close()
        

def plotAgent(env, agent, put_start=False, put_goal=False, ax=None, plot_policy=True):
    numRows, numCols = env.getGridDimensions()
    matrixMDP = env.matrixMDP
    if ax is None:
        ax = plt.gca()
    
    numStates = env.getNumStates()  # Assuming a method getNumStates exists
    
    for idx in range(numStates):
        state = agent.encode_state(idx)
        action = agent.select_action(state, deterministic=True).item()
        i, j = env.getStateXY(idx)
        dx, dy = 0, 0
        if matrixMDP[i][j] != -1 and matrixMDP[i][j] != -2 and ((i,j)!= env.getStateXY(env.getGoalState()) or not put_goal) and plot_policy:
            if action == 0:  # up
                dy = 0.25
            elif action == 1:  # right
                dx = 0.25
            elif action == 2:  # down
                dy = -0.25
            elif action == 3:  # left
                dx = -0.25
            ax.arrow(j + 0.5, numRows - i - 0.5, dx, dy, head_width=0.2, head_length=0.05, fc='k', ec='k')
        elif matrixMDP[i][j] == -1:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="gray"))
        elif matrixMDP[i][j] == -2:
            # print("==")
            # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="black"))
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor='none', edgecolor='black', hatch='//'))

    
    if put_start:
        i, j = env.getStateXY(env.getStartState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="darkblue"))
        robot_image = mpimg.imread('robot.png')
        ax.imshow(robot_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))

    if put_goal:
        i, j = env.getStateXY(env.getGoalState())
        # ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="white"))
        
        apple_image = mpimg.imread('apple.jpg')
        ax.imshow(apple_image, aspect='auto', extent=(j, j + 1, numRows - i - 1, numRows - i))

    ax.set_xlim([0, numCols])
    ax.set_ylim([0, numRows])

    # Remove x and y labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.show()


def plotPath(env, states, path):
    numRows, numCols = env.getGridDimensions()
    matrixMDP = env.matrixMDP
    plt.clf()
    ax = plt.gca()  # Get the current axis

    for idx in range(env.getNumStates()):
        i, j = env.getStateXY(idx)
        if matrixMDP[i][j] == -1:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="gray"))
        if matrixMDP[i][j] == -2:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor='none', edgecolor='black', hatch='//'))
        if idx in states:
            ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="red"))
            
    i, j = env.getStateXY(env.getGoalState())
    ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="cyan"))
    
    i, j = env.currX, env.currY
    ax.add_patch(patches.Rectangle((j, numRows - i - 1), 1.0, 1.0, facecolor="red"))
    
    plt.xlim([0, numCols])
    plt.ylim([0, numRows])

    for i in range(numCols):
        plt.axvline(i, color='k', linestyle=':')
    plt.axvline(numCols, color='k', linestyle=':')
    
    for j in range(numRows):
        plt.axhline(j, color='k', linestyle=':')
    plt.axhline(numRows, color='k', linestyle=':')
    plt.savefig(path)
    plt.close()
    # plt.show()


def playAgent(env, agent):
    states = []
    state = env.reset()
    for t in count():
        states.append(state)
        state = agent.encode_state(state)
        action = agent.select_action(state)
        state, reward, done= env.step(env.getActionSet()[action.item()])
        if done or t >= 100:
            break
    return states


def plot_component_policies(em_policy, env, states_to_plot, save_path=None):
    """
    Plot histograms of action probabilities for each component at specified states,
    with an additional column showing π^c distributions.
    
    Args:
        em_policy: EMPolicyIteration instance
        env: GridWorld instance
        states_to_plot: List of state indices to plot
        save_path: Optional path to save the figure
    """
    if not states_to_plot:
        raise ValueError("states_to_plot cannot be empty")
        
    n_states = len(states_to_plot)
    n_components = em_policy.n_components
    actions = em_policy.action_set
    
    # Create figure with subplots, adding one extra column for π^c
    fig = plt.figure(figsize=(4*(n_components + 1), 4*n_states))
    gs = plt.GridSpec(n_states, n_components + 1, width_ratios=[1]*n_components + [0.8])
    
    # Color map for different actions
    colors = ['blue', 'green', 'red', 'orange']
    
    for i, state in enumerate(states_to_plot):
        state_x, state_y = env.getStateXY(state)
        
        # Plot action probabilities for each component
        for k in range(n_components):
            ax = fig.add_subplot(gs[i, k])
            probabilities = em_policy.pi_b[k, state]
            
            # Create bar plot
            x = np.arange(len(actions))
            bars = ax.bar(x, probabilities, color=colors)
            
            # Customize subplot
            ax.set_xticks(x)
            ax.set_xticklabels(actions, rotation=45)
            ax.set_ylim(0, 1.1)
            
            # Add title for first row only
            if i == 0:
                ax.set_title(f'Component {k}\n')
            
            # Add state coordinates on the left
            if k == 0:
                ax.set_ylabel(f'State ({state_x}, {state_y})')
            
            # Add probability values on top of bars
            for bar in bars:
                height = bar.get_height()
                rheight = height
                if height > 0.2:
                    height -= 0.1
                else:
                    height += 0.1
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{rheight:.2f}',
                       ha='center', va='bottom')
        
        # Add π^c distribution plot
        ax_pic = fig.add_subplot(gs[i, -1])
        pic_values = em_policy.pi_c[state][::-1]
        
        # Create horizontal bar plot for π^c
        y = np.arange(n_components)
        bars = ax_pic.barh(y, pic_values, color='purple', alpha=0.5)
        
        # Customize π^c subplot
        ax_pic.set_yticks(y)
        ax_pic.set_yticklabels([f'c{k}' for k in range(n_components)][::-1])
        ax_pic.set_xlim(-0.2, 1.1)
        
        # Add title for first row only
        if i == 0:
            ax_pic.set_title('π^c\n')
        
        # Add probability values inside bars
        for bar in bars:
            width = bar.get_width()
            ax_pic.text(width + 0.1, bar.get_y() + bar.get_height()/2,
                       f'{width:.2f}',
                       ha='center', va='center')
    
    plt.tight_layout()
    
    # Save or show
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

def plot_policy_distribution(policy, env, states_to_plot, save_path=None):
    """
    Plot histograms of action probabilities for specified states.
    
    Args:
        policy: Policy instance (MaxEntropyPolicy)
        env: GridWorld instance
        states_to_plot: List of state indices to plot
        save_path: Optional path to save the figure
    """
    if not states_to_plot:
        raise ValueError("states_to_plot cannot be empty")
        
    n_states = len(states_to_plot)
    actions = policy.action_set
    
    # Create figure
    fig = plt.figure(figsize=(4*n_states, 3))
    
    # Color map for different actions
    colors = ['blue', 'green', 'red', 'orange']
    
    for i, state in enumerate(states_to_plot):
        state_x, state_y = env.getStateXY(state)
        
        # Create subplot for this state
        ax = fig.add_subplot(1, n_states, i+1)
        
        # Get action probabilities
        probabilities = policy.pi[state]
        
        # Create bar plot
        x = np.arange(len(actions))
        bars = ax.bar(x, probabilities, color=colors)
        
        # Customize subplot
        ax.set_xticks(x)
        ax.set_xticklabels(actions, rotation=45)
        ax.set_ylim(0, 1.1)
        ax.grid(True, alpha=0.3)
        
        # Add state coordinates in title
        ax.set_title(f'State ({state_x}, {state_y})')
        
        # Add probability values on top of bars
        for bar in bars:
            height = bar.get_height()
            rheight = round(height, 2)
            text_height = height - 0.1 if height > 0.2 else height + 0.1
            ax.text(bar.get_x() + bar.get_width()/2., text_height,
                   f'{rheight:.2f}',
                   ha='center', va='center')
        
        # Add value function information
        value = policy.V[state]
        ax.text(0.02, 0.95, f'V(s) = {value:.2f}',
                transform=ax.transAxes,
                bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # # Save or show
    # if save_path:
    #     plt.savefig(save_path, bbox_inches='tight', dpi=300)
    #     plt.close()
    #     fig.clear()
    # else:    
    #     plt.show()
    # plt.close()
    # fig.clear()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.close(fig)  # Close this specific figure
    else:
        plt.ion()  # Turn interactive mode back on for display
        plt.show()
        
    # Clear the current figure
    plt.clf()

    
    
