import math
import random
import numpy as np
import time
from IPython.display import display, clear_output
import matplotlib.image as mpimg
from gridworld import *
from utils import *

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.image as mpimg
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase
from typing import Dict, List, Tuple, Optional
import numpy as np

class Agent:
    """
    An agent implementing entropy-regularized policy iteration for reinforcement learning.
    
    This implementation uses a softmax policy with temperature parameter alpha and
    includes numerical stability improvements via the log-sum-exp trick.
    """
    
    def __init__(
        self, 
        gamma: float,
        env,
        rewards: Optional[np.ndarray] = None,
        alpha: float = 0.1
    ) -> None:
        """
        Initialize the agent.
        
        Args:
            gamma: Discount factor (0 < gamma ≤ 1)
            env: Environment object that implements required interface
            rewards: Optional custom reward matrix of shape (num_states, num_actions)
            alpha: Temperature parameter for entropy regularization
        
        Raises:
            ValueError: If parameters are invalid
        """
        if not 0 < gamma <= 1:
            raise ValueError("Gamma must be in (0, 1]")
        if alpha <= 0:
            raise ValueError("Alpha must be positive")
            
        self.gamma = gamma
        self.env = env
        self.alpha = alpha
        self.rewards = rewards
        
        # Initialize state space
        self.num_states = env.getNumStates()
        if self.num_states <= 0:
            raise ValueError("Environment must have at least one state")
            
        # Initialize action space
        self.action_set = env.getActionSet()
        if not self.action_set:
            raise ValueError("Environment must have at least one action")
        
        # Initialize value function and policy
        self.V = np.zeros(self.num_states)
        self.pi = np.full(
            (self.num_states, len(self.action_set)), 
            1/len(self.action_set)
        )
        
        # Initialize iteration counter
        self.itr = 0

    def get_reward(self, state: int, action: int) -> float:
        """Get reward for state-action pair, using custom rewards if provided."""
        if self.rewards is not None:
            return self.rewards[state, action]
        return self.env.getRewardS(state, action)

    def entropy(self, probabilities: np.ndarray) -> float:
        """
        Calculate the entropy of a probability distribution.
        
        Args:
            probabilities: Array of probabilities
            
        Returns:
            Entropy value
        """
        return -np.sum(probabilities * np.log(probabilities + 1e-10))
   
    def _eval_policy(self, theta: float = 0.001) -> float:
        """
        Evaluate the current policy using entropy-regularized value iteration.
        
        Args:
            theta: Convergence threshold
            
        Returns:
            Maximum change in value function
        """
        delta = 0.0
        for s in self.env.getAvailableStates():
            if self.env.isStateWall(s):
                continue
                
            if self.env.isGoalState(s):
                # For goal states, we use the maximum reward over all actions
                if len(self.env.getAvailableActions(s)) > 0:
                    max_reward = float('-inf')
                    for a in range(len(self.action_set)):
                        if self.action_set[a] in self.env.getAvailableActions(s):
                            reward = self.get_reward(s, a)
                            max_reward = max(max_reward, reward)
                    self.V[s] = max_reward
                continue
            
            v = self.V[s]
            available_actions = self.env.getAvailableActions(s)
            if not available_actions:
                continue
                
            # Calculate max term for numerical stability
            max_term = float('-inf')
            for a in range(len(self.action_set)):
                if self.action_set[a] in available_actions:
                    reward = self.get_reward(s, a)
                    Ts = self.env.transition_probabilities(s, self.action_set[a])
                    sum_next = sum(Ts[next_s] * self.V[next_s] for next_s in Ts)
                    term = (reward + self.gamma * sum_next) / self.alpha
                    max_term = max(max_term, term)

            # Compute log-sum-exp
            sum_exp = 0
            for a in range(len(self.action_set)):
                if self.action_set[a] in available_actions:
                    reward = self.get_reward(s, a)
                    Ts = self.env.transition_probabilities(s, self.action_set[a])
                    sum_next = sum(Ts[next_s] * self.V[next_s] for next_s in Ts)
                    term = (reward + self.gamma * sum_next) / self.alpha
                    sum_exp += np.exp(term - max_term)

            self.V[s] = self.alpha * (np.log(sum_exp) + max_term)
            delta = max(delta, abs(v - self.V[s]))

        return delta
    
    def _improve_policy(self) -> bool:
        """
        Improve the current policy using soft policy improvement.
        
        Returns:
            Boolean indicating whether the policy has converged
        """
        policy_stable = True
        for s in self.env.getAvailableStates():
            if self.env.isStateWall(s):
                continue
                
            available_actions = self.env.getAvailableActions(s)
            if not available_actions:
                continue
                
            old_policy = np.copy(self.pi[s])
            
            # Find max term for numerical stability
            max_term = float('-inf')
            for a in range(len(self.action_set)):
                if self.action_set[a] in available_actions:
                    reward = self.get_reward(s, a)
                    Ts = self.env.transition_probabilities(s, self.action_set[a])
                    sum_next = sum(Ts[next_s] * self.V[next_s] for next_s in Ts)
                    term = (reward + self.gamma * sum_next) / self.alpha
                    max_term = max(max_term, term)

            # Compute new policy probabilities
            temp_v = np.zeros(len(self.action_set))
            sum_exp = 0
            for a in range(len(self.action_set)):
                if self.action_set[a] in available_actions:
                    reward = self.get_reward(s, a)
                    Ts = self.env.transition_probabilities(s, self.action_set[a])
                    sum_next = sum(Ts[next_s] * self.V[next_s] for next_s in Ts)
                    term = (reward + self.gamma * sum_next) / self.alpha
                    temp_v[a] = np.exp(term - max_term)
                    sum_exp += temp_v[a]
                    
            # Normalize probabilities
            temp_v /= (sum_exp + 1e-10)
            self.pi[s] = temp_v
            
            # Check if policy has significantly changed
            if not np.allclose(self.pi[s], old_policy, atol=0.1):
                policy_stable = False

        return policy_stable

    def one_step_policy_iteration(self) -> None:
        """Perform one step of policy iteration."""
        self._eval_policy()
        self._improve_policy()
    
    def solve_policy_iteration(
        self, 
        iterations: int = 10_000, 
        theta: float = 0.001,
        verbose: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Solve the MDP using entropy-regularized policy iteration.
        
        Args:
            iterations: Maximum number of iterations
            theta: Convergence threshold for policy evaluation
            verbose: Whether to print progress information
            
        Returns:
            Tuple of (value_function, policy)
        """
        if iterations <= 0:
            raise ValueError("iterations must be positive")
        if theta <= 0:
            raise ValueError("theta must be positive")
            
        policy_stable = False
        while not policy_stable and self.itr < iterations:
            # Policy evaluation
            while True:
                self.itr += 1
                delta = self._eval_policy(theta)
                
                if verbose and self.itr % 100 == 0:
                    print(f"Iteration {self.itr}, delta: {delta}")
                    
                if delta < theta or self.itr >= iterations:
                    break
                    
            # Policy improvement
            policy_stable = self._improve_policy()
            
        if verbose:
            if policy_stable:
                print(f"Converged after {self.itr} iterations")
            else:
                print(f"Maximum iterations ({iterations}) reached")
                
        return self.V, self.pi
    
class AgentMax:
    V = None
    pi = None
    gamma = 0.9
    numStates = 0
    actionSet = None
    env = None
    itr = 0

    def __init__(self, gamma, env):
        self.gamma = gamma
        self.env = env
        self.numStates = env.getNumStates()
        self.V = np.zeros(self.numStates)
        self.actionSet = env.getActionSet()
        self.pi = np.full((self.numStates, len(self.actionSet)), 1/len(self.actionSet))
        self.rewards = None
        self.alpha=0.2
        
    def _evalPolicy(self):
        delta = 0.0
        for s in self.env.getAvailableStates():
            if not self.env.isStateWall(s) and not self.env.isGoalState(s):
                v = self.V[s]
                action_values = np.zeros(len(self.actionSet))
                for a in range(len(self.actionSet)):
                    if self.actionSet[a] in self.env.getAvailableActions(s):
                        reward = self.env.getRewardS(s, a)
                        
                        action_values[a] = reward
                        Ts = self.env.transition_probabilities(s, self.actionSet[a])
                        for next_state_id in Ts.keys():
                            action_values[a] += self.gamma * Ts[next_state_id] * self.V[next_state_id]
                max_value = np.max(action_values)
                self.V[s] = max_value
                delta = max(delta, abs(v - self.V[s]))
            elif self.env.isGoalState(s):
                self.V[s] = self.env.getRewardS(s, None)
        
        return delta
    
    def _improvePolicy(self):
        policy_stable = True
        for s in self.env.getAvailableStates():
            if not self.env.isStateWall(s):
                p = np.copy(self.pi[s])
                tempV = np.zeros(shape=(len(self.actionSet),))
                max_term = float('-inf')

                # First pass: find max_term for numerical stability
                for a in range(len(self.actionSet)):
                    if self.actionSet[a] in self.env.getAvailableActions(s):
                        reward = self.env.getRewardS(s, a)
                        Ts = self.env.transition_probabilities(s, self.actionSet[a])
                        sum_next = sum(Ts[next_state_id] * self.V[next_state_id] for next_state_id in Ts.keys())
                        term = (reward + self.gamma * sum_next) / self.alpha
                        max_term = max(max_term, term)
                
                # Second pass: compute stable exponentials
                sum_exp = 0
                for a in range(len(self.actionSet)):
                    if self.actionSet[a] in self.env.getAvailableActions(s):
                        reward = self.env.getRewardS(s, a)
                        Ts = self.env.transition_probabilities(s, self.actionSet[a])
                        sum_next = sum(Ts[next_state_id] * self.V[next_state_id] for next_state_id in Ts.keys())
                        term = (reward + self.gamma * sum_next) / self.alpha
                        tempV[a] = np.exp(term - max_term)
                        
                        sum_exp += tempV[a]
                tempV /= sum_exp
                self.pi[s] = tempV
                # Check if the policy has significantly changed
                if not np.all(np.isclose(self.pi[s], p, atol=0.1)):
                    policy_stable = False

        return policy_stable

    def solvePolicyIteration(self, iterations, theta=0.001):
        policy_stable = False
        while not policy_stable:
            while True:
                self.itr += 1
                delta = self._evalPolicy()
                if delta < theta:
                    break
                if self.itr>iterations:
                    break
            policy_stable = self._improvePolicy()
        return self.V, self.pi

if __name__ == "__main__":
    gamma = 0.99
    iterations = 1000
    stochastic = True
    env_name = 'mdps/labyrinth.mdp'
    
    env = GridWorld(path = env_name, useNegativeRewards=False, stochastic=stochastic)
    
    agent_without_ref = Agent(None, gamma, env, alpha=0.2)
    V_without_ref, pi_without_ref = agent_without_ref.solvePolicyIteration(10000)
    
    test_poilicy(env, pi_without_ref)
     
    # print(pi_without_ref.round(3))
    
    env = GridWorld(path = env_name, useNegativeRewards=False, stochastic=stochastic)
    agent_max = AgentMax(gamma, env)
    V_max, pi_max = agent_max.solvePolicyIteration(iterations)
    # pi = softmax(agent_max.V, gamma)
    # print(pi_max)
    test_poilicy(env, pi_max)
 
    
    env = GridWorld(path = env_name, useNegativeRewards=True)
    
    numStates = env.getNumStates()
    A = env.getAdjacencyMatrix()
    D = np.zeros((numStates, numStates))
    
    for i in range(numStates):
        for j in range(numStates):
            D[i][i] = np.sum(A[i])
    
    for i in range(numStates):
        if D[i][i] == 0.0:
            D[i][i] = 1.0
    
    L = D - A
    eigenvalues, eigenvectors = np.linalg.eigh(L)

    cv = 3

    rewards = np.zeros((numStates, len(env.getActionSet())))
    for idx in range(env.getNumStates()): 
        i, j = env.getStateXY(idx)        
        if env.matrixMDP[i][j] != -1:
            diff = np.linalg.norm(eigenvectors[env.getStartState(), 1:cv] - eigenvectors[idx, 1:cv])
            rewards[idx, :] = diff
    

    lap_ref = Agent(None, gamma, env, rewards=rewards, alpha=1.0)
    agent_with_lapc_ref = Agent(lap_ref, gamma, env, alpha=0.2)
    v_with_lapc_ref, pi_with_lapc_ref = agent_with_lapc_ref.solvePolicyIteration(iterations)
    test_poilicy(env, pi_with_lapc_ref)
    
    # # env_ref = GridWorld(path = env_name, useNegativeRewards=True)
    # # lap_ref = Agent(None, gamma, env_ref, rewards=rewards, alpha=0.2)
    # # # v_lap, pi_lap = lap_ref.solvePolicyIteration(iterations)
    # # # lap_ref = None
    
    # # env = GridWorld(path = env_name, useNegativeRewards=True)
    # # agent_with_lapc_ref = Agent(lap_ref, gamma, env, alpha=0.1)
    # # v_with_lap_ref, pi_with_lap_ref = agent_with_lapc_ref.solvePolicyIteration(iterations, train_referece=True)
    
    # env = GridWorld(path = env_name, useNegativeRewards=True)
    # agent_with_lapc_ref = Agent(None, gamma, env, alpha=0.1)
    # v_with_lap_ref, pi_with_lap_ref = agent_with_lapc_ref.solvePolicyIteration(iterations, train_referece=False)