import os
import sys
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, parent_dir)
import copy
import torch
import random
import math
import numpy as np
from scipy.spatial import distance
from shapely.geometry import LineString, Polygon, Point


def get_interactive_agents_mask(data, hist_steps, max_interaction_horizon):

    # get directly interactive agents
    bs, num_agent, T, _ = data['agent']['position'].shape
    interactive_agents_mask = np.zeros([bs, num_agent-1], dtype=bool)
    bs_valid_ego_future_mask = copy.deepcopy(data['agent']['valid_mask'][:,0,(hist_steps+1):]).cpu()
    bs_ego_heading = data['agent']['heading'][:,0,(hist_steps+1):].cpu()
    bs_ego_pos = data['agent']['position'][:,0,(hist_steps+1):].cpu()
    valid_agent_future_mask = copy.deepcopy(data['agent']['valid_mask'][:,1:,(hist_steps+1):]).cpu()
    agents_heading = data['agent']['heading'][:,1:,(hist_steps+1):].cpu()
    agents_pos = data['agent']['position'][:,1:,(hist_steps+1):].cpu()
    agents_shape = data['agent']['shape'][...,[1,0]].cpu()

    for bs_idx in range(bs):
        if not (True in bs_valid_ego_future_mask[bs_idx]):
            continue
        ego_pos = bs_ego_pos[bs_idx][bs_valid_ego_future_mask[bs_idx]]
        ego_heading = bs_ego_heading[bs_idx][bs_valid_ego_future_mask[bs_idx]]
        for agent_idx in range(num_agent-1):
            if not (True in valid_agent_future_mask[bs_idx, agent_idx]):
                continue

            distances = distance.cdist(
                (agents_pos[bs_idx, agent_idx]).reshape(-1, 2),
                ego_pos.reshape(-1, 2),
            )
            distances[~valid_agent_future_mask[bs_idx, agent_idx]] = 1e6
            min_t_agent, min_t_ego = np.unravel_index(distances.argmin(), distances.shape)
            min_dist = distances[min_t_agent, min_t_ego]

            if min_dist > 4:  # coarse distance judgement
                continue
            if min_t_ego < min_t_agent or min_t_ego - min_t_agent > max_interaction_horizon: 
                continue
            agent_bbox = generate_rotated_bbox_points(agents_pos[bs_idx,agent_idx,min_t_agent].reshape(-1, 2),agents_shape[bs_idx,agent_idx+1,min_t_agent].reshape(-1, 2),agents_heading[bs_idx,agent_idx,min_t_agent].reshape(-1, 1)).squeeze(0)
            ego_bbox = generate_rotated_bbox_points(ego_pos[min_t_ego].reshape(-1, 2),agents_shape[bs_idx,0,min_t_agent].reshape(-1, 2),ego_heading[min_t_ego].reshape(-1, 1)).squeeze(0)
            intersection = Polygon(ego_bbox).intersects(Polygon(agent_bbox))
            if intersection:
                interactive_agents_mask[bs_idx, agent_idx] = 1
    return torch.from_numpy(interactive_agents_mask).to(data['agent']['position'].device)

def generate_rotated_bbox_points(centers, size, angle):

    if centers.shape[0] == 0:
        return None
    half_lengths = size[:, 0] / 2.0  # Half lengths
    half_widths = size[:, 1] / 2.0   # Half widths
    n = centers.shape[0]
    bbox_points = np.empty((n, 4, 2))
    
    # Top-right corner
    bbox_points[:, 0, 0] = centers[:, 0] + half_lengths
    bbox_points[:, 0, 1] = centers[:, 1] + half_widths
    
    # Bottom-right corner
    bbox_points[:, 1, 0] = centers[:, 0] + half_lengths
    bbox_points[:, 1, 1] = centers[:, 1] - half_widths
    
    # Bottom-left corner
    bbox_points[:, 2, 0] = centers[:, 0] - half_lengths
    bbox_points[:, 2, 1] = centers[:, 1] - half_widths
    
    # Top-left corner
    bbox_points[:, 3, 0] = centers[:, 0] - half_lengths
    bbox_points[:, 3, 1] = centers[:, 1] + half_widths

    s = np.sin(angle)
    c = np.cos(angle)

    rotation_matrices = np.empty((n, 2, 2))
    
    rotation_matrices[:, 0, 0] = c
    rotation_matrices[:, 0, 1] = -s
    rotation_matrices[:, 1, 0] = s
    rotation_matrices[:, 1, 1] = c

    bbox_points_center_cosy = bbox_points - centers.numpy()[:, np.newaxis, :]

    rotated_points_center_cosy = np.einsum('nij,nkj->nki', rotation_matrices, bbox_points_center_cosy)
    rotated_points = rotated_points_center_cosy + centers.numpy()[:, np.newaxis, :]

    return rotated_points # size: n,4,2