from typing import Optional, Tuple
import numpy as np


class FourRoomsMDP:    
    def __init__(
        self, 
        start_state: Optional[Tuple[int, int]] = None, 
        goal_state: Optional[Tuple[int, int]] = None, 
    ):
        self.rows = 13
        self.cols = 13
        
        self.actions = {
            0: (-1, 0),
            1: (1, 0),
            2: (0, -1),
            3: (0, 1),
        }
        self.num_actions = len(self.actions)
        
        self.walls = self._create_walls()
        
        self.state_to_coord = []
        self.coord_to_state = {}
        
        self.start_state = start_state
        self.goal_state = goal_state
        
        s = 0
        for r in range(self.rows):
            for c in range(self.cols):
                if (r, c) not in self.walls:
                    self.state_to_coord.append((r, c))
                    self.coord_to_state[(r, c)] = s
                    s += 1
                    
        self.num_states = len(self.state_to_coord)
        
        self._transition_matrix = self._build_transition_matrix()

    def _create_walls(self):
        walls = set()
        
        for r in range(self.rows):
            walls.add((r, 0))
            walls.add((r, self.cols - 1))
        for c in range(self.cols):
            walls.add((0, c))
            walls.add((self.rows - 1, c))
            
        for r in [1, 3, 4, 5]:
            walls.add((r, 6))
        for r in [7, 8, 10, 11]:
            walls.add((r, 7))
            
        for c in [1, 2, 4, 5, 6, 7, 8, 9, 11]:
            walls.add((6, c))
            
        return walls

    def step(self, state, action):
        if state < 0 or state >= self.num_states:
            raise ValueError(f"Invalid state: {state}")
        if action < 0 or action >= self.num_actions:
            raise ValueError(f"Invalid action: {action}")

        r, c = self.state_to_coord[state]
        
        dr, dc = self.actions[action]
        
        next_r, next_c = r + dr, c + dc
        
        if (next_r, next_c) in self.walls:
            next_state = state
        else:
            next_state = self.coord_to_state[(next_r, next_c)]
        
        if self.goal_state is not None:
            if next_state == self.coord_to_state[self.goal_state]:
                done = True
                reward = 1.0
            else:
                done = False
                reward = 0.0
        else:
            done = False
            reward = 0.0
            
        info = {}
        
        return next_state, reward, done, info

    def reset(self):
        if self.start_state is None:
            self.start_state = list(self.coord_to_state.keys())[np.random.choice(self.num_states)]
        self.current_state = self.coord_to_state[self.start_state]
        return self.current_state

    def render(self, current_state=None):
        grid = [[" " for _ in range(self.cols)] for _ in range(self.rows)]
        
        # Draw walls
        for (r, c) in self.walls:
            grid[r][c] = "█"
            
        # Draw current state
        if current_state is not None and 0 <= current_state < self.num_states:
            r, c = self.state_to_coord[current_state]
            grid[r][c] = "A"
        
        start_r, start_c = self.start_state
        grid[start_r][start_c] = "S"

        goal_r, goal_c = self.goal_state
        grid[goal_r][goal_c] = "G"

        # Print the grid
        print("-" * (self.cols * 2 + 1))
        for r in range(self.rows):
            row_str = "|" + "|".join(grid[r]) + "|"
            print(row_str)
            print("-" * (self.cols * 2 + 1))

    def _build_transition_matrix(self):
        P = np.zeros((self.num_states, self.num_actions, self.num_states))
        for s in range(self.num_states):
            for a in range(self.num_actions):
                next_s, _, _, _ = self.step(s, a)
                P[s, a, next_s] = 1.0
        return P

    def get_transition_matrix(self):
        return self._transition_matrix

    def get_policy_transition_matrix(self, policy):
        if policy.shape != (self.num_states, self.num_actions):
            raise ValueError(f"Policy shape must be ({self.num_states}, {self.num_actions})")
        
        P = self.get_transition_matrix()
        P_pi = np.einsum('sa,sas->ss', policy, P)
        
        return P_pi
    
    def get_states(self):
        return list(range(self.num_states))
    
    def get_adjacency_matrix(self):
        P = self.get_transition_matrix()
        A = (P.sum(axis=1) > 0).astype(int)
        return A