"""
Open Grid Wandering Target Environment

Simple 10x10 open grid (no internal walls) with a wandering target.
- Target moves DIAGONALLY with probability p_move each step
- Diagonal movement covers more distance, penalizing long commitments
- This forces the agent to replan when target drifts

Supports two modes:
1. execute_option: Original termination-based execution
2. execute_option_with_duration: Fixed duration commitment (no termination function)

Key advantages:
- No complex room geometry to confuse the analysis
- Diagonal movement creates clear "wrong direction" scenarios
- Open grid means target can drift anywhere
"""

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from options.Buffer.Buffer import Buffer
from options.Buffer.Experience import Experience
from options.Option import Option

# Cardinal directions
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

GRID_SIZE = 10


def ind2coord(index):
    """Convert state index to (row, col)."""
    row = index // GRID_SIZE
    col = index % GRID_SIZE
    return (row, col)


def coord2ind(coord):
    """Convert (row, col) to state index."""
    return coord[0] * GRID_SIZE + coord[1]


class OpenGridWanderingTarget(gym.Env):
    """
    Simple 10x10 open grid with a linearly wandering target.
    
    - No internal walls, just borders
    - Target moves in cardinal directions with probability p_move
    - Agent gets reward for reaching the target
    - Termination sees STALE target position (from option start)
    """
    metadata = {'render.modes': ['human']}

    def __init__(self, p_move=0.0, gamma=0.99, switch_reward=-0.1, 
                 step_reward=-0.1, terminal_reward=10):
        self.grid_size = GRID_SIZE
        self.n_states = GRID_SIZE * GRID_SIZE  # 100 states
        self.gamma = gamma
        self.switch_reward = switch_reward
        self.step_reward = step_reward
        self.terminal_reward = terminal_reward
        
        # Probability of target moving each step
        self.p_move = p_move
        
        # All cells are valid (open grid)
        self.valid_cells = list(range(self.n_states))
        
        # Diagonal directions for target movement (dr, dc)
        # Diagonal moves cover more distance, penalizing long commitments more
        self.target_move_dirs = [
            (-1, -1),  # UP-LEFT
            (-1, 1),   # UP-RIGHT
            (1, -1),   # DOWN-LEFT
            (1, 1),    # DOWN-RIGHT
        ]
        
        self.action_space = spaces.Discrete(4)  # UP, RIGHT, DOWN, LEFT
        self.observation_space = spaces.Discrete(self.n_states)
        self.absorbing_state = self.n_states
        
        # Target position (the wandering goal)
        self.target_state = None
        
        self.state = None
        self.done = False
        self._reset()

    def _move(self, state_idx, direction):
        """Attempt a move in direction, blocked by grid boundaries."""
        row, col = ind2coord(state_idx)
        if direction == UP:
            row -= 1
        elif direction == DOWN:
            row += 1
        elif direction == LEFT:
            col -= 1
        elif direction == RIGHT:
            col += 1

        # Boundary check
        if row < 0 or row >= self.grid_size or col < 0 or col >= self.grid_size:
            return state_idx
        return coord2ind((row, col))

    def _move_target_diagonal(self):
        """Move target one step diagonally (covers more distance per move)."""
        row, col = ind2coord(self.target_state)
        
        # Pick a random cardinal direction
        dr, dc = self.target_move_dirs[np.random.randint(len(self.target_move_dirs))]
        
        new_row = row + dr
        new_col = col + dc
        
        # Bounce off walls
        if new_row < 0:
            new_row = 0
        elif new_row >= self.grid_size:
            new_row = self.grid_size - 1
        if new_col < 0:
            new_col = 0
        elif new_col >= self.grid_size:
            new_col = self.grid_size - 1
            
        self.target_state = coord2ind((new_row, new_col))

    def _step(self, action):
        assert self.action_space.contains(action), f"Invalid action: {action}"

        if self.state == self.absorbing_state or self.done:
            return self.absorbing_state, 0, True, {
                "target_state": self.target_state
            }

        # 1. Agent move
        new_state = self._move(self.state, action)
        
        # 2. Maybe move target linearly (context drift!)
        target_moved = False
        if np.random.random() < self.p_move:
            self._move_target_diagonal()
            target_moved = True

        # 3. Check if reached target
        if new_state == self.target_state:
            self.done = True
            self.state = self.absorbing_state
            return self.state, self.terminal_reward, self.done, {
                "target_state": self.target_state,
                "target_moved": target_moved
            }

        # 4. Standard step
        self.state = new_state
        reward = self.step_reward

        return self.state, reward, self.done, {
            "target_state": self.target_state,
            "target_moved": target_moved
        }

    def execute_option(self, option: Option, buffer: Buffer, max_steps: int = None):
        """Execute option until termination or episode end.
        
        Key: Termination function only sees the target from when option STARTED.
        It doesn't know if the target moved during execution.
        """
        steps_done = 0
        
        # Capture target at option start - termination only sees this (stale) info
        target_at_option_start = self.target_state
        
        while not self.done:
            if max_steps is not None and steps_done >= max_steps:
                break

            state_before = self.state
            target_before = self.target_state

            action = option.intra_option_policy.get_action(state_before, target_before)
            next_state, reward, done, info = self._step(action)
            steps_done += 1
            self.done = done

            terminated = False
            if not self.done:
                # Termination sees target_at_option_start, NOT current target
                # It doesn't know if target moved during execution
                term_prob = option.termination_function.get_termination_probability(
                    next_state, target_at_option_start
                )
                terminated = np.random.random() < term_prob
                if terminated:
                    reward += self.switch_reward

            # Store both current target (for meta-policy) and target_at_option_start (for termination)
            buffer.add(Experience(
                state_before, target_before, action, option, reward,
                next_state, info["target_state"],
                context_at_option_start=target_at_option_start
            ))

            if terminated:
                break
                
        return steps_done, self.done

    def execute_option_with_duration(self, option: Option, duration: int, duration_idx: int,
                                      buffer: Buffer, max_steps: int = None):
        """Execute option for a FIXED duration (no termination function).
        
        This is the duration-commitment mode: at option start, we commit to running
        for exactly `duration` steps (or until episode ends/goal reached).
        
        Args:
            option: The option to execute
            duration: Number of steps to commit to
            duration_idx: Index into DURATIONS list (for storing in experience)
            buffer: Buffer to store experiences
            max_steps: Maximum steps allowed (episode limit)
            
        Returns:
            steps_done: Actual steps executed
            done: Whether episode ended
        """
        steps_done = 0
        cumulative_reward = 0.0
        
        # Capture state/target at option start
        state_at_start = self.state
        target_at_start = self.target_state
        
        while not self.done and steps_done < duration:
            if max_steps is not None and steps_done >= max_steps:
                break

            state_before = self.state
            target_before = self.target_state

            action = option.intra_option_policy.get_action(state_before, target_before)
            next_state, reward, done, info = self._step(action)
            steps_done += 1
            cumulative_reward += reward
            self.done = done

        # Add switch cost at option end (if not episode end)
        if not self.done:
            cumulative_reward += self.switch_reward

        # Store ONE experience for the whole option execution
        # This represents: "I was at state_at_start, saw target_at_start, 
        # chose option with duration, got cumulative_reward, ended at next_state"
        buffer.add(Experience(
            state_at_start, target_at_start, 
            option.intra_option_policy.action,  # The constant action
            option, cumulative_reward,
            self.state, self.target_state,
            context_at_option_start=target_at_start,
            duration_idx=duration_idx,
            committed_duration=duration
        ))
                
        return steps_done, self.done

    def _reset(self):
        """Reset environment with random agent and target positions."""
        # Start agent at random position
        self.state = np.random.choice(self.valid_cells)
        
        # Start target at different random position
        self.target_state = np.random.choice(self.valid_cells)
        while self.target_state == self.state:
            self.target_state = np.random.choice(self.valid_cells)
        
        self.done = False
        return self.state, self.target_state

    def render(self, mode='human'):
        """Simple text rendering of the grid."""
        grid = [['.' for _ in range(self.grid_size)] for _ in range(self.grid_size)]
        
        if self.state != self.absorbing_state:
            ar, ac = ind2coord(self.state)
            grid[ar][ac] = 'A'
        
        tr, tc = ind2coord(self.target_state)
        if grid[tr][tc] == 'A':
            grid[tr][tc] = '@'  # Agent on target
        else:
            grid[tr][tc] = 'T'
        
        print('\n'.join([''.join(row) for row in grid]))
        print()

