from __future__ import annotations
import numpy as np

class LongCorridor:
    """Long, narrow grid with repeating color bands along x.
    Egocentric observation is a W×W window (one-hot colors, +1 channel for walls, +1 channel for objects outside boundaries).
    """
    def __init__(
        self,
        Lx: int = 48,
        Ly: int = 5,
        n_colors: int = 6,
        obs_size: int = 5,
        seed: int = 0,
        n_objects: int = 10,
        object_margin: int = 2,
    ):
        if Lx < 1:
            raise ValueError(f"Lx must be at least 1, got {Lx}")
        if Ly < 1:
            raise ValueError(f"Ly must be at least 1, got {Ly}")
        self.Lx, self.Ly = Lx, Ly
        self.n_colors = n_colors
        self.obs_size = obs_size
        self.rng = np.random.RandomState(seed)
        # Repeating bands along x
        band_width = max(1, Lx // (n_colors * 2))
        pattern = np.zeros((Ly, Lx), dtype=np.int64)
        for x in range(Lx):
            pattern[:, x] = (x // band_width) % n_colors
        self.pattern = pattern
        # 4 headings: N,E,S,W (dx,dy) with y down
        self.dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)]

        # --- Add randomly scattered "objects" outside the boundaries ---
        self.n_objects = n_objects
        self.object_margin = object_margin
        # Place objects randomly outside the grid, but within a margin
        # The objects are points (ox, oy) such that either ox < 0 or ox >= Lx or oy < 0 or oy >= Ly
        # But within [-object_margin, Lx+object_margin) in both axes
        self.objects = []
        for _ in range(self.n_objects):
            # Decide which side: left, right, top, bottom
            side = self.rng.choice(['left', 'right', 'top', 'bottom'])
            if side == 'left':
                ox = self.rng.randint(-self.object_margin, 0)
                oy = self.rng.randint(-self.object_margin, Ly + self.object_margin)
            elif side == 'right':
                ox = self.rng.randint(Lx, Lx + self.object_margin)
                oy = self.rng.randint(-self.object_margin, Ly + self.object_margin)
            elif side == 'top':
                ox = self.rng.randint(-self.object_margin, Lx + self.object_margin)
                oy = self.rng.randint(-self.object_margin, 0)
            else:  # bottom
                ox = self.rng.randint(-self.object_margin, Lx + self.object_margin)
                oy = self.rng.randint(Ly, Ly + self.object_margin)
            self.objects.append((ox, oy))

    def reset(self) -> tuple[int,int,int]:
        x = self.rng.randint(1, self.Lx - 1)
        # Handle case where Ly = 1 by allowing y = 0
        if self.Ly == 1:
            y = 0
        elif self.Ly == 2:
            # When Ly = 2, we can only choose y = 0 or y = 1
            # But we want to avoid the edges, so we choose y = 0
            y = 0
        else:
            y = self.rng.randint(1, self.Ly - 1)
        h = self.rng.randint(0, 4)
        return x, y, h

    def step(self, x: int, y: int, h: int, action: int) -> tuple[int,int,int]:
        if action == 1:      # left
            h = (h - 1) % 4
        elif action == 2:    # right
            h = (h + 1) % 4
        elif action == 3:    # backward
            dx, dy = self.dirs[(h + 2) % 4]
            nx, ny = x + dx, y + dy
            if 0 <= nx < self.Lx and 0 <= ny < self.Ly:
                x, y = nx, ny
        else:                # forward
            dx, dy = self.dirs[h]
            nx, ny = x + dx, y + dy
            if 0 <= nx < self.Lx and 0 <= ny < self.Ly:
                x, y = nx, ny
        return x, y, h

    def egocentric_obs(self, x: int, y: int, h: int) -> np.ndarray:
        W = self.obs_size
        half = W // 2
        coords = [(u, v) for v in range(-half, half + 1) for u in range(-half, half + 1)]
        rot = {
            0: lambda u, v: (u, v),
            1: lambda u, v: (v, -u),
            2: lambda u, v: (-u, -v),
            3: lambda u, v: (-v, u),
        }[h]
        dx, dy = self.dirs[h]
        shift_u = half - 1  # push window forward
        xs, ys = [], []
        for (u, v) in coords:
            ru, rv = rot(u, v)
            wx = x + ru + dx * shift_u
            wy = y + rv + dy * shift_u
            xs.append(wx)
            ys.append(wy)
        C = self.n_colors
        # +1 for wall, +1 for object
        out = np.zeros((len(xs), C + 2), dtype=np.float32)
        for i, (wx, wy) in enumerate(zip(xs, ys)):
            if 0 <= wx < self.Lx and 0 <= wy < self.Ly:
                c = self.pattern[wy, wx]
                out[i, c] = 1.0
            else:
                # Check if this location is an object
                is_object = any((wx == ox and wy == oy) for (ox, oy) in self.objects)
                if is_object:
                    out[i, C + 1] = 1.0  # object channel
                else:
                    out[i, C] = 1.0  # wall / OOB channel
        return out.flatten()
