import os
import json
import math
import pandas as pd
import numpy as np
import torch
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

try:
    from scipy.spatial import cKDTree
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

if _HAS_SCIPY:
    print("[dataprocess] Using SciPy cKDTree acceleration for kNN.")
else:
    print("[dataprocess] SciPy not found – falling back to NumPy kNN (slower).")

LONG_COL = 'x'  # driving direction
LAT_COL = 'y'  # lateral direction
VELX_COL = 'xVelocity'
VELY_COL = 'yVelocity'
ACCX_COL = 'xAcceleration'
ACCY_COL = 'yAcceleration'
LANE_COL = 'laneId'
ID_COL = 'id'
FRAME_COL = 'frame'
WIDTH_COL = 'width'
HEIGHT_COL = 'height'

INTENT_COL='intent_label'

VISIBLE_FRONT_M = 90.0
VISIBLE_BACK_M = 60.0
K_DIFFUSION = 6
K_ADV_FWD = 3

FPS = 25.0  # HighD is recorded at 25 Hz
DT = 1.0 / FPS
FUTURE_S = 5.0
HISTORY_S=3.0
MAX_SAMPLES=10000

FUTURE_STEPS = int(round(FUTURE_S * FPS))
HISTORY_STEPS = int(round(HISTORY_S * FPS))
def set_fps(fps:float):
    global FPS,DT,FUTURE_STEPS,HISTORY_STEPS
    FPS=float(fps)
    DT=1.0 / FPS
    FUTURE_STEPS = int(round(FUTURE_S * FPS))
    HISTORY_STEPS = int(round(HISTORY_S * FPS))

# Macro density evaluated only at key horizon
KEY_TIMES_S = [0.0,1.0, 3.0, 5.0]
def _key_indices(fps: float = FPS, key_times_s: list[float] = KEY_TIMES_S) -> list[int]:
    idx = []
    for s in key_times_s:
        i = int(round(s * fps))  # frames after t=0
        i = max(0, min(FUTURE_STEPS, i))
        idx.append(i)
    return idx

def _gaussian_kernel_1d(dx: np.ndarray, h: float) -> np.ndarray:
    h = float(max(1e-6, h))
    return (1.0 / (np.sqrt(2.0 * np.pi) * h)) * np.exp(-0.5 * (dx / h) ** 2)


def _kde_rho_u_at(
    frame_df: pd.DataFrame,
    x_query: float,
    lanes_considered: set,
    h_m: float,
    long_min: float,
    long_max: float,
) -> tuple[float, float, int]:
    present = frame_df[frame_df[LANE_COL].isin(lanes_considered)].copy()
    if present.empty:
        return 0.0, 0.0, 0

    y = present[LONG_COL].to_numpy(dtype=np.float64)
    if LONG_COL == 'x':
        ui = present[VELX_COL].to_numpy(dtype=np.float64, copy=True)
    else:
        ui = present[VELY_COL].to_numpy(dtype=np.float64, copy=True)

    a, b = float(long_min), float(long_max)
    left_mask = y < (a + 3.0 * h_m)
    right_mask = y > (b - 3.0 * h_m)
    if left_mask.any():
        y = np.concatenate([y, (2.0 * a) - y[left_mask]])
        ui = np.concatenate([ui, ui[left_mask]])
    if right_mask.any():
        y = np.concatenate([y, (2.0 * b) - y[right_mask]])
        ui = np.concatenate([ui, ui[right_mask]])

    dx = x_query - y
    k = _gaussian_kernel_1d(dx, h_m)
    denom = k.sum()
    rho_per_m = float(denom)  # sum of kernels = veh/m

    u_hat = float((k * ui).sum() / denom) if denom > 1e-12 else 0.0

    lanes_present = set(int(l) for l in present[LANE_COL].dropna().unique().tolist())
    n_lanes = max(1, len(lanes_present))
    rho_per_km_per_lane = rho_per_m * 1000.0 / float(n_lanes)
    return rho_per_km_per_lane, u_hat, n_lanes

DATA_LONG_MIN = None
DATA_LONG_MAX = None
DATA_LANE_MIN = None
DATA_LANE_MAX = None

DF_MULTI = None


def get_long_vel(row: pd.Series) -> float:
    return float(row[VELX_COL] if LONG_COL == 'x' else row[VELY_COL])


def get_lat_vel(row: pd.Series) -> float:
    return float(row[VELY_COL] if LONG_COL == 'x' else row[VELX_COL])


def get_long_acc(row: pd.Series) -> float:
    v = row.get(ACCX_COL, 0.0) if LONG_COL == 'x' else row.get(ACCY_COL, 0.0)
    return float(v if v is not None else 0.0)


def get_lat_acc(row: pd.Series) -> float:
    v = row.get(ACCY_COL, 0.0) if LONG_COL == 'x' else row.get(ACCX_COL, 0.0)
    return float(v if v is not None else 0.0)


# Load dataset
def load_trajectories(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    need = {ID_COL, FRAME_COL, LONG_COL, LAT_COL, WIDTH_COL, HEIGHT_COL, VELX_COL, VELY_COL, LANE_COL}
    miss = [c for c in need if c not in df.columns]
    if miss:
        raise ValueError('The following columns are missing from the trajectories file: {}'.format(miss))
    df = df.dropna(subset=[ID_COL, FRAME_COL, LONG_COL, LAT_COL, LANE_COL]).copy()
    df[ID_COL] = df[ID_COL].astype(int)
    df[FRAME_COL] = df[FRAME_COL].astype(int)
    df[LANE_COL] = df[LANE_COL].astype(int)
    return df


def slice_frame(df: pd.DataFrame, frame: int) -> pd.DataFrame:
    return df[df[FRAME_COL] == frame].copy()


def pick_ego(frame_df: pd.DataFrame, ego_id: int) -> pd.Series:
    row = frame_df[frame_df[ID_COL] == ego_id]
    if row.empty:
        raise ValueError(f'Ego id {ego_id} not found in frame {int(frame_df[FRAME_COL].iloc[0])}')
    return row.iloc[0]


# GK: node features (relative to ego)
GK_NODE_FEATURE_NAMES = [
    'dx', 'dy', 'dvx', 'dvy', 'same_lane', 'dist', 'inv_ttc', 'inv_thw', 'size', 'width'
]


def safe_inv(x, clip=1e6):
    x = float(x)
    if x <= 1e-6:
        return 0.0
    return float(min(1.0 / x, clip))


def compute_inv_ttc(dy, dv_long):
    if dv_long > 1e-6:
        return safe_inv(abs(dy) / dv_long)
    return 0.0


def compute_inv_thw(gap_long, ego_speed_long):
    if gap_long > 1e-6 and ego_speed_long > 1e-6:
        return safe_inv(abs(gap_long) / ego_speed_long)
    return 0.0


def build_gk_node_features(frame_df: pd.DataFrame, ego: pd.Series, neighbor_ids: list) -> np.ndarray:
    N = 1 + len(neighbor_ids)
    F = len(GK_NODE_FEATURE_NAMES)
    feats = np.zeros((N, F), dtype=np.float32)

    ex, ey = float(ego[LAT_COL]), float(ego[LONG_COL])
    ev_long = get_long_vel(ego)
    evx = float(ego.get(VELX_COL, 0.0))
    evy = float(ego.get(VELY_COL, 0.0))
    elane = int(ego[LANE_COL])

    if len(neighbor_ids) == 0:
        feats[0][-1], feats[0][-2] = ego.get(WIDTH_COL), ego.get(HEIGHT_COL)
        return feats

    sub = frame_df.set_index(ID_COL)
    for k, nid in enumerate(neighbor_ids, start=1):
        r = sub.loc[int(nid)]
        nx, ny = float(r[LAT_COL]), float(r[LONG_COL])
        nvx = float(r.get(VELX_COL, 0.0))
        nvy = float(r.get(VELY_COL, 0.0))
        nlane = int(r[LANE_COL])

        dx = nx - ex
        dy = ny - ey
        dvx = nvx - evx
        dvy = nvy - evy
        same_lane = 1.0 if (nlane == elane) else 0.0
        dist = float((dx ** 2 + dy ** 2) ** 0.5)

        dv_long = get_long_vel(r) - ev_long
        inv_ttc = compute_inv_ttc(dy=dy, dv_long=dv_long)
        int_thw = compute_inv_thw(gap_long=abs(dy), ego_speed_long=abs(ev_long))

        # size = 1.0 if float(r.get(WIDTH_COL, 5.0)) >= 6.5 else 0.0
        # width = float(r.get(WIDTH_COL, 5.0))
        size = 1.0 if float(r.get(HEIGHT_COL, 5.0)) >= 6.5 else 0.0
        width = float(r.get(HEIGHT_COL, 5.0))

        feats[k, :] = [dx, dy, dvx, dvy, same_lane, dist, inv_ttc, int_thw, size, width]
    return feats


def select_visible_subgraph(frame_df: pd.DataFrame, ego: pd.Series,
                            front_range_m: float = 120.0, back_range_m: float = 90.0) -> pd.DataFrame:
    ey = float(ego[LONG_COL])
    mask = ((frame_df[LONG_COL] >= ey - back_range_m) &
            (frame_df[LONG_COL] <= ey + front_range_m) &
            (abs(frame_df[LANE_COL] - ego[LANE_COL]) <= 1))
    vis = frame_df[mask].copy()
    return vis

def get_future_traj(df: pd.DataFrame, ego_id: int, start_frame: int, K: int = FUTURE_STEPS) -> tuple[np.ndarray, int]:
    traj = np.full((int(K), 2), np.nan, dtype=np.float32)
    valid = 0
    for t in range(1, int(K) + 1):
        fr = start_frame + t
        fdf = slice_frame(df, fr)
        if fdf.empty:
            break
        row = fdf[fdf[ID_COL] == ego_id]
        if row.empty:
            break
        r = row.iloc[0]
        traj[t - 1, 0] = float(r[LAT_COL])
        traj[t - 1, 1] = float(r[LONG_COL])
        valid = t
    return traj, valid

# Intent label
INTENT_IDX={"free_flow":0,'car_following':1,'lane_changing':2,'merging':3,'emergency':4}
IGNORE_INDEX = -100

def build_ego_state(ego: pd.Series):
    x = float(ego[LAT_COL])
    y = float(ego[LONG_COL])

    v_long = get_long_vel(ego)
    v_lat = get_lat_vel(ego)
    a_long = get_long_acc(ego)
    a_lat = get_lat_acc(ego)

    vx = float(ego.get(VELX_COL, 0.0))
    vy = float(ego.get(VELY_COL, 0.0))

    heading = 0.0 if (abs(vx) < 1e-6 and abs(vy) < 1e-6) else math.atan2(vy, vx)
    lane_id = int(ego[LANE_COL])
    height = float(ego[HEIGHT_COL]) if HEIGHT_COL in ego else 2.0
    width = float(ego[WIDTH_COL]) if WIDTH_COL in ego else 5.0

    return np.array([x,y,v_long,v_lat,a_long,a_lat,heading,lane_id,height,width],dtype=np.float32)


def get_vk_feat(df: pd.DataFrame, ego_id: int, start_frame: int, H: int = HISTORY_STEPS) -> tuple[np.ndarray, int]:
    buf=[]
    for t in range(H,-1,-1):
        ft=start_frame-t
        fdf=slice_frame(df, ft)
        if fdf.empty:
            break
        row = fdf[fdf[ID_COL] == ego_id]
        if row.empty:
            break
        r=row.iloc[0]
        buf.append(build_ego_state(r))
    vk_feat=np.stack(buf,axis=0).astype(np.float32)
    return vk_feat, vk_feat.shape[0]

def get_intent(df:pd.DataFrame,ego_id:int,start_frame:int,H:int=HISTORY_STEPS):
    buf=[]
    for t in range(H,-1,-1):
        ft=start_frame-t
        fdf=slice_frame(df, ft)
        if fdf.empty:
            break
        row=fdf[fdf[ID_COL] == ego_id]
        if row.empty:
            break
        r=row.iloc[0]
        buf.append(r[INTENT_COL])
    return buf

def get_history_graph_seq(df:pd.DataFrame,ego_id: int, start_frame: int, H: int = HISTORY_STEPS):
    x_seq=[]
    ei_diff_seq=[]
    ei_adv_seq=[]
    num_x_seq=[]
    num_diff_seq=[]
    num_adv_seq=[]

    for t in range(H,-1,-1):
        ft=start_frame-t
        fdf=slice_frame(df, ft)
        if fdf.empty:
            break
        row=fdf[fdf[ID_COL] == ego_id]
        if row.empty:
            break
        ego=row.iloc[0]

        vis_df=select_visible_subgraph(fdf,ego,front_range_m=VISIBLE_FRONT_M,back_range_m=VISIBLE_BACK_M)

        node_ids=sorted(vis_df[ID_COL].unique().tolist())
        if ego_id in node_ids:
            node_ids.remove(ego_id)
        node_ids=[ego_id]+node_ids
        id2idx={nid:i for i,nid in enumerate(node_ids)}

        neighbor_ids=[nid for nid in node_ids if nid != int(ego_id)]
        gk_node_feats=build_gk_node_features(fdf,ego,neighbor_ids)

        ei_diff=build_edges_diffusion_knn(vis_df,id2idx,k=K_DIFFUSION)
        ei_adv=build_edges_advection_forward_knn(vis_df,id2idx,k_forward=K_ADV_FWD)

        x_seq.append(gk_node_feats.astype(np.float32))
        ei_diff_seq.append(ei_diff.astype(np.int64))
        ei_adv_seq.append(ei_adv.astype(np.int64))
        num_x_seq.append(gk_node_feats.shape[0])
        num_diff_seq.append(ei_diff.shape[1])
        num_adv_seq.append(ei_adv.shape[1])

    valid=len(num_x_seq)
    return x_seq,ei_diff_seq,ei_adv_seq,num_x_seq,num_diff_seq,num_adv_seq,valid

def window_bounds_check(y, lane, front_m: float, back_m: float) -> bool:
    global DATA_LONG_MIN, DATA_LONG_MAX, DATA_LANE_MIN, DATA_LANE_MAX
    if any(v is None for v in (DATA_LONG_MIN, DATA_LONG_MAX, DATA_LANE_MIN, DATA_LANE_MAX)):
        return True
    y_min=float(y-back_m)
    y_max=float(y+front_m)
    lane_min=int(lane-1)
    lane_max=int(lane+1)
    if y_min<DATA_LONG_MIN-1e-6:
        return False
    if y_max>DATA_LONG_MAX+1e-6:
        return False
    if lane_min<DATA_LANE_MIN:
        return False
    if lane_max>DATA_LANE_MAX:
        return False
    return True

def get_last_state(df:pd.DataFrame,ego_id:int,start_frame:int,T:int,flag:bool):
    if flag==True:
        fr=int(start_frame+T)
    else:
        fr=int(start_frame-T)
    fdf=slice_frame(df,fr)
    if fdf.empty:
        return False,None
    row=fdf[fdf[ID_COL] == ego_id]
    if row.empty:
        return False,None
    r=row.iloc[0]
    lane=int(r[LANE_COL])
    y=float(r[LONG_COL])
    return True,lane,y

def knn_edges(points_xy: np.ndarray, k: int, exclude_self: bool = True) -> np.ndarray:
    N = int(points_xy.shape[0])
    if N <= 1 or k <= 0:
        return np.zeros((2, 0), dtype=np.int64)
    k = min(int(k), N - 1)

    if _HAS_SCIPY:
        tree = cKDTree(points_xy)
        # query k+1 to include self, then drop self if requested
        dists, idxs = tree.query(points_xy, k=k + 1)
        # When k==1, query returns shape (N,), unify shapes
        if k == 1:
            idxs = idxs[:, None]
            dists = dists[:, None]
        src, dst = [], []
        for i in range(N):
            neigh = idxs[i]
            if exclude_self:
                # remove i if present
                neigh = neigh[neigh != i]
            # ensure at most k
            if neigh.size > k:
                neigh = neigh[:k]
            for j in neigh.tolist():
                src.append(i)
                dst.append(int(j))
        return np.vstack([src, dst]).astype(np.int64)

    dx = points_xy[:, None, 0] - points_xy[None, :, 0]
    dy = points_xy[:, None, 1] - points_xy[None, :, 1]
    dist2 = dx * dx + dy * dy
    if exclude_self:
        np.fill_diagonal(dist2, np.inf)

    nbr_idx = np.argpartition(dist2, kth=k - 1, axis=1)[:, :k]
    src, dst = [], []
    for i in range(N):
        for j in nbr_idx[i]:
            src.append(i)
            dst.append(int(j))
    return np.vstack([src, dst]).astype(np.int64)


def build_edges_diffusion_knn(vis_df: pd.DataFrame, id2idx: dict, k: int = K_DIFFUSION) -> np.ndarray:
    N = len(id2idx)
    pts = np.zeros((N, 2), dtype=np.float32)
    sub = vis_df.set_index(ID_COL)
    for nid, i in id2idx.items():
        r = sub.loc[int(nid)]
        pts[i, 0] = float(r[LAT_COL])
        pts[i, 1] = float(r[LONG_COL])

    ei_dir = knn_edges(pts, k=k, exclude_self=True)
    if ei_dir.shape[1] == 0:
        return np.zeros((2, 0), dtype=np.int64)

    undirected = set()
    for u, v in ei_dir.T.tolist():
        a, b = (u, v) if u < v else (v, u)
        undirected.add((a, b))
    e2 = []
    for a, b in undirected:
        e2.append((a, b))
        e2.append((b, a))
    return np.array(e2, dtype=np.int64).T


def build_edges_advection_forward_knn(vis_df: pd.DataFrame, id2idx: dict, k_forward: int = 3) -> np.ndarray:
    N = len(id2idx)
    if N <= 1 or k_forward <= 0:
        return np.zeros((2, 0), dtype=np.int64)

    sub = vis_df.set_index(ID_COL)

    lat = np.zeros((N,), dtype=np.float32)
    lon = np.zeros((N,), dtype=np.float32)
    v_long = np.zeros((N,), dtype=np.float32)

    for nid, i in id2idx.items():
        r = sub.loc[int(nid)]
        lat[i] = float(r[LAT_COL])
        lon[i] = float(r[LONG_COL])
        v_long[i] = get_long_vel(r)

    pts = np.stack([lat, lon], axis=1)
    long_dir = np.sign(np.median(v_long))
    if long_dir == 0:
        long_dir = 1.0

    # KDTree preselection size: modest multiple of k to keep it fast but robust
    pre_k = int(min(max(8, k_forward * 6), max(1, N - 1)))

    src, dst = [], []
    if _HAS_SCIPY:
        tree = cKDTree(pts)
        # Query pre_k+1 (self included), then filter
        dists, idxs = tree.query(pts, k=min(pre_k + 1, max(1, N)))
        if pre_k == 1:
            # un-broadcast shapes for k==1 corner case
            idxs = idxs[:, None]
            dists = dists[:, None]
        for i in range(N):
            neigh = idxs[i]
            neigh = neigh[neigh != i]  # drop self
            if neigh.size == 0:
                continue
            # forward condition in longitudinal direction
            forward_mask = ((lon[neigh] - lon[i]) * long_dir) > 0
            f_idx = neigh[forward_mask]
            if f_idx.size == 0:
                continue
            # choose k_forward with smallest Euclidean distance
            dd = (lat[f_idx] - lat[i]) ** 2 + (lon[f_idx] - lon[i]) ** 2
            ksel = min(k_forward, f_idx.size)
            pick = f_idx[np.argpartition(dd, kth=ksel - 1)[:ksel]]
            for j in pick.tolist():
                src.append(int(j))
                dst.append(i)
        if len(src) == 0:
            return np.zeros((2, 0), dtype=np.int64)
        return np.vstack([src, dst]).astype(np.int64)

    for i in range(N):
        dlon = (lon - lon[i]) * long_dir
        idxs = np.where(dlon > 0)[0]
        if idxs.size == 0:
            continue
        dlat = lat[idxs] - lat[i]
        dlon_abs = lon[idxs] - lon[i]
        dist2 = dlat * dlat + dlon_abs * dlon_abs
        k = min(k_forward, idxs.size)
        nbr = idxs[np.argpartition(dist2, kth=k - 1)[:k]]
        for j in nbr.tolist():
            src.append(j)
            dst.append(i)
    if len(src) == 0:
        return np.zeros((2, 0), dtype=np.int64)
    return np.vstack([src, dst]).astype(np.int64)

def compute_traditional_density_instant(
        df: pd.DataFrame,
        frame: int,
        ego: pd.Series,
        front_m: float = VISIBLE_FRONT_M,
        back_m: float = VISIBLE_BACK_M,
        lane_radius: int = 1,
) -> dict:
    y0 = float(ego[LONG_COL])
    lane0 = int(ego[LANE_COL])
    lanes = set(range(lane0 - lane_radius, lane0 + lane_radius + 1))

    fdf = slice_frame(df, frame)
    if fdf.empty:
        return {
            'k_veh_per_km_per_lane': 0.0,
            'n_lanes_considered': 0,
            'N_vehicles': 0,
            'L_m': float(front_m + back_m)
        }

    h_m = 25.0

    y_min = y0 - back_m
    y_max = y0 + front_m
    roi = fdf[(fdf[LONG_COL] >= y_min) & (fdf[LONG_COL] <= y_max)].copy()

    rho_k, u_hat, n_lanes = _kde_rho_u_at(
        frame_df=roi,
        x_query=y0,
        lanes_considered=lanes,
        h_m=h_m,
        long_min=DATA_LONG_MIN if DATA_LONG_MIN is not None else y_min,
        long_max=DATA_LONG_MAX if DATA_LONG_MAX is not None else y_max,
    )

    return {
        'k_veh_per_km_per_lane': float(rho_k),
        'n_lanes_considered': int(n_lanes),
        'N_vehicles': int(roi[roi[LANE_COL].isin(lanes)][ID_COL].nunique()),
        'L_m': float(y_max - y_min)
    }


def build_one_sample(df: pd.DataFrame, frame_df: pd.DataFrame, frame: int, ego_id: int) -> dict:
    try:
        future_last_valid,future_last_lane,future_last_y=get_last_state(df,ego_id=int(ego_id),start_frame=int(frame),T=FUTURE_STEPS,flag=True)
        if future_last_valid is not True:
            return None
        history_last_valid,history_last_lane,history_last_y=get_last_state(df,ego_id=int(ego_id),start_frame=int(frame),T=HISTORY_STEPS,flag=False)
        if history_last_valid is not True:
            return None
        future_bound_check=window_bounds_check(future_last_y,future_last_lane,front_m=VISIBLE_FRONT_M,back_m=VISIBLE_BACK_M)
        if future_bound_check is not True:
            return None
        history_bound_check=window_bounds_check(history_last_y,history_last_lane,front_m=VISIBLE_FRONT_M,back_m=VISIBLE_BACK_M)
        if history_bound_check is not True:
            return None

        fut_traj_xy, fut_valid = get_future_traj(df, ego_id=int(ego_id), start_frame=int(frame), K=FUTURE_STEPS)
        if fut_valid < FUTURE_STEPS:
            return None
        vk_feat,vk_feat_valid=get_vk_feat(df,ego_id=int(ego_id),start_frame=int(frame),H=HISTORY_STEPS)
        if vk_feat_valid < HISTORY_STEPS+1:
            return None
        x_nodes_seq,ei_diff_seq,ei_adv_seq,num_x_seq,num_diff_seq,num_adv_seq,valid=get_history_graph_seq(df,ego_id=int(ego_id),start_frame=int(frame),H=HISTORY_STEPS)
        if valid<HISTORY_STEPS+1:
            return None

        trad_k_seq = compute_traditional_density_seq(
            df=df,
            ego_id=int(ego_id),
            start_frame=int(frame),
        )  # [Kp,1]

        anchor_hist=get_intent(df,ego_id=int(ego_id),start_frame=int(frame),H=HISTORY_STEPS)

        sample = {
            'frame': int(frame),
            'ego_id': int(ego_id),
            'vk_feat': vk_feat,
            'future_traj': fut_traj_xy,
            'x_nodes_seq': x_nodes_seq,
            'edge_index_diff_seq': ei_diff_seq,
            'edge_index_adv_seq': ei_adv_seq,
            'num_x_seq': num_x_seq,
            'num_diff_seq': num_diff_seq,
            'num_adv_seq': num_adv_seq,
            'trad_k_seq': trad_k_seq.astype(np.float32), # [Kp,1] sequence at keypoints
            'hist_intent':np.array(anchor_hist).astype(np.float32),
        }
        #assert sample['vk_feat'].shape[0]==76,'vk_feat dim0 error'
        #assert sample['vk_feat'].shape[1]==10,'vk_feat dim1 error'
        #assert sample['future_traj'].shape[0]==125,'future_traj dim0 error'
        #assert len(num_x_seq)==len(num_diff_seq)==len(num_adv_seq)==76,'graph error'
        #assert trad_k_seq.shape[0]==4,'density,error'
        #assert sample['hist_intent'].shape[0]==76,'hist_intent dim0 error'

        return sample
    except Exception as e:
        return None

def compute_traditional_density_seq(
        df: pd.DataFrame,
        ego_id: int,
        start_frame: int,
) -> np.ndarray:
    global DF_MULTI
    if DF_MULTI is None:
        DF_MULTI = df.set_index([FRAME_COL, ID_COL]).sort_index()
    key_idx = _key_indices(FPS, KEY_TIMES_S)
    out = np.zeros((len(key_idx), 1), dtype=np.float32)

    h_m = 25.0

    for j, k in enumerate(key_idx):
        fr = start_frame + k
        try:
            ego_t = DF_MULTI.loc[(fr, ego_id)]
        except KeyError:
            out[j, 0] = 0.0
            continue

        lane0 = int(ego_t[LANE_COL])
        lanes = set(range(lane0 - 1, lane0 + 2))
        yq = float(ego_t[LONG_COL])

        fdf = slice_frame(df, fr)
        if fdf.empty:
            out[j, 0] = 0.0
            continue

        y_min = yq - VISIBLE_BACK_M
        y_max = yq + VISIBLE_FRONT_M
        roi = fdf[(fdf[LONG_COL] >= y_min) & (fdf[LONG_COL] <= y_max)].copy()

        rho_k, _, _ = _kde_rho_u_at(
            frame_df=roi,
            x_query=yq,
            lanes_considered=lanes,
            h_m=h_m,
            long_min=DATA_LONG_MIN if DATA_LONG_MIN is not None else y_min,
            long_max=DATA_LONG_MAX if DATA_LONG_MAX is not None else y_max,
        )
        out[j, 0] = float(rho_k)
    return out

def build_dataset(
        csv_path: str,
        out_dir: str,
        start_frame: int | None = None,
        end_frame: int | None = None,
        stride: int = 1,
        max_samples: int | None = None,
        dataset:str='HighD'
):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f'CSV file {csv_path} does not exist')

    print(f"Loading trajectory data from {csv_path}...")
    df = load_trajectories(csv_path)

    if dataset.lower()=='ngsim':
        set_fps(10.0)
        global LONG_COL,LAT_COL
        LONG_COL,LAT_COL='y','x'

    global DATA_LONG_MIN, DATA_LONG_MAX, DATA_LANE_MIN, DATA_LANE_MAX
    DATA_LONG_MIN = float(df[LONG_COL].min())
    DATA_LONG_MAX = float(df[LONG_COL].max())
    DATA_LANE_MIN = int(df[LANE_COL].min())
    DATA_LANE_MAX = int(df[LANE_COL].max())
    print(f"[dataprocess] Dataset bounds: LONG in [{DATA_LONG_MIN:.2f}, {DATA_LONG_MAX:.2f}], LANE in [{DATA_LANE_MIN}, {DATA_LANE_MAX}]")

    # Prepare frames to process
    frames = sorted(df[FRAME_COL].unique().tolist())
    if start_frame is not None:
        frames = [f for f in frames if f >= start_frame]
    if end_frame is not None:
        frames = [f for f in frames if f <= end_frame]
    frames = frames[::max(1, int(stride))]

    print("Collecting frame-ego pairs to process...")
    frame_ego_pairs = []
    for fr in tqdm(frames, desc="Scanning frames"):
        fdf = slice_frame(df, fr)
        for ego_id in fdf[ID_COL].unique().tolist():
            frame_ego_pairs.append((fr, int(ego_id)))

    if max_samples is not None and len(frame_ego_pairs) > max_samples:
        frame_ego_pairs = frame_ego_pairs[:max_samples]

    total_pairs = len(frame_ego_pairs)
    print(f"Total frame-ego pairs to process: {total_pairs}")

    dataset = []
    skipped = 0
    with tqdm(total=total_pairs, desc="Processing samples (single-thread)") as pbar:
        for (fr, ego_id) in frame_ego_pairs:
            fdf = slice_frame(df, fr)
            sample = build_one_sample(df, fdf, fr, ego_id)
            if sample is not None:
                dataset.append(sample)
                if len(dataset)>=MAX_SAMPLES:
                    break
            else:
                skipped += 1
            pbar.update(1)
            pbar.set_postfix({
                'processed': len(dataset),
                'skipped': skipped,
                'success_rate': f'{100 * len(dataset) / (len(dataset) + skipped):.1f}%'
            })

    print(f"\nSaving dataset to {out_dir}...")
    torch.save(dataset, out_dir)
    print(f'Done! Total samples: {len(dataset)}, Skipped: {skipped}')
    print(f'Success rate: {100 * len(dataset) / (len(dataset) + skipped):.2f}%')
    print(f'Dataset saved to: {out_dir}')

    return dataset

if __name__ == '__main__':
    # csv_path = 'RawDataset/HighD/data/01_tracks_labeled.csv'
    # out_dir = 'ProcessedDataset/HighD/01_gk_vk_dataset.pt'

    csv_path='RawDataset/vehicle-trajectory-data/0750am-0805am/trajectories-0750am-0805am_labeled.csv'
    out_dir='ProcessedDataset/NGSIM/trajectories-0750am-0805am_gk_vk_dataset.pt'

    build_dataset(
        csv_path=csv_path,
        out_dir=out_dir,
        stride=1,
        max_samples=None,
        dataset='NGSIM'
    )
