from __future__ import annotations
import numpy as np
from .env import LongCorridor

ACTIONS = 4  # forward, left, right, back


def sample_trajectory(env: LongCorridor, T: int = 200, p_turn: float = 0.2, seed: int | None = None):
    if seed is not None:
        rng = np.random.RandomState(seed)
        rand = rng.rand
        choice = rng.choice
    else:
        rand = np.random.rand
        choice = np.random.choice

    x, y, h = env.reset()
    obs_list, act_list, hdir_list, pos_list = [], [], [], []
    for t in range(T):
        obs = env.egocentric_obs(x, y, h)
        hd_onehot = np.eye(4, dtype=np.float32)[h]
        angle = h * (np.pi / 2)
        hd_sin_cos = np.array([np.sin(angle), np.cos(angle)], dtype=np.float32)
        
        # Check if there's a wall in front by trying to move forward
        # If position doesn't change after forward move, there's a wall/boundary
        dx, dy = env.dirs[h]
        nx, ny = x + dx, y + dy
        wall_in_front = not (0 <= nx < env.Lx and 0 <= ny < env.Ly)
        
        # Only turn if there's a wall in front
        if wall_in_front:
            # When wall is detected, turn with high probability
            if rand() < 0.9:  # 90% chance to turn when wall detected
                action = choice([1, 2])  # left or right
            else:
                action = 0  # forward (10% chance)
        else:
            # When no wall, always go forward
            action = 0  # forward
        if rand() < 0.0:
            action = 3  # backward (rarely used)
        if rand() < 0.1:  # 90% chance to turn when wall detected
            action = choice([1, 2])  # left or right
            
        obs_list.append(obs.astype(np.float32))
        act_list.append(np.eye(ACTIONS, dtype=np.float32)[action])
        hdir_list.append(np.concatenate([hd_onehot, hd_sin_cos], axis=0))
        pos_list.append((x, y))
        x, y, h = env.step(x, y, h, action)

    O = np.stack(obs_list)
    A = np.stack(act_list)
    HDF = np.stack(hdir_list)
    P = np.array(pos_list)
    return O, A, HDF, P
