"""
Open Grid Chase Environment

Cat (agent) chases Mouse (target) on a 10x10 grid.

Cat:
- Can move in 8 directions (4 cardinal + 4 diagonal)
- Has 8 options corresponding to each direction

Mouse:
- Has a direction (UP, RIGHT, DOWN, LEFT) and moves in that direction EVERY step
- With probability p_turn, mouse TURNS to a new direction
- Mouse can only turn to directions going AWAY from cat (never toward cat)
- If at wall, mouse bounces (reverses direction on that axis)

Key insight:
- p_turn=0: Mouse moves predictably in one direction (easy to intercept)
- p_turn=1: Mouse constantly changes direction away from cat (hard to catch)
"""

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from options.Buffer.Buffer import Buffer
from options.Buffer.Experience import Experience
from options.Option import Option

# Directions for mouse (cardinal only)
DIR_UP = 0
DIR_RIGHT = 1
DIR_DOWN = 2
DIR_LEFT = 3

# Actions for cat (8 directions)
ACT_UP = 0
ACT_UP_RIGHT = 1
ACT_RIGHT = 2
ACT_DOWN_RIGHT = 3
ACT_DOWN = 4
ACT_DOWN_LEFT = 5
ACT_LEFT = 6
ACT_UP_LEFT = 7

GRID_SIZE = 10

# Direction vectors (dr, dc) for each action
ACTION_DELTAS = {
    ACT_UP: (-1, 0),
    ACT_UP_RIGHT: (-1, 1),
    ACT_RIGHT: (0, 1),
    ACT_DOWN_RIGHT: (1, 1),
    ACT_DOWN: (1, 0),
    ACT_DOWN_LEFT: (1, -1),
    ACT_LEFT: (0, -1),
    ACT_UP_LEFT: (-1, -1),
}

# Direction vectors for mouse (cardinal only)
MOUSE_DIR_DELTAS = {
    DIR_UP: (-1, 0),
    DIR_RIGHT: (0, 1),
    DIR_DOWN: (1, 0),
    DIR_LEFT: (0, -1),
}


def ind2coord(index):
    """Convert state index to (row, col)."""
    row = index // GRID_SIZE
    col = index % GRID_SIZE
    return (row, col)


def coord2ind(coord):
    """Convert (row, col) to state index."""
    return coord[0] * GRID_SIZE + coord[1]


class OpenGridChase(gym.Env):
    """
    Cat chases Mouse on a 10x10 grid.
    
    - Cat can move in 8 directions
    - Mouse moves in its current direction every step
    - Mouse turns with probability p_turn, prioritizing:
      1. Getting OFF the cat's trajectory line
      2. Moving AWAY from the cat's position
    """
    metadata = {'render.modes': ['human']}

    def __init__(self, p_turn=0.0, gamma=0.99, switch_reward=-0.1, 
                 step_reward=-0.1, terminal_reward=10):
        self.grid_size = GRID_SIZE
        self.n_states = GRID_SIZE * GRID_SIZE  # 100 states
        self.gamma = gamma
        self.switch_reward = switch_reward
        self.step_reward = step_reward
        self.terminal_reward = terminal_reward
        
        # Probability of mouse turning each step
        self.p_turn = p_turn
        
        # All cells are valid (open grid)
        self.valid_cells = list(range(self.n_states))
        
        self.action_space = spaces.Discrete(8)  # 8 directions for cat
        self.observation_space = spaces.Discrete(self.n_states)
        self.absorbing_state = self.n_states
        
        # Cat (agent) position and last movement direction
        self.state = None
        self.cat_direction = None  # Track cat's movement direction (dr, dc)
        
        # Mouse (target) position and direction
        self.mouse_state = None
        self.mouse_direction = None  # DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT
        
        self.done = False
        self._reset()

    def _move_cat(self, action):
        """Move cat in given direction, blocked by grid boundaries."""
        row, col = ind2coord(self.state)
        dr, dc = ACTION_DELTAS[action]
        
        new_row = row + dr
        new_col = col + dc
        
        # Boundary check
        if new_row < 0 or new_row >= self.grid_size:
            new_row = row
        if new_col < 0 or new_col >= self.grid_size:
            new_col = col
        
        # Track cat's movement direction for mouse evasion
        self.cat_direction = (dr, dc)
            
        return coord2ind((new_row, new_col))

    def _is_position_on_cat_trajectory(self, pos_row, pos_col):
        """Check if a position is on the cat's trajectory line.
        
        Returns True if cat's current direction would eventually hit this position.
        """
        if self.cat_direction is None:
            return False
        
        cat_row, cat_col = ind2coord(self.state)
        cat_dr, cat_dc = self.cat_direction
        
        # Vector from cat to position
        delta_row = pos_row - cat_row
        delta_col = pos_col - cat_col
        
        # If cat isn't moving, not on trajectory
        if cat_dr == 0 and cat_dc == 0:
            return False
        
        # Check if position is in the direction cat is moving
        # Position must be "ahead" of cat in the direction of movement
        if cat_dr != 0:
            if (cat_dr > 0 and delta_row <= 0) or (cat_dr < 0 and delta_row >= 0):
                return False  # Position is behind or at same row
        if cat_dc != 0:
            if (cat_dc > 0 and delta_col <= 0) or (cat_dc < 0 and delta_col >= 0):
                return False  # Position is behind or at same col
        
        # Check if position is on the trajectory line
        if cat_dr == 0:
            # Moving horizontally only - on trajectory if same row
            return delta_row == 0
        elif cat_dc == 0:
            # Moving vertically only - on trajectory if same col
            return delta_col == 0
        else:
            # Diagonal movement - check if on the diagonal line
            # Allow some tolerance (within 1 cell of trajectory)
            ratio_row = delta_row / cat_dr
            ratio_col = delta_col / cat_dc
            return abs(ratio_row - ratio_col) <= 1
    
    def _is_on_cat_trajectory(self):
        """Check if mouse's current position is on the cat's trajectory."""
        mouse_row, mouse_col = ind2coord(self.mouse_state)
        return self._is_position_on_cat_trajectory(mouse_row, mouse_col)
    
    def _get_escape_directions(self):
        """Get best escape directions, prioritizing:
        1. NOT stepping INTO the cat's trajectory
        2. Getting OFF the cat's trajectory if already on it
        3. Moving AWAY from the cat
        
        Returns list of best directions.
        """
        cat_row, cat_col = ind2coord(self.state)
        mouse_row, mouse_col = ind2coord(self.mouse_state)
        currently_on_trajectory = self._is_on_cat_trajectory()
        
        candidates = []
        
        for direction in [DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT]:
            dr, dc = MOUSE_DIR_DELTAS[direction]
            new_row = mouse_row + dr
            new_col = mouse_col + dc
            
            # Skip if would hit wall
            if not (0 <= new_row < self.grid_size and 0 <= new_col < self.grid_size):
                continue
            
            would_be_on_trajectory = self._is_position_on_cat_trajectory(new_row, new_col)
            dist_before = abs(mouse_row - cat_row) + abs(mouse_col - cat_col)
            dist_after = abs(new_row - cat_row) + abs(new_col - cat_col)
            moves_away = dist_after > dist_before
            
            candidates.append({
                'dir': direction,
                'steps_into_trajectory': not currently_on_trajectory and would_be_on_trajectory,
                'escapes_trajectory': currently_on_trajectory and not would_be_on_trajectory,
                'moves_away': moves_away,
            })
        
        if not candidates:
            return [DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT]
        
        # Priority 1: Filter out directions that step INTO trajectory
        safe = [c for c in candidates if not c['steps_into_trajectory']]
        if safe:
            candidates = safe
        
        # Priority 2: Prefer directions that escape trajectory
        escaping = [c for c in candidates if c['escapes_trajectory']]
        if escaping:
            candidates = escaping
        
        # Priority 3: Prefer directions that move away from cat
        away = [c for c in candidates if c['moves_away']]
        if away:
            candidates = away
        
        return [c['dir'] for c in candidates]
    
    def _get_valid_mouse_directions(self):
        """Get directions for mouse to escape.
        
        Uses smart escape logic that considers cat's trajectory.
        """
        return self._get_escape_directions()

    def _move_mouse(self):
        """Maybe turn first, then move in current direction.
        
        With probability p_turn, mouse first picks a new direction away from cat.
        Then mouse moves in its current direction.
        If mouse would hit a wall, it stays in place (no bounce).
        """
        # Maybe turn FIRST (before moving) - this is key!
        if np.random.random() < self.p_turn:
            valid_dirs = self._get_valid_mouse_directions()
            if valid_dirs:
                self.mouse_direction = np.random.choice(valid_dirs)
        
        # Now move in current direction
        row, col = ind2coord(self.mouse_state)
        dr, dc = MOUSE_DIR_DELTAS[self.mouse_direction]
        
        new_row = row + dr
        new_col = col + dc
        
        # If would hit wall, stay in place (no bounce, keep direction)
        if new_row < 0 or new_row >= self.grid_size:
            new_row = row  # Stay in place
        if new_col < 0 or new_col >= self.grid_size:
            new_col = col  # Stay in place
        
        self.mouse_state = coord2ind((new_row, new_col))

    def _step(self, action):
        assert self.action_space.contains(action), f"Invalid action: {action}"

        if self.state == self.absorbing_state or self.done:
            return self.absorbing_state, 0, True, {
                "mouse_state": self.mouse_state,
                "mouse_direction": self.mouse_direction
            }

        # Remember positions before moves
        cat_old = self.state
        mouse_old = self.mouse_state

        # 1. Cat moves
        new_state = self._move_cat(action)
        self.state = new_state
        
        # 2. Check if cat landed on mouse (catch!)
        if self.state == self.mouse_state:
            self.done = True
            self.state = self.absorbing_state
            return self.state, self.terminal_reward, self.done, {
                "mouse_state": self.mouse_state,
                "mouse_direction": self.mouse_direction
            }
        
        # 3. Mouse moves (always moves, maybe turns)
        self._move_mouse()

        # 4. Check if they passed through each other (swapped positions)
        #    Cat moved to where mouse was, mouse moved to where cat was
        if self.state == mouse_old and self.mouse_state == cat_old:
            self.done = True
            self.state = self.absorbing_state
            return self.state, self.terminal_reward, self.done, {
                "mouse_state": self.mouse_state,
                "mouse_direction": self.mouse_direction
            }

        # 5. Check if mouse ran into cat
        if self.state == self.mouse_state:
            self.done = True
            self.state = self.absorbing_state
            return self.state, self.terminal_reward, self.done, {
                "mouse_state": self.mouse_state,
                "mouse_direction": self.mouse_direction
            }

        # 6. Standard step
        reward = self.step_reward

        return self.state, reward, self.done, {
            "mouse_state": self.mouse_state,
            "mouse_direction": self.mouse_direction
        }

    def execute_option_with_term_prob(self, option: Option, term_prob: float, term_prob_idx: int,
                                       buffer: Buffer, max_steps: int = None):
        """Execute option with a committed TERMINATION PROBABILITY.
        
        At each step, option terminates with probability term_prob.
        This gives a geometric distribution of durations with E[duration] = 1/term_prob.
        
        Args:
            option: The option to execute
            term_prob: Probability of terminating at each step (0.1 to 1.0)
            term_prob_idx: Index into TERM_PROBS list
            buffer: Experience buffer
            max_steps: Maximum steps allowed
            
        Returns:
            steps_done: Actual steps executed
            done: Whether episode ended
        """
        steps_done = 0
        cumulative_reward = 0.0
        
        # Capture state at option start (including mouse direction!)
        state_at_start = self.state
        mouse_at_start = self.mouse_state
        mouse_dir_at_start = self.mouse_direction
        
        while not self.done:
            if max_steps is not None and steps_done >= max_steps:
                break

            action = option.intra_option_policy.get_action(self.state, self.mouse_state)
            next_state, reward, done, info = self._step(action)
            steps_done += 1
            cumulative_reward += reward
            self.done = done
            
            # Check termination (if not already done)
            if not self.done:
                if np.random.random() < term_prob:
                    break  # Option terminates

        # Add switch cost at option end (if not episode end)
        if not self.done:
            cumulative_reward += self.switch_reward

        # Store experience (including mouse direction)
        # Reusing duration_idx field for term_prob_idx
        buffer.add(Experience(
            state_at_start, mouse_at_start,
            option.intra_option_policy.action,
            option, cumulative_reward,
            self.state, self.mouse_state,
            context_at_option_start=mouse_at_start,
            duration_idx=term_prob_idx,
            committed_duration=steps_done,  # Actual duration for logging
            mouse_direction=mouse_dir_at_start
        ))
                
        return steps_done, self.done

    def _reset(self):
        """Reset environment with random positions and mouse direction."""
        # Start cat at random position
        self.state = np.random.choice(self.valid_cells)
        
        # Start mouse at different random position
        self.mouse_state = np.random.choice(self.valid_cells)
        while self.mouse_state == self.state:
            self.mouse_state = np.random.choice(self.valid_cells)
        
        # Random initial direction for mouse
        self.mouse_direction = np.random.choice([DIR_UP, DIR_RIGHT, DIR_DOWN, DIR_LEFT])
        
        # Cat has no direction yet
        self.cat_direction = None
        
        self.done = False
        return self.state, self.mouse_state

    def render(self, mode='human'):
        """Simple text rendering of the grid."""
        grid = [['.' for _ in range(self.grid_size)] for _ in range(self.grid_size)]
        
        # Draw mouse with direction indicator
        mr, mc = ind2coord(self.mouse_state)
        dir_chars = {DIR_UP: '↑', DIR_RIGHT: '→', DIR_DOWN: '↓', DIR_LEFT: '←'}
        grid[mr][mc] = dir_chars.get(self.mouse_direction, 'M')
        
        # Draw cat
        if self.state != self.absorbing_state:
            cr, cc = ind2coord(self.state)
            if grid[cr][cc] != '.':
                grid[cr][cc] = '@'  # Cat on mouse
            else:
                grid[cr][cc] = 'C'
        
        print('\n'.join([''.join(row) for row in grid]))
        print()

