from __future__ import annotations

import math
from typing import Tuple, Optional

import numpy as np

try:
    import gymnasium as gym                
    from gymnasium import spaces                
except Exception:                    
    try:
        import gym                
        from gym import spaces                
    except Exception:
        gym = None                
                                         
        class _Box:
            def __init__(self, low, high, shape=None, dtype=np.float32):
                low = np.array(low, dtype=dtype)
                high = np.array(high, dtype=dtype)
                if shape is None:
                    shape = low.shape
                self.low = np.broadcast_to(low, shape).astype(dtype)
                self.high = np.broadcast_to(high, shape).astype(dtype)
                self._shape = tuple(shape)
                self.dtype = dtype

            @property
            def shape(self):
                return self._shape

            def sample(self):
                return np.random.uniform(self.low, self.high).astype(self.dtype)

        class spaces:                
            Box = _Box


class OptimalTrapEnv(object):
    """
    A simple 2D navigation task with a hazardous trap region near the shortest path.

    - State: (x, y, vx, vy)
    - Action: (ax, ay) clipped to [-1, 1]
    - Reward: negative distance to goal minus small action cost; +goal_bonus when close to goal
    - Cost: 1 if inside hazard region (circle) else 0

    Episode terminates when:
      - step reaches max_steps
      - agent gets within goal_tolerance of goal
      - agent leaves the square bounds (treated as failure)
    """

    metadata = {"render.modes": ["human"]}

    def __init__(self,
                 max_steps: int = 500,
                 dt: float = 0.1,
                 goal: Tuple[float, float] = (4.0, 0.0),
                 start: Tuple[float, float] = (-4.0, 0.0),
                 goal_tolerance: float = 0.2,
                 trap_center: Tuple[float, float] = (0.0, 0.0),
                 trap_radius: float = 1.0,
                 boundary: float = 6.0):
        super().__init__()
        self.dt = dt
        self.max_steps = max_steps
        self.goal = np.array(goal, dtype=np.float32)
        self.start = np.array(start, dtype=np.float32)
        self.goal_tolerance = goal_tolerance
        self.trap_center = np.array(trap_center, dtype=np.float32)
        self.trap_radius = trap_radius
        self.boundary = boundary

        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
        high = np.array([boundary, boundary, 5.0, 5.0], dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, shape=(4,), dtype=np.float32)

        self._state = np.zeros(4, dtype=np.float32)
        self._step_count = 0

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
                                                          
        if seed is not None:
            np.random.seed(seed)
        pos = self.start + 0.05 * np.random.uniform(-1, 1, size=2).astype(np.float32)
        vel = np.zeros(2, dtype=np.float32)
        self._state = np.concatenate([pos, vel]).astype(np.float32)
        self._step_count = 0
        info = {}
        return self._state.copy(), info

    def step(self, action: np.ndarray):
        action = np.clip(action, self.action_space.low, self.action_space.high)
        x, y, vx, vy = self._state
        ax, ay = action
                           
        vx = vx + ax * self.dt
        vy = vy + ay * self.dt
        x = x + vx * self.dt
        y = y + vy * self.dt
        self._state = np.array([x, y, vx, vy], dtype=np.float32)
        self._step_count += 1

                               
        pos = self._state[:2]
        dist_to_goal = float(np.linalg.norm(self.goal - pos))
        in_trap = float(np.linalg.norm(pos - self.trap_center) <= self.trap_radius)
        out_of_bounds = bool(np.any(np.abs(pos) > self.boundary))

                                                                 
        reward = -dist_to_goal - 0.01 * float(np.linalg.norm(action))
        if dist_to_goal <= self.goal_tolerance:
            reward += 10.0

        cost = 1.0 if in_trap > 0.5 else 0.0

        terminated = dist_to_goal <= self.goal_tolerance or out_of_bounds
        truncated = self._step_count >= self.max_steps
        info = {"cost": cost, "dist_to_goal": dist_to_goal}
        return self._state.copy(), float(reward), bool(terminated), bool(truncated), info

    def render(self):                    
                             
        x, y, vx, vy = self._state
        print(f"state: x={x:.2f} y={y:.2f} vx={vx:.2f} vy={vy:.2f}")
