import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple, List, Optional

class StochasticGridWorld:
    def __init__(self, 
                 size: int = 9, 
                 reward_region: List[Tuple[int, int]] = None,
                 reward_value: float = 1.0,
                 action_boost: float = 0.4,
                 random_seed: Optional[int] = None):
        """
        Initialize a 9x9 stochastic grid world environment.
        
        Args:
            size: Grid size (default 9x9)
            reward_region: List of (row, col) tuples defining reward region
            reward_value: Reward value for reward region (default 1.0)
            action_boost: Additional probability for intended action (default 0.4)
                         Total prob = 0.25 + action_boost, others get (0.75 - action_boost)/3
            random_seed: Random seed for reproducibility
        """
        self.size = size
        self.reward_value = reward_value
        self.action_boost = action_boost
        
        # Calculate probabilities
        self.intended_prob = 0.25 + action_boost  # Base 1/4 + boost
        self.other_prob = (0.75 - action_boost) / 3  # Remaining prob split among 3 other actions
        
        if random_seed is not None:
            np.random.seed(random_seed)
            random.seed(random_seed)
        
        # Define actions: 0=up, 1=right, 2=down, 3=left
        self.actions = ['up', 'right', 'down', 'left']
        self.action_effects = {
            0: (-1, 0),  # up
            1: (0, 1),   # right
            2: (1, 0),   # down
            3: (0, -1)   # left
        }
        
        # Initialize reward grid
        self.reward_grid = np.zeros((size, size))
        
        # Set reward region (default: top-right corner)
        if reward_region is None:
            reward_region = [(0, size-1), (1, size-1), (0, size-2)]
        
        self.reward_region = reward_region
        for row, col in reward_region:
            if 0 <= row < size and 0 <= col < size:
                self.reward_grid[row, col] = reward_value
        
        # Initialize agent position
        self.agent_pos = [0, 0]  # Start at top-left
        self.initial_pos = [0, 0]
    
    def reset(self, start_pos: Optional[Tuple[int, int]] = None) -> Tuple[int, int]:
        """
        Reset the environment to initial state.
        
        Args:
            start_pos: Optional starting position, otherwise use initial_pos
            
        Returns:
            Starting position as (row, col)
        """
        if start_pos is not None:
            self.agent_pos = list(start_pos)
        else:
            self.agent_pos = list(self.initial_pos)
        return tuple(self.agent_pos)
    
    def is_valid_position(self, row: int, col: int) -> bool:
        """Check if position is within grid bounds."""
        return 0 <= row < self.size and 0 <= col < self.size
    
    def get_next_position(self, current_pos: List[int], action: int) -> List[int]:
        """
        Get next position given current position and action.
        If next position is out of bounds, stay in current position.
        """
        row, col = current_pos
        d_row, d_col = self.action_effects[action]
        new_row, new_col = row + d_row, col + d_col
        
        if self.is_valid_position(new_row, new_col):
            return [new_row, new_col]
        else:
            return current_pos  # Stay in place if hitting boundary
    
    def step(self, action: int) -> Tuple[Tuple[int, int], float, bool]:
        """
        Take a step in the environment with stochastic transitions.
        Each direction has base probability 1/4, with intended action getting extra boost.
        
        Args:
            action: Intended action (0=up, 1=right, 2=down, 3=left)
            
        Returns:
            Tuple of (new_state, reward, done)
        """
        if not 0 <= action <= 3:
            raise ValueError("Action must be between 0 and 3")
        
        # Create probability distribution over all actions
        action_probs = [self.other_prob] * 4  # Base probability for all actions
        action_probs[action] = self.intended_prob  # Boost for intended action
        
        # Sample actual action based on probabilities
        actual_action = np.random.choice([0, 1, 2, 3], p=action_probs)
        
        # Move agent
        self.agent_pos = self.get_next_position(self.agent_pos, actual_action)
        
        # Calculate reward
        reward = self.reward_grid[self.agent_pos[0], self.agent_pos[1]]
        
        # Check if episode is done (reached reward region)
        # done = tuple(self.agent_pos) in self.reward_region
        
        return tuple(self.agent_pos), reward, 0
    
    def get_state(self) -> Tuple[int, int]:
        """Get current agent position."""
        return tuple(self.agent_pos)
    
    def get_reward(self, state: Tuple[int, int]) -> float:
        """Get reward for a given state."""
        row, col = state
        if self.is_valid_position(row, col):
            return self.reward_grid[row, col]
        return 0.0
    
    def render(self, figsize: Tuple[int, int] = (8, 8), save_path: Optional[str] = None) -> None:
        """
        Render visual representation of the grid world using matplotlib.
        
        Args:
            figsize: Figure size as (width, height)
            save_path: Optional path to save the figure
        """
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        
        # Create grid
        for i in range(self.size + 1):
            ax.axhline(i, color='black', linewidth=1)
            ax.axvline(i, color='black', linewidth=1)
        
        # Fill reward regions
        for row, col in self.reward_region:
            rect = patches.Rectangle((col, self.size - row - 1), 1, 1, 
                                   linewidth=1, edgecolor='black', 
                                   facecolor='gold', alpha=0.7)
            ax.add_patch(rect)
            # Add 'R' text
            ax.text(col + 0.5, self.size - row - 1 + 0.5, 'R', 
                   fontsize=16, fontweight='bold', ha='center', va='center')
        
        # Add agent
        agent_row, agent_col = self.agent_pos
        agent_rect = patches.Rectangle((agent_col, self.size - agent_row - 1), 1, 1,
                                     linewidth=2, edgecolor='red',
                                     facecolor='lightcoral', alpha=0.8)
        ax.add_patch(agent_rect)
        # Add 'A' text
        ax.text(agent_col + 0.5, self.size - agent_row - 1 + 0.5, 'A',
               fontsize=16, fontweight='bold', ha='center', va='center', color='white')
        
        # Fill empty cells with light gray
        for row in range(self.size):
            for col in range(self.size):
                if (row, col) not in self.reward_region and [row, col] != self.agent_pos:
                    rect = patches.Rectangle((col, self.size - row - 1), 1, 1,
                                           linewidth=1, edgecolor='black',
                                           facecolor='lightgray', alpha=0.3)
                    ax.add_patch(rect)
        
        # Set axis properties
        ax.set_xlim(0, self.size)
        ax.set_ylim(0, self.size)
        ax.set_xticks([])  # removes x-axis tick labels
        ax.set_yticks([])  # removes y-axis tick labels
        ax.set_aspect('equal')
        ax.invert_yaxis()  # To match grid coordinates (0,0 at top-left)
        
        # Add labels
        # ax.set_xlabel('Column', fontsize=12)
        # ax.set_ylabel('Row', fontsize=12)
        ax.set_title(f'Grid World - Agent: {tuple(self.agent_pos)}, Reward: {self.get_reward(tuple(self.agent_pos)):.1f}', 
                    fontsize=14, fontweight='bold')
        
        # Add grid coordinates
        for i in range(self.size):
            ax.text(-0.3, i + 0.5, str(self.size - 1 - i), ha='center', va='center', fontsize=10)
            ax.text(i + 0.5, self.size + 0.3, str(i), ha='center', va='center', fontsize=10)
        
        # Add legend
        legend_elements = [
            patches.Patch(color='lightcoral', alpha=0.8, label='Agent (A)'),
            patches.Patch(color='gold', alpha=0.7, label='Reward Region (R)'),
            patches.Patch(color='lightgray', alpha=0.3, label='Empty Cell')
        ]
        ax.legend(
            handles=legend_elements,
            loc="upper left",          # anchor the legend box by its top-left corner
            bbox_to_anchor=(0, -0.05), # small negative y pushes it just below the axes
        )
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Grid world saved to {save_path}")
        
        plt.show()
        
        # Print text info as well
        print(f"Agent position: {tuple(self.agent_pos)}")
        print(f"Current reward: {self.get_reward(tuple(self.agent_pos))}")
        print(f"Reward regions: {self.reward_region}")
        print(f"Action probabilities - Intended: {self.intended_prob:.3f}, Others: {self.other_prob:.3f}")
    
    def get_all_states(self) -> List[Tuple[int, int]]:
        """Get all possible states in the grid."""
        states = []
        for row in range(self.size):
            for col in range(self.size):
                states.append((row, col))
        return states
    
    def get_transition_probabilities(self, state: Tuple[int, int], action: int) -> dict:
        """
        Get transition probabilities from a state given an action.
        Each direction has base probability, with intended action getting extra boost.
        
        Args:
            state: Current state (row, col)
            action: Intended action
            
        Returns:
            Dictionary mapping next_state -> probability
        """
        current_pos = list(state)
        transitions = {}
        
        # Calculate probabilities for each possible action
        for possible_action in range(4):
            next_pos = tuple(self.get_next_position(current_pos, possible_action))
            
            if possible_action == action:
                # Intended action gets higher probability
                prob = self.intended_prob
            else:
                # Other actions get lower probability
                prob = self.other_prob
            
            if next_pos in transitions:
                transitions[next_pos] += prob
            else:
                transitions[next_pos] = prob
        
        return transitions

# Example usage and testing
if __name__ == "__main__":
    seed = 0 
    np.random.seed(seed)
    random.seed(seed)
    # Create environment
    env = StochasticGridWorld(
        size=9,
        reward_region=[(8, 8), (7, 8), (8, 7), (7, 7), (6, 8), (8, 6), (6, 6), (7, 5), (5, 7), (4, 8), (8, 4), (8, 5), (5, 8), (6, 7), (7, 6)],  # center
        action_boost=0.65,  # Intended action gets 0.25 + 0.4 = 0.65 probability
        random_seed=seed
    )
    env.render()

    expert_traj = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (7, 4), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8)]
    expert_action = [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1]  # right x4, down x8, right x4
    
    # print("Initial state:")
    initialization_zone = [(0, 0), (0, 1), (1, 0), (1, 1)]
    bc_rewards = []
    for reset_pos in initialization_zone:
        env.reset(start_pos=reset_pos)
        # print(f"Reset to {reset_pos}:")
        state = reset_pos
    
        # BC rollout
        episode_reward = 0
        for i in range(19):
            if state in expert_traj:
                idx = expert_traj.index(state)
                action = expert_action[idx]
            else:
                action = random.choice([0, 1, 2, 3])  # Random action if not in expert traj
            state, reward, done = env.step(action)
            episode_reward += reward
            # print(f"Step {i+1}: Action {env.actions[action]}, New state: {state}, Reward: {reward}, Done: {done}")
            # env.render()
            print()
        print(f"Total episode reward from start {reset_pos}: {episode_reward}")
        bc_rewards.append(episode_reward)

    bc_rewards_average = sum(bc_rewards) / len(bc_rewards)
    print(f"Average BC episode reward over initialization zone: {bc_rewards_average}")
    # Tunable reward table scaling: group states into regions of region_size x region_size
    region_size = 4 # You can change this to any integer >=1
    num_regions_row = env.size // region_size
    num_regions_col = env.size // region_size
    residual_row = env.size % region_size
    residual_col = env.size % region_size

    # Build region mapping: for each (row, col), assign a region index
    region_map = np.zeros((env.size, env.size), dtype=int)
    region_idx = 0
    for i in range(num_regions_row):
        for j in range(num_regions_col):
            for r in range(region_size):
                for c in range(region_size):
                    row = i * region_size + r
                    col = j * region_size + c
                    region_map[row, col] = region_idx
            region_idx += 1
    # Handle residual rows
    if residual_row > 0:
        for j in range(num_regions_col):
            for r in range(residual_row):
                for c in range(region_size):
                    row = num_regions_row * region_size + r
                    col = j * region_size + c
                    region_map[row, col] = region_idx
            region_idx += 1
    # Handle residual cols
    if residual_col > 0:
        for i in range(num_regions_row):
            for r in range(region_size):
                for c in range(residual_col):
                    row = i * region_size + r
                    col = num_regions_col * region_size + c
                    region_map[row, col] = region_idx
            region_idx += 1
    # Handle residual corner (if both residual_row and residual_col > 0)
    if residual_row > 0 and residual_col > 0:
        for r in range(residual_row):
            for c in range(residual_col):
                row = num_regions_row * region_size + r
                col = num_regions_col * region_size + c
                region_map[row, col] = region_idx
        region_idx += 1

    num_regions = region_idx

    # Now, reward_table is indexed by region and action, but Q_table remains per-state
    reward_table = np.zeros((num_regions, 4))
    Q_table = np.zeros((env.size, env.size, 4))
    for interaction in range(10000):
        reset_pos = random.choice(initialization_zone)
        env.reset(start_pos=reset_pos)
        state = reset_pos

        # BC rollout
        episode_reward = 0
        states_visited = []
        actions = []
        for i in range(19):
            action = np.argmax(Q_table[state[0], state[1], :])
            # Epsilon-greedy action selection
            epsilon = 0.1
            if random.random() < epsilon:
                action = random.choice([0, 1, 2, 3])
            actions.append(action)
            states_visited.append(state)
            if state in expert_traj:
                idx = expert_traj.index(state)
                action_true = expert_action[idx]
                if action == action_true:
                    region = region_map[state[0], state[1]]
                    reward_table[region, action] = 1
            state, reward, done = env.step(action)
            episode_reward += reward
        # Q-learning update
        next_state = state
        alpha = 0.5  # learning rate
        gamma = 0.99  # discount factor
        for s, a in zip(states_visited, actions):
            region = region_map[s[0], s[1]]
            best_next_action = np.argmax(Q_table[next_state[0], next_state[1], :])
            td_target = reward_table[region, a] + gamma * Q_table[next_state[0], next_state[1], best_next_action]
            td_delta = td_target - Q_table[s[0], s[1], a]
            Q_table[s[0], s[1], a] += alpha * td_delta
            next_state = s  # Move to previous state for next update
        
    
    # Plot max Q-value table as a heatmap
    max_Q = np.max(Q_table, axis=2)
    plt.figure(figsize=(8, 6))
    plt.imshow(max_Q, cmap='viridis', origin='lower')  # 'lower' puts row 0 at the bottom
    plt.title("Max Q-value per state")
    plt.xlabel('Column')
    plt.ylabel('Row')
    plt.colorbar(label='Max Q-value')
    plt.tight_layout()
    plt.show()
    # Render the optimal policy as arrows
    plt.figure(figsize=(8, 8))
    plt.imshow(max_Q, cmap='viridis', origin='lower', alpha=0.3)
    for row in range(env.size):
        for col in range(env.size):
            best_action = np.argmax(Q_table[row, col, :])
            dx, dy = 0, 0
            if best_action == 2:   # up
                dx, dy = 0, 0.4
            elif best_action == 1: # right
                dx, dy = 0.4, 0
            elif best_action == 0: # down
                dx, dy = 0, -0.4
            elif best_action == 3: # left
                dx, dy = -0.4, 0
            # Note: row 0 is at the bottom because origin='lower'
            plt.arrow(col, row, dx, dy, head_width=0.2, head_length=0.2, fc='red', ec='red')
    plt.title("Optimal Policy (Arrows show best action per state)")
    plt.xlabel('Column')
    plt.ylabel('Row')
    plt.xlim(-0.5, env.size - 0.5)
    plt.ylim(-0.5, env.size - 0.5)
    # Remove invert_yaxis to keep y=0 at the bottom
    plt.grid(True, which='both', color='black', linewidth=0.5, alpha=0.3)
    plt.tight_layout()
    plt.show()
    # Evaluate Q-learning policy on 10 episodes
    eval_rewards = []
    for reset_pos in initialization_zone:
        env.reset(start_pos=reset_pos)
        state = reset_pos
        episode_reward = 0
        for t in range(19):
            action = np.argmax(Q_table[state[0], state[1], :])
            print(f"Eval Step {t+1}: Action {env.actions[action]}, State: {state}")
            state, reward, done = env.step(action)
            episode_reward += reward
            # env.render()  # Visualize agent every step
        eval_rewards.append(episode_reward)

    avg_eval_reward = sum(eval_rewards) / len(eval_rewards)
    print(f"Average Q-learning episode reward over 10 evaluation episodes: {avg_eval_reward}")

    
