

import argparse
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

LONG_COL   = "y"
LAT_COL    = "x"
VELX_COL   = "xVelocity"
VELY_COL   = "yVelocity"
ACCX_COL   = "xAcceleration"
ACCY_COL   = "yAcceleration"
LANE_COL   = "laneId"
ID_COL     = "id"
FRAME_COL  = "frame"

FPS = 25.0
DT  = 1.0 / FPS
HISTORY_S = 3.0
HISTORY_STEPS = int(round(HISTORY_S * FPS))  # 75

INTENT_IDX = {
    "free_flow":     0,
    "car_following": 1,
    "lane_changing": 2,
    "merging":       3,
    "emergency":     4,
}
IGNORE_INDEX = -1



def _get_long_vel(row):
    return float(row[VELY_COL] if LONG_COL == "y" else row[VELX_COL])

def _get_lat_vel(row):
    return float(row[VELX_COL] if LONG_COL == "y" else row[VELY_COL])

def _central_diff(arr, dt):
    n = len(arr)
    if n == 0:
        return np.array([], dtype=np.float64)
    out = np.zeros(n, dtype=np.float64)
    if n >= 2:
        out[0]     = (arr[1] - arr[0]) / dt
        out[-1]    = (arr[-1] - arr[-2]) / dt
    if n >= 3:
        out[1:-1]  = (arr[2:] - arr[:-2]) / (2.0 * dt)
    return out

# ---------- Precompute "leader" metrics at each frame (same-lane nearest-ahead) ----------
def add_lead_metrics(df: pd.DataFrame) -> pd.DataFrame:

    v_long = np.where(LONG_COL == "y", df[VELY_COL].to_numpy(), df[VELX_COL].to_numpy())
    df = df.copy()
    df["_v_long"] = v_long.astype(np.float64)

    df["lead_gap"]  = np.inf
    df["lead_rel_v"] = 0.0
    df["thw"]       = np.inf
    df["ttc"]       = np.inf


    for fr, gfr in tqdm(df.groupby(FRAME_COL, sort=False), desc="Precompute leaders", total=df[FRAME_COL].nunique()):

        for lane, gl in gfr.groupby(LANE_COL, sort=False):

            idx = gl.sort_values(LONG_COL).index.to_numpy()
            if idx.size == 0:
                continue
            y   = df.loc[idx, LONG_COL].to_numpy(dtype=np.float64)
            vl  = df.loc[idx, "_v_long"].to_numpy(dtype=np.float64)

            if idx.size >= 2:
                # For all except last
                i_curr = idx[:-1]
                i_lead = idx[1:]
                gap = (y[1:] - y[:-1])               # positive ahead
                rel_v = (vl[1:] - vl[:-1])           # leader - ego
                thw = gap / np.maximum(np.abs(vl[:-1]), 0.1)
                # TTC only if closing-in (rel_v<0)
                ttc = np.where(rel_v < 0.0, gap / np.maximum(-rel_v, 1e-3), np.inf)

                df.loc[i_curr, "lead_gap"]   = gap
                df.loc[i_curr, "lead_rel_v"] = rel_v
                df.loc[i_curr, "thw"]        = thw
                df.loc[i_curr, "ttc"]        = ttc

    return df

def label_intent_for_vehicle(sub: pd.DataFrame) -> tuple[np.ndarray, np.ndarray]:

    n = len(sub)
    labels = np.full(n, IGNORE_INDEX, dtype=np.int32)
    confs  = np.zeros(n, dtype=np.float32)
    if n == 0:
        return labels, confs

    # Extract series (numpy arrays for speed)
    vx = sub[VELX_COL].to_numpy(dtype=np.float64, copy=False)
    vy = sub[VELY_COL].to_numpy(dtype=np.float64, copy=False)
    x_lat = sub[LAT_COL].to_numpy(dtype=np.float64, copy=False)
    y_long = sub[LONG_COL].to_numpy(dtype=np.float64, copy=False)
    lane_seq = sub[LANE_COL].to_numpy(dtype=np.int32, copy=False)

    # Derived quantities
    yaw = np.arctan2(vy, vx + 1e-6)
    yaw_rate = _central_diff(yaw, DT)

    # accelerations
    ax = _central_diff(vx, DT)
    ay = _central_diff(vy, DT)

    # tangential & normal unit vectors
    t_hat = np.stack([np.cos(yaw), np.sin(yaw)], axis=1)              # [n,2]
    n_hat = np.stack([-np.sin(yaw), np.cos(yaw)], axis=1)             # [n,2]

    # longitudinal accel (parallel component) and lateral speed (normal component)
    a_par = ax * t_hat[:, 0] + ay * t_hat[:, 1]
    v_lat = vx * n_hat[:, 0] + vy * n_hat[:, 1]

    # Leader metrics (already on sub via precompute)
    gap   = sub["lead_gap"].to_numpy(dtype=np.float64, copy=False)
    rel_v = sub["lead_rel_v"].to_numpy(dtype=np.float64, copy=False)
    thw   = sub["thw"].to_numpy(dtype=np.float64, copy=False)
    ttc   = sub["ttc"].to_numpy(dtype=np.float64, copy=False)

    # Window sizes used in original logic
    win_lc = int(round(1.5 / DT))  # ~1.5 s for lane-change detection
    win_m  = int(round(1.0 / DT))  # ~1.0 s for merging accel mean

    for j in range(n):
        if j < HISTORY_STEPS:

            labels[j] = IGNORE_INDEX
            confs[j]  = 0.0
            continue


        if (a_par[j] < -3.5) or (ttc[j] < 1.2):
            labels[j] = INTENT_IDX["emergency"]
            confs[j]  = float(np.clip((-a_par[j]/6.0) if np.isfinite(a_par[j]) else 1.0, 0, 1))
            continue

        lane_changed = (j > 0 and lane_seq[j] != lane_seq[j-1])
        t0 = max(0, j - win_lc // 2)
        t1 = min(n, j + win_lc // 2 + 1)
        cum_lat = abs(x_lat[t1-1] - x_lat[t0]) if t1-1 >= t0 else 0.0

        is_lc = lane_changed or (abs(v_lat[j]) > 0.35) or (cum_lat > 1.5) or (abs(yaw_rate[j]) > 0.12)

        t0m = max(0, j - win_m // 2)
        t1m = min(n, j + win_m // 2 + 1)
        a_win_mean = float(np.mean(a_par[t0m:t1m])) if t1m > t0m else 0.0
        is_mg = is_lc and (a_win_mean > 0.8)

        if is_mg:
            labels[j] = INTENT_IDX["merging"]
            confs[j]  = float(np.clip(a_win_mean/2.0, 0, 1))
            continue

        if is_lc:
            labels[j] = INTENT_IDX["lane_changing"]
            confs[j]  = 0.7 if lane_changed else 0.6
            continue

        if (np.isfinite(thw[j]) and (0.8 <= thw[j] <= 2.5)
            and (abs(rel_v[j]) < 2.0)
            and (abs(v_lat[j]) < 0.25)
            and (abs(yaw_rate[j]) < 0.08)):
            labels[j] = INTENT_IDX["car_following"]
            confs[j]  = float(np.clip(1.0 - abs(rel_v[j])/4.0, 0, 1))
            continue

        labels[j] = INTENT_IDX["free_flow"]
        confs[j]  = 0.6

    return labels, confs

def main(in_path,out_path):
    in_csv  = in_path
    out_csv = out_path

    print(f"[load] {in_csv}")
    df = pd.read_csv(in_csv)

    need = {ID_COL, FRAME_COL, LONG_COL, LAT_COL, VELX_COL, VELY_COL, LANE_COL}
    missing = [c for c in need if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns in input CSV: {missing}")

    df = df.dropna(subset=[ID_COL, FRAME_COL, LONG_COL, LAT_COL, LANE_COL]).copy()
    df[ID_COL]    = df[ID_COL].astype(int)
    df[FRAME_COL] = df[FRAME_COL].astype(int)
    df[LANE_COL]  = df[LANE_COL].astype(int)

    df.sort_values([FRAME_COL, ID_COL], inplace=True, kind="mergesort")

    df = add_lead_metrics(df)

    df["intent_label"] = IGNORE_INDEX
    df["intent_conf"]  = 0.0

    for vid, sub in tqdm(df.groupby(ID_COL, sort=False), desc="Per-vehicle intent"):
        sub = sub.sort_values(FRAME_COL, kind="mergesort")
        labels, confs = label_intent_for_vehicle(sub)
        df.loc[sub.index, "intent_label"] = labels
        df.loc[sub.index, "intent_conf"]  = confs

    print(f"[save] {out_csv}")
    df.to_csv(out_csv, index=False)
    print("Done.")

if __name__ == "__main__":
    in_path= 'RawDataset/HighD/data/01_tracks.csv'
    out_path='RawDataset/HighD/data/01_tracks_labeled.csv'
    main(in_path,out_path)
