import math

class Policy:
    """
    Town01 single-carriageway controller.
    - Strict TL stop-line guard (Yellow==Red).
    - Distinguish curb-parked blocker on the right vs. in-lane stopped lead.
    - Safe pass: hold left offset until CLEAR of blocker; slight opposite-lane incursion only if clear.
    - Smooth speed planner and stable lateral controller.
    """

    def __init__(
        self,
        lane_width_m=4.0,
        tl_gate_to_line_m=2.5,   # map-specific; override if known
        tl_stop_margin_m=3.2,
        tl_brake_start_m=24.0
    ):
        # === Road geometry ===
        self.lane_width_m       = float(lane_width_m)
        self.lane_half_width_m  = 0.5 * self.lane_width_m
        self.centerline_guard_m = 0.40  # keep ~40 cm from centerline when possible

        # === Cruise ===
        self.v_cruise = 6.94  # 25 km/h

        # === Longitudinal PI ===
        self.kp_v = 1.0
        self.ki_v = 0.05
        self.int_a = 0.0
        self.a_throttle_max = 2.0
        self.a_brake_max   = 4.0
        self.a_comf        = 2.2

        # === Lateral controller (lookahead + damping) ===
        self.k_yaw       = 0.75
        self.k_e         = 1.15
        self.k_yaw_rate  = 0.12
        self.stanley_clip= 0.60
        self.v_epsilon   = 0.1
        self.low_v_steer_bias = 2.0

        # Lookahead
        self.L0_la = 0.8
        self.k_la  = 0.6  # s, Lf = L0 + k*v

        # === Command rate limits ===
        self.steer_rate_limit    = 0.18
        self.throttle_rate_limit = 0.16
        self.brake_rate_limit    = 0.40

        # === Traffic lights ===
        self.tl_gate_to_line_m   = float(tl_gate_to_line_m)
        self.tl_stop_margin_m    = float(tl_stop_margin_m)
        self.tl_brake_start_m    = float(tl_brake_start_m)
        self.tl_min_unknown_v    = 3.0
        self.tl_stop_latch       = False

        # === Pedestrians ===
        self.ped_stop_margin_m = 3.5
        self.ped_hold_steps    = 8
        self.ped_hold_counter  = 0

        # === Lead following ===
        self.s0                = 1.2
        self.s0_min_stop       = 1.0
        self.stop_deadband     = 0.08
        self.T_headway         = 0.95
        self.k_v_rel           = 0.5
        self.k_creep_gap       = 1.0
        self.creep_speed_max   = 1.2

        # === Unstick ===
        self.stuck_counter        = 0
        self.stuck_thresh_steps   = 60
        self.unstick_speed        = 1.0
        self.oncoming_safe_dist_m = 15.0
        self.lead_consider_dist_m = 15.0

        # === Planner speed smoothing ===
        self.v_plan       = 0.0
        self.dv_up_max    = 0.60
        self.dv_down_max  = 2.50
        self.launch_dv_up = 0.35

        # === Turn/curvature control ===
        self.a_lat_max = 1.6  # m/s^2

        # === Misc ===
        self.speed_hold_brake = 0.30

        # === Right curb-parked pass logic ===
        self.edge_rel_v_thr         = 0.30   # parked/stopped if |rel_v| small
        self.edge_dist_gate_m       = 10.0   # only engage when within ~10 m
        self.pass_timer             = 0
        self.pass_timer_max         = 120
        self.pass_active            = False
        self.pass_hold_clear_steps  = 25
        self.offset_hold_counter    = 0
        self.opposite_lat_gate_m    = 1.0
        self.no_pass_gap_m          = 25.0   # conservative (bursty oncoming)
        self.opposite_incursion_max = 0.35   # allow slight incursion into opposite lane iff clear

        # --- NEW: stronger clearance behavior (prevents early recentering) ---
        self.pass_alongside_gate_m     = 4.0   
        self.pass_clear_front_m        = 8.0    # only recenter after we're ≥ this ahead
        self.pass_lead_lost_hold_steps = 40     # keep offset if lead vanishes briefly
        self.lead_missing_counter      = 0
        self.min_pass_offset_left      = 0.55   # minimum offset while passing (safety margin)

    # ---------- Utils ----------
    def _clip(self, x, lo, hi):
        if x < lo: return lo
        if x > hi: return hi
        return x

    def _finite_or(self, val, default=0.0):
        try:
            if val is None: return default
            if isinstance(val, (int, float)):
                return val if math.isfinite(val) else default
        except Exception:
            return default
        return default

    def _last_hist(self, hist, default=0.0):
        if isinstance(hist, (list, tuple)) and len(hist) > 0:
            v = hist[-1]
            return self._finite_or(v, default)
        return default

    def _mean_hist(self, hist, default=0.0):
        if not isinstance(hist, (list, tuple)) or len(hist) == 0:
            return default
        vals = [v for v in hist if isinstance(v, (int, float)) and math.isfinite(v)]
        return (sum(vals) / len(vals)) if vals else default

    def _rate_limit(self, desired, last, limit):
        return self._clip(desired, last - limit, last + limit)

    def _v_from_stop_dist(self, d):
        if d is None or d <= 0.0: return 0.0
        return math.sqrt(max(0.0, 2.0 * self.a_comf * d))

    # ---------- Speed target smoothing ----------
    def _smooth_v_target(self, v_raw, obs, dv_up_override=None):
        v_raw = max(0.0, float(v_raw))
        v_prev = float(self.v_plan)
        v_now  = self._finite_or(obs.get("speed_mps"), 0.0)

        dv_up = self.launch_dv_up if v_now < 0.5 else self.dv_up_max
        if isinstance(dv_up_override, (int, float)) and math.isfinite(dv_up_override):
            dv_up = max(dv_up, float(dv_up_override))

        if v_raw > v_prev:
            v_new = min(v_raw, v_prev + dv_up)
        else:
            v_new = max(v_raw, v_prev - self.dv_down_max)

        self.v_plan = self._clip(v_new, 0.0, self.v_cruise)
        return self.v_plan

    # ---------- Oncoming squeeze-right ----------
    def _oncoming_modifiers(self, obs):
        ocs = obs.get("opposite_cars") or []
        if not ocs: return 0.0, self.v_cruise
        oc = ocs[0] or {}
        gap = oc.get("gap_long_m"); lat = oc.get("gap_lat_m")
        gap = gap if isinstance(gap, (int, float)) and math.isfinite(gap) else None
        lat = lat if isinstance(lat, (int, float)) and math.isfinite(lat) else None
        if gap is None or lat is None: return 0.0, self.v_cruise
        dist_gate, lat_gate = 25.0, 1.2
        if gap > dist_gate or abs(lat) > lat_gate: return 0.0, self.v_cruise
        f = self._clip(((dist_gate - gap)/dist_gate) * ((lat_gate - abs(lat))/lat_gate), 0.0, 1.0)
        offset_right = 0.35 * f
        vcap = self.v_cruise * (0.6 + 0.4 * (1.0 - f))
        return offset_right, vcap

    # ---------- TL speed limit (Yellow==Red) ----------
    def _plan_speed_limit_traffic_light(self, obs):
        tl = obs.get("traffic_light") or {}
        exists = bool(tl.get("exists", False))
        if not exists:
            self.tl_stop_latch = False
            return None

        state   = tl.get("state", "Unknown")
        dist_m  = tl.get("dist_m")
        dist_m  = dist_m if isinstance(dist_m, (int, float)) and math.isfinite(dist_m) else None
        is_stop_state = (state != "Green")

        d_line = None
        if dist_m is not None:
            d_line = max(0.0, dist_m - self.tl_gate_to_line_m)

        if (not is_stop_state) or (d_line is not None and d_line > self.tl_brake_start_m and state == "Unknown"):
            self.tl_stop_latch = False
            return None if state == "Green" else self._clip(self.tl_min_unknown_v, 0.0, self.v_cruise)

        lead_gap_in_lane = None
        leads = obs.get("lead_cars") or []
        if leads:
            l = leads[0] or {}
            g = l.get("gap_long_m"); gl = l.get("gap_lat_m")
            if isinstance(g, (int, float)) and math.isfinite(g) and \
               isinstance(gl,(int, float)) and math.isfinite(gl) and abs(gl) < self.lane_half_width_m:
                lead_gap_in_lane = g

        d_stop_candidates = []
        if d_line is not None:
            d_stop_candidates.append(max(0.0, d_line - self.tl_stop_margin_m))
        if lead_gap_in_lane is not None:
            d_stop_candidates.append(max(0.0, lead_gap_in_lane - self.s0_min_stop))
        if not d_stop_candidates:
            return 0.0

        d_stop = min(d_stop_candidates)
        vlim   = self._v_from_stop_dist(d_stop)

        v = self._finite_or(obs.get("speed_mps"), 0.0)
        if (d_line is not None and d_line <= (self.tl_stop_margin_m + 0.8) and v < 0.4):
            self.tl_stop_latch = True

        if self.tl_stop_latch:
            approach = None
            if (lead_gap_in_lane is not None) and (lead_gap_in_lane > (self.s0_min_stop + self.stop_deadband)):
                approach = self._clip(self.k_creep_gap * (lead_gap_in_lane - self.s0_min_stop), 0.0, self.creep_speed_max)

            v_line_cap = self._v_from_stop_dist(max(0.0, (d_line if d_line is not None else 0.0) - self.tl_stop_margin_m))
            vlim = min(vlim, v_line_cap if d_line is not None else vlim, approach if approach is not None else 0.0)
            return max(0.0, vlim)

        return max(0.0, vlim)

    # ---------- Pedestrians ----------
    def _plan_speed_limit_pedestrians(self, obs):
        peds = obs.get("pedestrians") or []
        best = None
        ped_block_active = False

        for ped in peds:
            lane  = ped.get("lane"); state = ped.get("state")
            gap   = ped.get("gap_long_m")
            gap   = gap if isinstance(gap, (int, float)) and math.isfinite(gap) else None
            if gap is None or gap <= 0.0: continue

            if lane == "ego" and state == "in_lane":
                vlim = self._v_from_stop_dist(gap - self.ped_stop_margin_m)
                best = vlim if (best is None or vlim < best) else best
                ped_block_active = True
            elif lane == "approach" and state == "approaching_lane":
                t_enter = ped.get("t_enter_lane_s")
                if isinstance(t_enter, (int, float)) and math.isfinite(t_enter) and 0.0 <= t_enter <= 3.5:
                    vlim = self._v_from_stop_dist(gap - self.ped_stop_margin_m)
                    best = vlim if (best is None or vlim < best) else best
                    ped_block_active = True

        if ped_block_active:
            self.ped_hold_counter = self.ped_hold_steps
        else:
            if self.ped_hold_counter > 0:
                self.ped_hold_counter -= 1
                best = 0.0 if best is None else min(best, 0.0)
        return best

    # ---------- Parked-right detection helpers ----------
    def _right_edge_intrusion(self, lead):
        """
        Estimate how much the lead's left side intrudes into ego lane from the right curb.
        Positive = intruding.
        """
        lat = lead.get("gap_lat_m")
        gap = lead.get("gap_long_m")
        if not (isinstance(lat, (int, float)) and math.isfinite(lat)): return -1.0
        if not (isinstance(gap, (int, float)) and math.isfinite(gap) and gap > 0.0): return -1.0

        car_half = 0.9   # ~half width of sedan
        left_edge_from_lane_center = lat - car_half     # right+ sign
        intrusion = self.lane_half_width_m - left_edge_from_lane_center
        return intrusion  # >0 means it's biting into lane from the right

    def _is_right_curb_parked(self, lead):
        # Must be near right curb, intruding, and nearly stationary relative to ego.
        rel = lead.get("rel_long_mps"); lat = lead.get("gap_lat_m")
        rel = rel if isinstance(rel, (int, float)) and math.isfinite(rel) else 0.0
        lat = lat if isinstance(lat, (int, float)) and math.isfinite(lat) else None

        intrusion = self._right_edge_intrusion(lead)
        if intrusion <= 0.05:  # not biting into lane
            return False
        if abs(rel) > self.edge_rel_v_thr:
            return False
        if lat is None or lat < 0.6:  # prefer objects clearly offset to the right side
            return False
        return True

    def _oncoming_no_pass(self, obs):
        oc = (obs.get("opposite_cars") or [])
        if not oc: return False
        o = oc[0] or {}
        d = o.get("gap_long_m"); lat = o.get("gap_lat_m")
        if not (isinstance(d, (int, float)) and math.isfinite(d)): return False
        if not (isinstance(lat,(int, float)) and math.isfinite(lat)): return False
        return (d < self.no_pass_gap_m) and (abs(lat) < self.opposite_lat_gate_m)

    # ---------- NEW: Improved pass state (clearance-aware) ----------
    def _update_pass_state(self, pass_allowed, gap, required_offset_left, oncoming_block):
        """
        Don't recenter until CLEAR of the blocker. Holds offset if lead disappears briefly.
        """
        # Track whether the blocker is still detected
        if isinstance(gap, (int, float)) and math.isfinite(gap):
            self.lead_missing_counter = 0
        else:
            self.lead_missing_counter += 1

        offset_left = 0.0
        vpass_cap   = None
        lockout     = False

        # Abort pass immediately if oncoming blocks
        if oncoming_block:
            self.pass_active = False
            self.pass_timer  = 0

        # Arm pass
        if (not self.pass_active) and pass_allowed:
            self.pass_active = True
            self.pass_timer  = self.pass_timer_max

        if self.pass_active:
            self.pass_timer -= 1

            # Max left we allow (respect centerline; tiny incursion if clear)
            base_max  = self.lane_half_width_m - self.centerline_guard_m
            inc_allow = 0.0 if oncoming_block else self.opposite_incursion_max
            dyn_left_max = base_max + inc_allow

            # Enforce a minimum safe offset while passing
            desired = max(required_offset_left, self.min_pass_offset_left)
            desired = self._clip(desired, 0.0, dyn_left_max)
            offset_left = desired if desired > 0.0 else 0.0

            # Speed cap while threading past
            if isinstance(gap, (int, float)) and math.isfinite(gap):
                d_eff = max(0.0, gap - (self.s0_min_stop + 0.6))
                approach = self._v_from_stop_dist(d_eff)
                vpass_cap = max(1.0, min(1.8, approach if approach > 0.0 else 1.8))
            else:
                vpass_cap = 1.6

            # Clearance logic
            still_alongside = (
                (isinstance(gap, (int, float)) and math.isfinite(gap) and gap <= self.pass_alongside_gate_m)
                or (gap is None and self.lead_missing_counter < self.pass_lead_lost_hold_steps)
            )
            clear_ok = (
                (isinstance(gap, (int, float)) and math.isfinite(gap) and gap >= self.pass_clear_front_m)
                or (self.lead_missing_counter >= self.pass_lead_lost_hold_steps)
            )

            # End pass only when clearly past (or timer safety)
            if (clear_ok or self.pass_timer <= 0) and (not still_alongside):
                self.pass_active = False
                self.pass_timer = 0
                self.offset_hold_counter = self.pass_hold_clear_steps

        # trailing hold to avoid snap-back
        if (not self.pass_active) and (self.offset_hold_counter > 0) and offset_left <= 1e-3:
            self.offset_hold_counter -= 1
            offset_left = 0.25

        # lockout near-contact front
        if (not self.pass_active) and isinstance(gap, (int, float)) and gap <= (self.s0_min_stop + 1.0):
            lockout = True

        return offset_left, vpass_cap, lockout

    # ---------- Lead + pass planning ----------
    def _plan_speed_limit_lead_and_pass(self, obs):
        leads = obs.get("lead_cars") or []
        if not leads:
            return None, 0.0, None

        lead = leads[0] or {}
        gap    = lead.get("gap_long_m")
        gap_lat= lead.get("gap_lat_m")
        rel_x  = lead.get("rel_long_mps")
        ttc    = lead.get("ttc_s")
        thw    = lead.get("thw_s")

        gap    = gap    if isinstance(gap, (int, float)) and math.isfinite(gap) else None
        gap_lat= gap_lat if isinstance(gap_lat, (int, float)) and math.isfinite(gap_lat) else None
        rel_x  = rel_x  if isinstance(rel_x, (int, float)) and math.isfinite(rel_x) else 0.0

        v_ego  = self._finite_or(obs.get("speed_mps"), 0.0)
        v_lead = self._clip(v_ego + rel_x, 0.0, 50.0)

        in_lane = (gap_lat is not None and abs(gap_lat) < self.lane_half_width_m)

        # Parked vs stopped-in-traffic reasoning
        parked_right = self._is_right_curb_parked(lead)
        stopped_lead = (v_lead < 0.25) or (isinstance(ttc, (int, float)) and 0.0 <= ttc < 1.2) or \
                       (isinstance(gap, (int, float)) and gap <= (self.s0_min_stop + 1.0))

        # Pass feasibility and required offset
        oncoming_block = self._oncoming_no_pass(obs)
        intrusion = self._right_edge_intrusion(lead)
        required_offset = 0.0 if intrusion <= 0.0 else max(intrusion + 0.55, self.min_pass_offset_left)
        pass_allowed = (parked_right and stopped_lead and (gap is not None) and gap <= self.edge_dist_gate_m and (not oncoming_block))

        offset_left_cmd, v_pass_cap, lockout = self._update_pass_state(pass_allowed, gap, required_offset, oncoming_block)

        # During pass: cap speed and keep offset
        if self.pass_active or v_pass_cap is not None:
            v_follow = self.v_cruise if gap is None else self._v_from_stop_dist(max(0.0, gap - (self.s0_min_stop + 1.0)))
            v_cap = v_pass_cap if v_pass_cap is not None else v_follow
            if lockout:
                v_cap = 0.0
            return self._clip(v_cap, 0.0, self.v_cruise), offset_left_cmd, v_pass_cap

        # Normal in-lane following
        if not in_lane:
            return None, 0.0, None
        if gap is None or gap <= 0.0:
            return 0.0, 0.0, None

        s_des       = self.s0 + max(0.0, v_ego) * self.T_headway + self.k_v_rel * max(0.0, v_ego - v_lead)
        v_from_gap  = max(0.0, (gap - self.s0) / max(self.T_headway, 0.5))
        v_follow    = min(v_lead, v_from_gap)

        if isinstance(thw, (int, float)) and math.isfinite(thw) and thw >= 0.0 and thw < self.T_headway:
            v_follow = min(v_follow, v_follow * (thw / max(self.T_headway, 1e-3)))

        if isinstance(ttc, (int, float)) and math.isfinite(ttc) and ttc >= 0.0:
            if ttc < 0.9: v_follow = 0.0
            elif ttc < 1.8: v_follow = min(v_follow, v_follow * (ttc / 1.8))

        if gap < s_des:
            v_follow = min(v_follow, self._v_from_stop_dist(max(0.0, gap - self.s0)))

        if v_ego < 0.8 and v_lead < 0.8:
            if gap > (self.s0_min_stop + self.stop_deadband):
                approach_speed = self._clip(self.k_creep_gap * (gap - self.s0_min_stop), 0.0, self.creep_speed_max)
                v_follow = min(v_follow, approach_speed)
            else:
                v_follow = 0.0

        return self._clip(v_follow, 0.0, self.v_cruise), 0.0, None

    # ---------- Unstick ----------
    def _unstick_logic(self, obs, v_target_current, tl_vlim, ped_vlim):
        no_tl_block  = (tl_vlim  is None) or (tl_vlim  > 0.5)
        no_ped_block = (ped_vlim is None) or (ped_vlim > 0.5)
        v = self._finite_or(obs.get("speed_mps"), 0.0)

        leads = obs.get("lead_cars") or []
        has_stopped_lead = False
        if leads:
            lead = leads[0] or {}
            gap = lead.get("gap_long_m"); gap_lat = lead.get("gap_lat_m"); rel = lead.get("rel_long_mps")
            gap     = gap     if isinstance(gap, (int, float)) and math.isfinite(gap) else None
            gap_lat = gap_lat if isinstance(gap_lat, (int, float)) and math.isfinite(gap_lat) else None
            rel     = rel     if isinstance(rel, (int, float)) and math.isfinite(rel) else 0.0
            v_lead = max(0.0, v + rel)
            if (gap is not None and gap < self.lead_consider_dist_m and
                gap_lat is not None and abs(gap_lat) < self.lane_half_width_m and
                v_lead < 0.3):
                has_stopped_lead = True

        if no_tl_block and no_ped_block and has_stopped_lead and v < 0.3 and (not self.pass_active):
            self.stuck_counter += 1
        else:
            self.stuck_counter = 0

        if self.stuck_counter > self.stuck_thresh_steps:
            oncoming = obs.get("opposite_cars") or []
            safe_to_proceed = True
            if oncoming:
                oc = oncoming[0] or {}
                oc_gap = oc.get("gap_long_m"); oc_lat = oc.get("gap_lat_m")
                oc_gap = oc_gap if isinstance(oc_gap, (int, float)) and math.isfinite(oc_gap) else None
                oc_lat = oc_lat if isinstance(oc_lat, (int, float)) and math.isfinite(oc_lat) else None
                if (oc_gap is not None and oc_gap < self.oncoming_safe_dist_m and
                    oc_lat is not None and abs(oc_lat) < self.lane_half_width_m):
                    safe_to_proceed = False
            if safe_to_proceed:
                return max(v_target_current, self.unstick_speed)
        return v_target_current

    # ---------- Curvature/heading caps ----------
    def _curve_speed_limit(self, obs):
        yaw_rate = abs(self._finite_or(obs.get("yaw_rate_rps"), 0.0))
        return self.v_cruise if yaw_rate < 1e-3 else self._clip(self.a_lat_max / yaw_rate, 0.0, self.v_cruise)

    def _heading_error_speed_limit(self, obs):
        epsi = abs(self._mean_hist(obs.get("yaw_error_hist4"), 0.0))
        scale = 1.0 - self._clip(0.7 * (epsi / 0.35), 0.0, 0.6)
        return self.v_cruise * scale

    # ---------- Lateral (with pass left-offset, oncoming squeeze-right) ----------
    def _compute_lateral(self, obs, offset_left_cmd):
        e_y      = self._mean_hist(obs.get("lateral_hist4"), 0.0)      # right+
        e_psi    = self._mean_hist(obs.get("yaw_error_hist4"), 0.0)    # right+
        yaw_rate = self._finite_or(obs.get("yaw_rate_rps"), 0.0)
        v        = self._finite_or(obs.get("speed_mps"), 0.0)
        last_st  = self._last_hist(obs.get("steer_cmd_hist4"), 0.0)

        offset_right, _ = self._oncoming_modifiers(obs)
        net = offset_right - max(0.0, offset_left_cmd)  # right+ minus left+

        Lf = self.L0_la + self.k_la * v
        e_y_la = (e_y - net) + Lf * e_psi

        speed_scale = 0.6 + 0.4 * self._clip(v / self.v_cruise, 0.0, 1.0)
        k_yaw_eff = self.k_yaw * speed_scale

        stanley = math.atan2(self.k_e * e_y_la, max(v, self.v_epsilon))
        stanley = self._clip(stanley, -self.stanley_clip, self.stanley_clip)
        raw_steer = -(k_yaw_eff * e_psi + stanley) - self.k_yaw_rate * yaw_rate

        atten = 0.5 + 0.5 * (v / (v + self.low_v_steer_bias))
        raw_steer *= atten
        raw_steer = self._clip(raw_steer, -1.0, 1.0)

        steer = self._rate_limit(raw_steer, last_st, self.steer_rate_limit)
        return self._clip(steer, -1.0, 1.0)

    # ---------- Longitudinal control ----------
    def _longitudinal_control(self, obs, v_target, tl_guard=False):
        v  = self._finite_or(obs.get("speed_mps"), 0.0)
        ev = v_target - v

        if v_target <= 0.2:
            self.int_a = min(self.int_a, 0.0)
        self.int_a += self.ki_v * ev
        self.int_a  = self._clip(self.int_a, -self.a_brake_max, self.a_throttle_max)

        a_cmd = self.kp_v * ev + self.int_a
        a_cmd = self._clip(a_cmd, -self.a_brake_max, self.a_throttle_max)

        if a_cmd >= 0.0:
            throttle_des = self._clip(a_cmd / self.a_throttle_max, 0.0, 1.0)
            brake_des    = 0.0
        else:
            throttle_des = 0.0
            brake_des    = self._clip((-a_cmd) / self.a_brake_max, 0.0, 1.0)

        if v < 0.05 and v_target < 0.1:
            brake_des = max(brake_des, self.speed_hold_brake)
            throttle_des = 0.0

        if tl_guard:
            throttle_des = 0.0
            brake_des = max(brake_des, 0.6)

        last_th = self._last_hist(obs.get("throttle_cmd_hist4"), 0.0)
        last_br = self._last_hist(obs.get("brake_cmd_hist4"), 0.0)

        throttle_cmd = self._rate_limit(throttle_des, last_th, self.throttle_rate_limit)
        throttle_cmd = self._clip(throttle_cmd, 0.0, 1.0)

        brake_cmd = self._rate_limit(brake_des, last_br, self.brake_rate_limit)
        brake_cmd = self._clip(brake_cmd, 0.0, 1.0)

        if brake_cmd > 0.0:
            throttle_cmd = 0.0
        return throttle_cmd, brake_cmd

    # ---------- API ----------
    def compute_action(self, obs, path=None):
        # Fail-safe
        speed = obs.get("speed_mps")
        if not isinstance(speed, (int, float)) or not math.isfinite(speed):
            last_steer = self._last_hist(obs.get("steer_cmd_hist4"), 0.0)
            steer = self._rate_limit(0.0, last_steer, self.steer_rate_limit)
            last_br = self._last_hist(obs.get("brake_cmd_hist4"), 0.0)
            brake = self._rate_limit(0.6, last_br, self.brake_rate_limit)
            return (self._clip(steer, -1.0, 1.0), 0.0, self._clip(brake, 0.0, 1.0))

        # Base speed caps
        v_raw = self.v_cruise
        v_raw = min(v_raw, self._curve_speed_limit(obs))
        v_raw = min(v_raw, self._heading_error_speed_limit(obs))
        _, v_oncoming_cap = self._oncoming_modifiers(obs)
        v_raw = min(v_raw, v_oncoming_cap)

        # TL (Yellow==Red)
        tl_vlim  = self._plan_speed_limit_traffic_light(obs)
        if tl_vlim is not None:
            v_raw = min(v_raw, tl_vlim)

        # Pedestrians
        ped_vlim = self._plan_speed_limit_pedestrians(obs)
        if ped_vlim is not None:
            v_raw = min(v_raw, ped_vlim)

        # Lead + parked-right pass
        lead_vlim, offset_left_cmd, v_pass_cap = self._plan_speed_limit_lead_and_pass(obs)
        if lead_vlim is not None:
            v_raw = min(v_raw, lead_vlim)

        # Unstick (not while actively passing)
        if not self.pass_active:
            v_raw = self._unstick_logic(obs, v_raw, tl_vlim, ped_vlim)

        v_raw = self._clip(v_raw, 0.0, self.v_cruise)

        # Catch-up override when roomy gap and lead pulling away
        dv_up_override = None
        leads = obs.get("lead_cars") or []
        if leads and not self.pass_active:
            lead = leads[0] or {}
            gap   = lead.get("gap_long_m")
            rel   = lead.get("rel_long_mps")
            thw   = lead.get("thw_s")
            gap   = gap if isinstance(gap, (int, float)) and math.isfinite(gap) else None
            rel   = rel if isinstance(rel, (int, float)) and math.isfinite(rel) else 0.0
            thw   = thw if isinstance(thw, (int, float)) and math.isfinite(thw) else None
            if gap is not None and ((thw is not None and thw > (self.T_headway + 0.3)) or rel > 0.5):
                dv_up_override = 0.9

        v_target = self._smooth_v_target(v_raw, obs, dv_up_override=dv_up_override)

        # Hard stop-line guard (prevents "stops after TL")
        tl = obs.get("traffic_light") or {}
        if bool(tl.get("exists", False)) and tl.get("state", "Unknown") != "Green":
            dist_m = tl.get("dist_m")
            if isinstance(dist_m, (int, float)) and math.isfinite(dist_m):
                d_line = max(0.0, dist_m - self.tl_gate_to_line_m)
                if d_line <= (self.tl_stop_margin_m + 0.8):
                    v_target = 0.0
                    tl_guard = True
                else:
                    tl_guard = False
            else:
                tl_guard = True
        else:
            tl_guard = False

        throttle, brake = self._longitudinal_control(obs, v_target, tl_guard=tl_guard)

        # Lateral
        steer = self._compute_lateral(obs, offset_left_cmd)

        steer    = self._clip(steer, -1.0, 1.0)
        throttle = self._clip(throttle, 0.0, 1.0)
        brake    = 0.0 if throttle > 0.0 else self._clip(brake, 0.0, 1.0)
        return (steer, throttle, brake)
