import math
import mmcv
import torch
from torch import nn as nn
from mmdet.models import weighted_loss
from mmdet.models.builder import LOSSES

def cost_inside(distance, safe_thresh, edge_cost=0.05):
    # From (0, 1) to (safe_thresh, edge_cost)
    k = (edge_cost - 1) / safe_thresh
    return k * distance + 1

def cost_outside(distance, safe_thresh, scale=0.05):
    # distance: Tensor [N]
    x = distance - safe_thresh
    return scale * torch.exp(-x)

class CriticMapBoundConstrain(nn.Module):
    """Constraint module that penalizes the ego trajectory for approaching or intersecting lane boundaries

    Args:
        reduction (str, optional): The method to reduce the loss.
            Options are "none", "mean" and "sum".
        loss_weight (float, optional): The weight of loss.
        lane_bound_cls_idx (float, optional): lane_boundary class index.
        dis_thresh (float, optional): distance threshold between ego vehicle and lane bound.
        point_cloud_range (list, optional): point cloud range.
    """

    def __init__(
        self,
        energy=1.0,
        lane_bound_cls_idx=2,
        dis_thresh=1.0,
        edge_cost=0.05,
        point_cloud_range=[-15.0, -30.0, -2.0, 15.0, 30.0, 2.0],
        perception_detach=False
        ):
        super(CriticMapBoundConstrain, self).__init__()
        self.energy = energy
        self.lane_bound_cls_idx = lane_bound_cls_idx
        self.dis_thresh = dis_thresh
        self.edge_cost = edge_cost
        self.pc_range = point_cloud_range
        self.perception_detach = perception_detach

    def cost_function(self, dist, safe_thresh, edge_cost=0.05):
        """
        Args:
            dist: Tensor [B, T, V, P] - distance to each agent at each frame
            safe_thresh: scalar
        Returns:
            cost: Tensor [B, T]
        """

        dist = dist.flatten(2, 3) # [B, T, V, P] => [B, T, V*P]
        min_dist, _ = torch.min(dist, dim=-1) # [B, T, V*P] => [B, T]

        B, T, _ = dist.shape
        device = dist.device
        cost = torch.where(
            min_dist <= safe_thresh,
            cost_inside(min_dist, safe_thresh=safe_thresh, edge_cost=edge_cost), # [B, T]
            cost_outside(min_dist, safe_thresh=safe_thresh, scale=abs(edge_cost)) # [B, T]
        )

        return cost

    def boundary_collison_detection(self, dist, lane_boundary, ego_fut_preds, reward_augment=False):
        '''
        Detect the whether or not the boundary and the ego trajectory are collide
        dist (Tensor): [B, T, V, P]
        lane_boundary (Tensor): [B, V, P, 2]
        '''
        B, T, V, P = dist.shape
        device = ego_fut_preds.device
        dist = dist.min(dim=-1, keepdim=False)[0] # [B, T, V] choose the closet point from V vector
        min_vector_idxs = torch.argmin(dist, dim=-1) # [B, T] for each time frame choose the closest boundary
        batch_idxs = torch.arange(B).unsqueeze(1).expand(B, T)  # [B, T]
        ts_idxs = torch.arange(T).unsqueeze(0).expand(B, T)     # [B, T]
        if not reward_augment:
            lane_boundary_expanded = lane_boundary.unsqueeze(1).repeat(1, T, 1, 1, 1) # [B, V, P, 2] => [B, T, V, P, 2]
        else:
            # if we use reward augment, we should also expand the batch_size of lan_boundary here!
            lane_boundary_expanded = lane_boundary.unsqueeze(1).repeat(B, T, 1, 1, 1) # [B (fake), V, P, 2] => [B (real), T, V, P, 2]
        min_lanes = lane_boundary_expanded[batch_idxs, ts_idxs, min_vector_idxs]  # [B, T, P, 2]

        min_lane_starts = min_lanes[:, :, :-1, :].flatten(0, 2) # [B, T, P, 2] => [B, T, P-1, 2] => [B*T*(P-1), 2]
        min_lane_ends = min_lanes[:, :, 1:, :].flatten(0, 2)    # [B, T, P, 2] => [B, T, P-1, 2] => [B*T*(P-1), 2]

        # calculate the traj star point and end point
        ego_traj_starts = ego_fut_preds[:, :-1, :] # [B, T-1, 2]
        padding_zeros = torch.zeros((B, 1, 2), dtype=ego_fut_preds.dtype, device=device)
        ego_traj_starts = torch.cat((padding_zeros, ego_traj_starts), dim=1) # [B, T-1, 2] => [B, T, 2]
        ego_traj_ends = ego_fut_preds # [B, T, 2]

        ego_traj_starts = ego_traj_starts.unsqueeze(2).repeat(1, 1, P-1, 1).flatten(0, 2) # [B, T, 1, 2] => [B, T, P-1, 2] => [B*T*(P-1), 2]
        ego_traj_ends = ego_traj_ends.unsqueeze(2).repeat(1, 1, P-1, 1).flatten(0, 2)     # [B, T, 1, 2] => [B, T, P-1, 2] => [B*T*(P-1), 2]

        intersect_mask = segments_intersect(ego_traj_starts, ego_traj_ends, min_lane_starts, min_lane_ends)
        intersect_mask = intersect_mask.reshape(B, T, P-1) # [B, T, P-1]
        intersect_mask = intersect_mask.any(dim=-1) # [B, T]

        return intersect_mask
        
    def forward(
            self,
            ego_fut_preds,
            map_gt,
            map_type_gt,
            reward_augment=False
        ):
        """Forward function.

        Args:
            ego_fut_preds (Tensor): [B, fut_ts, 2]
            map_gt (Tensor): [B, num_vec, num_pts, 2]
            map_type_gt (Tensor): [B, num_vec]
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
        """ 
        device = ego_fut_preds.device
        # filter lane element according to confidence score and class
        # ['divider','ped_crossing','boundary']
        assert map_gt.size(0) == 1, "Only batch_size=1 is supported in IntrinsicMapBoundConstrain."

        lane_bound_mask = (map_type_gt[0] == self.lane_bound_cls_idx)  # (num_vec,)
        if lane_bound_mask.sum() == 0:
            # No boundaries exist → no penalty → full reward
            B, T, _ = ego_fut_preds.shape
            if reward_augment:
                # for these augmented trajs, we do not hope to give them strong reward
                return torch.full((B,), fill_value=self.energy / 2, device=ego_fut_preds.device)
            else:
                return torch.full((B,), fill_value=self.energy / 2, device=ego_fut_preds.device)
        lane_boundary = map_gt[0, lane_bound_mask] # [V, P, 2]
        lane_boundary = lane_boundary.unsqueeze(0)  # [B, V, P, 2] to simulate batch dim

        _, V, P, _ = lane_boundary.shape # batch_size, num_vectors, num_points
        B, T, _ = ego_fut_preds.size()   # num_frames

        # obtain the ego's trajectory
        # here we use two tensor with shape (B, T, 2) to denote the start point and end point from 1 to T
        ego_fut_preds = ego_fut_preds.cumsum(dim=-2) # (B, T, 2)
        
        ego_traj_expanded = ego_fut_preds.unsqueeze(2).unsqueeze(3)  # [B, T, 2] => [B, T, 1, 1, 2]
        lane_boundary_expanded = lane_boundary.unsqueeze(1)  # [B, V, P, 2] => [B, 1, V, P, 2]
        dist = torch.linalg.norm(ego_traj_expanded - lane_boundary_expanded, dim=-1)  # [B, T, V, P]

        cost = self.cost_function(dist, safe_thresh=self.dis_thresh, edge_cost=self.edge_cost) # (B, T)
        penalty_mask = self.boundary_collison_detection(dist, lane_boundary, ego_fut_preds, reward_augment) # (B, T)

        # for each data, if intersect with the lane, then we put strong penalty
        penalty_mask = penalty_mask.any(dim=-1, keepdim=True).repeat(1, T) # [B, T]
        cost[penalty_mask] = 1

        reward = self.energy - cost.mean(dim=-1) # [B,]

        return reward


def segments_intersect(line1_start, line1_end, line2_start, line2_end):
    # Calculating the differences
    dx1 = line1_end[:, 0] - line1_start[:, 0]
    dy1 = line1_end[:, 1] - line1_start[:, 1]
    dx2 = line2_end[:, 0] - line2_start[:, 0]
    dy2 = line2_end[:, 1] - line2_start[:, 1]

    # Calculating determinants
    det = dx1 * dy2 - dx2 * dy1
    det_mask = (det != 0)

    # Checking if lines are parallel or coincident
    parallel_mask = torch.logical_not(det_mask)

    # Calculating intersection parameters
    t1 = ((line2_start[:, 0] - line1_start[:, 0]) * dy2 
          - (line2_start[:, 1] - line1_start[:, 1]) * dx2) / det
    t2 = ((line2_start[:, 0] - line1_start[:, 0]) * dy1 
          - (line2_start[:, 1] - line1_start[:, 1]) * dx1) / det

    # Checking intersection conditions
    intersect_mask = torch.logical_and(
        torch.logical_and(t1 >= 0, t1 <= 1),
        torch.logical_and(t2 >= 0, t2 <= 1)
    )

    # Handling parallel or coincident lines
    intersect_mask[parallel_mask] = False

    return intersect_mask


class IntrinsicMapBoundConstrain(nn.Module):
    """Constraint module that penalizes the ego trajectory for approaching or intersecting lane boundaries

    Args:
        reduction (str, optional): The method to reduce the loss.
            Options are "none", "mean" and "sum".
        loss_weight (float, optional): The weight of loss.
        lane_bound_cls_idx (float, optional): lane_boundary class index.
        dis_thresh (float, optional): distance threshold between ego vehicle and lane bound.
        point_cloud_range (list, optional): point cloud range.
    """

    def __init__(
        self,
        energy=1.0,
        lane_bound_cls_idx=2,
        dis_thresh=1.0,
        edge_cost=0.05,
        point_cloud_range=[-15.0, -30.0, -2.0, 15.0, 30.0, 2.0],
        perception_detach=False
        ):
        super(IntrinsicMapBoundConstrain, self).__init__()
        self.energy = energy
        self.lane_bound_cls_idx = lane_bound_cls_idx
        self.dis_thresh = dis_thresh
        self.edge_cost = edge_cost
        self.pc_range = point_cloud_range
        self.perception_detach = perception_detach

    def cost_function(self, dist, safe_thresh, edge_cost=0.05):
        """
        Args:
            dist: Tensor [B, V, P] - distance to each agent at each frame
            safe_thresh: scalar
        Returns:
            cost: Tensor [B]
        """

        dist = dist.flatten(1, 2) # [B, V, P] => [B, V*P]
        min_dist, _ = torch.min(dist, dim=-1) # [B, V*P] => [B,]

        B, _ = dist.shape
        device = dist.device
        cost = torch.where(
            min_dist <= safe_thresh,
            cost_inside(min_dist, safe_thresh=safe_thresh, edge_cost=edge_cost), # [B,]
            cost_outside(min_dist, safe_thresh=safe_thresh, scale=abs(edge_cost)) # [B,]
        )

        return cost
   
    def forward(
            self,
            map_gt,
            map_type_gt,
        ):
        """Forward function.

        Args:
            map_gt (Tensor): [B, num_vec, num_pts, 2]
            map_type_gt (Tensor): [B, num_vec, 3]
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
        """     
        # filter lane element according to confidence score and class
        # ['divider','ped_crossing','boundary']
        assert map_gt.size(0) == 1, "Only batch_size=1 is supported in IntrinsicMapBoundConstrain."

        lane_bound_mask = (map_type_gt[0] == self.lane_bound_cls_idx) # (num_vec,)
        if lane_bound_mask.sum() == 0:
            B, _, _, _ = map_gt.shape
            return torch.full((B,), fill_value=self.energy/2, device=map_gt.device)
        
        lane_boundary = map_gt[0, lane_bound_mask] # [V, P, 2]
        lane_boundary = lane_boundary.unsqueeze(0)  # [B, V, P, 2] to simulate batch dim

        B, V, P, _ = lane_boundary.shape # batch_size, num_vectors, num_points

        dist = torch.linalg.norm(lane_boundary, dim=-1)  # [B, V, P]

        cost = self.cost_function(dist, safe_thresh=self.dis_thresh, edge_cost=self.edge_cost) # [B,]

        reward = self.energy - cost # [B,]

        return reward


class CriticCollisionConstrain(nn.Module):
    """Collision constraint to push ego vehicle away from other agents during planning horizon."""

    def __init__(
        self,
        energy=1.0,
        dis_thresh=2,
        edge_cost=0.05,
        point_cloud_range=[-15.0, -30.0, -2.0, 15.0, 30.0, 2.0]
    ):
        super().__init__()
        self.energy = energy
        self.dis_thresh = dis_thresh
        self.edge_cost = edge_cost
        self.pc_range = point_cloud_range

    def cost_function(self, dist, safe_thresh, edge_cost=0.05, valid_mask=None):
        """
        Args:
            dist: Tensor [B, N, T] - distance to each agent at each frame
            safe_thresh: scalar
        Returns:
            cost: Tensor [B, T]
        """
        B, _, T = dist.shape
        device = dist.device

        # We first let all invalid items get the lowest cost (0), which means they do not participate in the cost computation
        # if there is no items, return zero as cost
        dist = dist + ~valid_mask*1e6
        valid_num = valid_mask.sum(dim=1).clamp(min=1)

        inside_mask = dist < safe_thresh
        cost = torch.where(
            inside_mask,
            cost_inside(dist, safe_thresh, edge_cost),
            cost_outside(dist, safe_thresh, edge_cost)
        ) # [B, N, T]

        exist_inside_mask = inside_mask.any(dim=1) # [B, T]
        cost = torch.where(
            exist_inside_mask,
            cost.max(dim=1)[0], # [B, T]
            cost.sum(dim=1) / valid_num #[B, T]
        )

        cost = cost.max(dim=-1)[0] #[B,]

        return cost

    def forward(
        self,
        ego_fut_preds,          # [B, T, 2]
        agent_cur_position,     # [B, N, 2]
        agent_type,             # [B, N]
        agent_fut_position,     # [B, N, T, 2]
        agent_fut_mask,         # [B, N, T]
    ):
        B, N, T = agent_fut_mask.shape
        if N == 0:
            return torch.full((B,), fill_value=self.energy / 2, device=ego_fut_preds.device)
        
        agent_cur_position = agent_cur_position
        agent_fut_position = agent_fut_position
        agent_fut_mask = agent_fut_mask

        # Efficiently zero out all future positions after the first invalid time step
        valid_mask = agent_fut_mask.cumprod(dim=2)  # 1 stays 1, but becomes 0 forever after first 0 # [B, N, T]
        agent_mask = (agent_type <= 4).unsqueeze(2) # [B, N, T]
        valid_mask = torch.logical_and(valid_mask, agent_mask)

        # Trajectories as cumulative deltas
        ego_fut_preds = ego_fut_preds.cumsum(dim=-2)                                               # [B, T, 2]
        agent_fut_position = agent_fut_position.cumsum(dim=2) + agent_cur_position[:, :, None, :]  # [B, N, T, 2]

        ego_fut_preds_transformation = torch.stack(
            [ego_fut_preds[:,:,0], ego_fut_preds[:,:,1] / 2],
            dim=2
        )  # [B, T, 2]

        agent_fut_position_transformation = torch.stack(
            [agent_fut_position[:,:,:,0], agent_fut_position[:,:,:,1] / 2],
            dim=3
        ) # [B, N, T, 2]

        # Filter agents too far from ego
        ego_fut_preds_transformation = ego_fut_preds_transformation[:, None, :, :]  # [B, 1, T, 2]

        # Compute per-axis distances
        dist = torch.linalg.norm(ego_fut_preds_transformation - agent_fut_position_transformation, dim=-1)  # [B, N, T]

        # Apply threshold-based cost
        cost = self.cost_function(dist=dist, safe_thresh=self.dis_thresh, edge_cost=self.edge_cost, valid_mask=valid_mask) # (B, T)

        reward = self.energy - cost

        return reward


class IntrinsicCollisionConstrain(nn.Module):
    def __init__(
        self,
        energy=1.0,
        dis_thresh=2,
        edge_cost=0.05,
        point_cloud_range=[-15.0, -30.0, -2.0, 15.0, 30.0, 2.0]
    ):
        super().__init__()
        self.energy = energy
        self.dis_thresh = dis_thresh
        self.edge_cost = edge_cost
        self.pc_range = point_cloud_range

    def cost_function(self, dist, safe_thresh, edge_cost=0.05, agent_mask=None):
        """
        Args:
            dist: Tensor [B, N] - distance to each agent
            safe_thresh: scalar
        Returns:
            cost: Tensor [B]
        """
        B = dist.shape[0]

        dist = dist + ~agent_mask*1e6
        num_agent = agent_mask.sum(dim=-1).clamp(min=1) # [B,]
        
        # first calculate cost for each agent
        inside_mask = dist < safe_thresh # [B, N]
        cost = torch.where(
            inside_mask,
            cost_inside(dist, safe_thresh, edge_cost),
            cost_outside(dist, safe_thresh, edge_cost)
        )
        # then aggregation the cost at one scenario
        exist_inside_mask = inside_mask.any(dim=-1) # [B,]
        cost = torch.where(
            exist_inside_mask,
            cost.max(dim=-1)[0],
            cost.sum(dim=-1) / num_agent
        )

        return cost

    def forward(self, agent_cur_position, agent_type=None):
        """
        Args:
            agent_cur_position: Tensor [B, N, 2] — relative positions
            agent_type: Tensor [B, N] — unused
        Returns:
            reward: Tensor [B] or scalar
        """
        B, N, _ = agent_cur_position.shape
        if N == 0:
            return torch.full((B,), fill_value=self.energy / 2, device=agent_type.device)
        
        agent_mask = (agent_type <= 4) # [B, N]

        # here, we transform y axis to y/2
        agent_cur_position_transform = torch.stack([agent_cur_position[:,:,0], agent_cur_position[:,:,1] / 2], dim=2) # [B, N, 2]

        dist = torch.linalg.norm(agent_cur_position_transform, dim=-1)  # [B, N]

        cost = self.cost_function(dist, self.dis_thresh, self.edge_cost, agent_mask)

        reward = self.energy - cost   # [B]

        return reward

class CriticImitationConstrain(nn.Module):
    """Imitation constraint to push ego vehicle mimic the expert drivers' trajectory."""

    def __init__(
        self,
    ):
        super().__init__()
    
    def reward_function(self, dist):
        '''
        dist (Tensor): (B, T)
        '''
        return torch.exp(-dist)

    def forward(
            self,
            ego_fut_preds,
            ego_fut_gt,
            ego_fut_masks
        ):
        '''
        ego_fut_preds (Tensor): [B, T, 2]
        ego_fut_gt (Tensor): [B, T, 2]
        ego_fut_masks (Tensor): [B, T]
        '''

        dist = torch.linalg.norm(ego_fut_preds - ego_fut_gt, dim=-1) # (B, T)
        reward = self.reward_function(dist) # (B, T)
        reward = reward * ego_fut_masks # (B, T)
        valid_count = ego_fut_masks.sum(dim=-1).clamp(min=1)
        reward = reward.sum(dim=-1) / valid_count #(B,)
        return reward


class CriticEndPointConstrain(nn.Module):
    """EndPoint constraint to push ego vehicle go to the end point."""

    def __init__(
        self,
    ):
        super().__init__()
    
    def reward_function(self, dist, delta=1.5):
        '''dist (Tensor): [B,]'''
        return torch.where(
            dist < delta,
            torch.ones_like(dist),
            delta / dist
        )

    def forward(
            self,
            ego_fut_preds,
            ego_fut_gt,
            ego_fut_masks
        ):
        '''
        ego_fut_preds (Tensor): [B, T, 2]
        ego_fut_gt (Tensor): [B, T, 2]
        ego_fut_masks (Tensor): [B, T]
        '''
        end_point_gt = ego_fut_preds.cumsum(dim=-2)[:,-1,:] #[B, 2]
        end_point_pred = ego_fut_gt.cumsum(dim=-2)[:,-1,:] #[B, 2]
        dist = torch.linalg.norm(end_point_gt - end_point_pred, dim=-1) # [B,]
        reward = self.reward_function(dist) # [B,]
        ego_fut_masks = ego_fut_masks.all(dim=-1) #[B,]
        reward[~ego_fut_masks] = 0.5

        return reward


@LOSSES.register_module()
class PlanMapDirectionLoss(nn.Module):
    """Planning loss to force the ego heading angle consistent with lane direction.

    Args:
        reduction (str, optional): The method to reduce the loss.
            Options are "none", "mean" and "sum".
        loss_weight (float, optional): The weight of loss.
        theta_thresh (float, optional): angle diff thresh between ego and lane.
        point_cloud_range (list, optional): point cloud range.
    """

    def __init__(
        self,
        reduction='mean',
        loss_weight=1.0,
        map_thresh=0.5,
        dis_thresh=2.0,
        lane_div_cls_idx=0,
        point_cloud_range = [-15.0, -30.0, -2.0, 15.0, 30.0, 2.0]
    ):
        super(PlanMapDirectionLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.map_thresh = map_thresh
        self.dis_thresh = dis_thresh
        self.lane_div_cls_idx = lane_div_cls_idx
        self.pc_range = point_cloud_range

    def forward(self,
                ego_fut_preds,
                lane_preds,
                lane_score_preds,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        """Forward function.

        Args:
            ego_fut_preds (Tensor): [B, fut_ts, 2]
            lane_preds (Tensor): [B, num_vec, num_pts, 2]
            lane_score_preds (Tensor): [B, num_vec, 3]
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None.
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        # filter lane element according to confidence score and class
        not_lane_div_mask = lane_score_preds[..., self.lane_div_cls_idx] < self.map_thresh
        # denormalize map pts
        lane_div_preds = lane_preds.clone()
        # lane_div_preds[...,0:1] = (lane_div_preds[..., 0:1] * (self.pc_range[3] -
        #                         self.pc_range[0]) + self.pc_range[0])
        # lane_div_preds[...,1:2] = (lane_div_preds[..., 1:2] * (self.pc_range[4] -
        #                         self.pc_range[1]) + self.pc_range[1])
        # pad not-lane-divider cls and low confidence preds
        #lane_div_preds[not_lane_div_mask] = 1e6

        loss_bbox = self.loss_weight * plan_map_dir_loss(ego_fut_preds, lane_div_preds,
                                                           weight=weight, dis_thresh=self.dis_thresh,
                                                           reduction=reduction, avg_factor=avg_factor)
        return loss_bbox


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def plan_map_dir_loss(pred, target, dis_thresh=2.0):
    """Planning ego-map directional loss.

    Args:
        pred (torch.Tensor): ego_fut_preds, [B, fut_ts, 2].
        target (torch.Tensor): lane_div_preds, [B, num_vec, num_pts, 2].
        weight (torch.Tensor): [B, fut_ts]

    Returns:
        torch.Tensor: Calculated loss [B, fut_ts]
    """
    num_map_pts = target.shape[2]
    pred = pred.cumsum(dim=-2)
    traj_dis = torch.linalg.norm(pred[:, -1, :] - pred[:, 0, :], dim=-1)
    static_mask = traj_dis < 1.0
    target = target.unsqueeze(1).repeat(1, pred.shape[1], 1, 1, 1)

    # find the closest map instance for ego at each timestamp
    dist = torch.linalg.norm(pred[:, :, None, None, :] - target, dim=-1)
    dist = dist.min(dim=-1, keepdim=False)[0]
    min_inst_idxs = torch.argmin(dist, dim=-1).tolist()
    batch_idxs = [[i] for i in range(dist.shape[0])]
    ts_idxs = [[i for i in range(dist.shape[1])] for j in range(dist.shape[0])]
    target_map_inst = target[batch_idxs, ts_idxs, min_inst_idxs]  # [B, fut_ts, num_pts, 2]

    # calculate distance
    dist = torch.linalg.norm(pred[:, :, None, :] - target_map_inst, dim=-1)
    min_pts_idxs = torch.argmin(dist, dim=-1)
    min_pts_next_idxs = min_pts_idxs.clone()
    is_end_point = (min_pts_next_idxs == num_map_pts-1)
    not_end_point = (min_pts_next_idxs != num_map_pts-1)
    min_pts_next_idxs[is_end_point] = num_map_pts - 2
    min_pts_next_idxs[not_end_point] = min_pts_next_idxs[not_end_point] + 1
    min_pts_idxs = min_pts_idxs.tolist()
    min_pts_next_idxs = min_pts_next_idxs.tolist()
    traj_yaw = torch.atan2(torch.diff(pred[..., 1]), torch.diff(pred[..., 0]))  # [B, fut_ts-1]
    # last ts yaw assume same as previous
    traj_yaw = torch.cat([traj_yaw, traj_yaw[:, [-1]]], dim=-1)  # [B, fut_ts]
    min_pts = target_map_inst[batch_idxs, ts_idxs, min_pts_idxs]
    dist = torch.linalg.norm(min_pts - pred, dim=-1)
    dist_mask = dist > dis_thresh
    min_pts = min_pts.unsqueeze(2)
    min_pts_next = target_map_inst[batch_idxs, ts_idxs, min_pts_next_idxs].unsqueeze(2)
    map_pts = torch.cat([min_pts, min_pts_next], dim=2)
    lane_yaw = torch.atan2(torch.diff(map_pts[..., 1]).squeeze(-1), torch.diff(map_pts[..., 0]).squeeze(-1))  # [B, fut_ts]
    yaw_diff = traj_yaw - lane_yaw
    yaw_diff[yaw_diff > math.pi] = yaw_diff[yaw_diff > math.pi] - math.pi
    yaw_diff[yaw_diff > math.pi/2] = yaw_diff[yaw_diff > math.pi/2] - math.pi
    yaw_diff[yaw_diff < -math.pi] = yaw_diff[yaw_diff < -math.pi] + math.pi
    yaw_diff[yaw_diff < -math.pi/2] = yaw_diff[yaw_diff < -math.pi/2] + math.pi
    yaw_diff[dist_mask] = 0  # loss = 0 if no lane around ego
    yaw_diff[static_mask] = 0  # loss = 0 if ego is static

    loss = math.pi/2 - torch.abs(yaw_diff)
    # reward 
    #loss = torch.abs(yaw_diff)


    return loss  # [B, fut_ts]
