import math
import torch
from utils.earth_computation import haversine_distance

def ensure_tensor(tensor, device=None):
    """Make sure the input is a tensor"""
    if not isinstance(tensor, torch.Tensor):
        tensor = torch.tensor(tensor, dtype=torch.float32)
    if device is not None:
        tensor = tensor.to(device)
    return tensor

def euclidean_distance(p1, p2):
    """Calculate the Euclidean distance between two points"""
    p1 = ensure_tensor(p1)
    p2 = ensure_tensor(p2)
    
    if p1.device != p2.device:
        p2 = p2.to(p1.device)
    
    return torch.sqrt(torch.sum((p1 - p2) ** 2))

def batch_euclidean_distance(p1, p2):
    """
    Calculate the Euclidean distance
    Parameters:
    p1: torch.Tensor, shape (batch_size, n, 2)
    p2: torch.Tensor, shape (batch_size, m, 2)
    Return:
    torch.Tensor: shape (batch_size, n, m)
    """
    p1 = ensure_tensor(p1)
    p2 = ensure_tensor(p2)
    
    if p1.device != p2.device:
        p2 = p2.to(p1.device)
    
    p1_expanded = p1.unsqueeze(2)  # (batch_size, n, 1, 2)
    p2_expanded = p2.unsqueeze(1)  # (batch_size, 1, m, 2)
    
    distances = torch.sqrt(torch.sum((p1_expanded - p2_expanded) ** 2, dim=-1))
    return distances

def frechet_distance(gt_trajectory, pred_trajectory):
    """
    Calculate the discrete Frechet distance between two trajectories (supporting batch)    
    Parameters:
    gt_trajectory: torch.Tensor, shape (batch_size, m, 2) or (m, 2)
    pred_trajectory: torch.Tensor, shape (batch_size, n, 2) or (n, 2)
    Returns:
    torch.Tensor: Frechet distance, shape (batch_size,) or scalar
    """
    gt_trajectory = ensure_tensor(gt_trajectory)
    pred_trajectory = ensure_tensor(pred_trajectory)
    
    if gt_trajectory.device != pred_trajectory.device:
        pred_trajectory = pred_trajectory.to(gt_trajectory.device)
    
    if gt_trajectory.dim() == 2:
        gt_trajectory = gt_trajectory.unsqueeze(0)
    if pred_trajectory.dim() == 2:
        pred_trajectory = pred_trajectory.unsqueeze(0)
    
    batch_size = gt_trajectory.shape[0]
    m, n = gt_trajectory.shape[1], pred_trajectory.shape[1]
    
    point_distances = batch_euclidean_distance(gt_trajectory, pred_trajectory)  # (batch_size, m, n)
    
    distance_matrix = torch.zeros((batch_size, m, n), device=gt_trajectory.device)
    
    distance_matrix[:, 0, 0] = point_distances[:, 0, 0]
    
    for j in range(1, n):
        distance_matrix[:, 0, j] = torch.max(
            distance_matrix[:, 0, j-1], 
            point_distances[:, 0, j]
        )
    
    for i in range(1, m):
        distance_matrix[:, i, 0] = torch.max(
            distance_matrix[:, i-1, 0], 
            point_distances[:, i, 0]
        )
    
    for i in range(1, m):
        for j in range(1, n):
            min_prev = torch.min(torch.stack([
                distance_matrix[:, i-1, j],
                distance_matrix[:, i, j-1],
                distance_matrix[:, i-1, j-1]
            ], dim=0), dim=0)[0]
            
            distance_matrix[:, i, j] = torch.max(min_prev, point_distances[:, i, j])
    
    return distance_matrix[:, m-1, n-1].squeeze()

def calculate_initial_bearing(lon1, lat1, lon2, lat2):
    '''
    Calculate the initial azimuth Angle from point 1 to point 2 (batch version, radian input and output)
    Parameter:
    lon1, lat1, lon2, lat2: Radian coordinates, supporting any tensor of the same shape
    Return:
    torch.Tensor: Azimuth, radian, range [-π, π]
    '''
    dlon = lon2 - lon1
    
    x = torch.sin(dlon) * torch.cos(lat2)
    y = torch.cos(lat1) * torch.sin(lat2) - torch.sin(lat1) * torch.cos(lat2) * torch.cos(dlon)
    
    initial_bearing = torch.atan2(x, y)
    return initial_bearing

def calculate_curvature_lonlat(trajectory):
    """
    Calculate the curvature of the longitude and latitude trajectory (batch version, radian input)
    Parameter
    trajectory: torch.Tensor, shape (batch_size, n, 2), [lon, lat] in radians
    Return
    torch.Tensor: The curvature value of each point (1/ meter), shape (batch_size, n), and set the first and last points to 0
    """
    trajectory = ensure_tensor(trajectory)
    
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n, _ = trajectory.shape
    
    curvatures = torch.zeros((batch_size, n), device=trajectory.device)
    
    for i in range(1, n-1):
        # Obtain three consecutive points
        p_prev = trajectory[:, i-1, :]  # (batch_size, 2)
        p_curr = trajectory[:, i, :]    # (batch_size, 2)
        p_next = trajectory[:, i+1, :]  # (batch_size, 2)
        
        lon_prev, lat_prev = p_prev[:, 0], p_prev[:, 1]
        lon_curr, lat_curr = p_curr[:, 0], p_curr[:, 1]
        lon_next, lat_next = p_next[:, 0], p_next[:, 1]
        
        # Batch calculate the azimuth angles of two line segments
        bearing1 = calculate_initial_bearing(lon_prev, lat_prev, lon_curr, lat_curr)
        bearing2 = calculate_initial_bearing(lon_curr, lat_curr, lon_next, lat_next)
        
        # Calculate the change in azimuth Angle (automatically handle radian surround)
        bearing_diff = bearing2 - bearing1
        
        # Normalize the Angle difference to the range of [-π, π]
        bearing_diff = torch.atan2(torch.sin(bearing_diff), torch.cos(bearing_diff))
        
        dist1 = haversine_distance(lon_prev, lat_prev, lon_curr, lat_curr)
        dist2 = haversine_distance(lon_curr, lat_curr, lon_next, lat_next)
        
        avg_distance = (dist1 + dist2) / 2
        
        mask = avg_distance < 1e-10
        curvatures[:, i] = torch.where(
            mask,
            torch.tensor(0.0, device=trajectory.device),
            bearing_diff / avg_distance
        )
    
    return curvatures.squeeze()

def calculate_curvature_turning_radius(trajectory):
    """
    Calculate the turning radius (based on the arc formed by three points, batch version)
    Parameter
    trajectory: torch.Tensor, shape (batch_size, n, 2), [lon, lat] in radians
    Return
    torch.Tensor: Turning radius, unit: meters, shape (batch_size, n)
    """
    trajectory = ensure_tensor(trajectory)
    
    # Handle the situation of a single trajectory
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n, _ = trajectory.shape
    
    # Initialize the turning radius tensor (infinity represents a straight line)
    turning_radii = torch.full((batch_size, n), float('inf'), device=trajectory.device)
    
    for i in range(1, n-1):
        p1 = trajectory[:, i-1, :]  # (batch_size, 2)
        p2 = trajectory[:, i, :]    # (batch_size, 2)
        p3 = trajectory[:, i+1, :]  # (batch_size, 2)
        
        # Batch calculate the distance between three points
        a = haversine_distance(p1[:, 0], p1[:, 1], p2[:, 0], p2[:, 1])  # p1-p2
        b = haversine_distance(p2[:, 0], p2[:, 1], p3[:, 0], p3[:, 1])  # p2-p3
        c = haversine_distance(p1[:, 0], p1[:, 1], p3[:, 0], p3[:, 1])  # p1-p3
        
        # Calculate the area of a triangle using Helen's formula
        s = (a + b + c) / 2
        
        # Check the validity of the triangle
        area_sq = s * (s - a) * (s - b) * (s - c)
        valid_mask = area_sq > 1e-20
        
        # Only calculate the area for valid triangles
        area = torch.zeros_like(a)
        area[valid_mask] = torch.sqrt(area_sq[valid_mask])
        
        # Calculate the turning radius R = (a*b*c)/(4*area)
        radius_mask = (area > 1e-10) & valid_mask
        turning_radius = torch.full_like(a, float('inf'))
        turning_radius[radius_mask] = (a[radius_mask] * b[radius_mask] * c[radius_mask]) / (4 * area[radius_mask])
        
        turning_radii[:, i] = turning_radius
    
    return turning_radii.squeeze()

def curvature_to_turning_radius(curvatures):
    """
    Convert curvature to turning Radius (Batch Version)
    Parameter
    curvatures: torch.Tensor, curvature value, shape (batch_size, n) or (n,)
    Return
    torch.Tensor: Turning radius, unit: meters
    """
    curvatures = ensure_tensor(curvatures)
    
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    
    mask = torch.abs(curvatures) < 1e-10
    radii = torch.where(mask, 
                       torch.tensor(float('inf'), device=curvatures.device), 
                       1.0 / torch.abs(curvatures))
    
    return radii.squeeze()

def smooth_curvature(curvatures, window_size=5):
    """
    Moving Average Smoothing Curvature Data (Batch Version)
    Parameter
    curvatures: torch.Tensor, original curvature, shape (batch_size, n) or (n,)
    window_size: Sliding window size
    Return
    torch.Tensor: The smoothed curvature
    """
    curvatures = ensure_tensor(curvatures)
    
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    
    batch_size, seq_len = curvatures.shape
    
    if seq_len < window_size:
        return curvatures.squeeze()
    
    # The moving average is achieved using one-dimensional convolution
    kernel = torch.ones(1, 1, window_size, device=curvatures.device) / window_size
    
    # Fill and convolve each batch
    padded = torch.nn.functional.pad(
        curvatures.unsqueeze(1), 
        (window_size//2, window_size//2), 
        mode='replicate'
    )  # (batch_size, 1, seq_len + window_size - 1)
    
    smoothed = torch.nn.functional.conv1d(
        padded, 
        kernel, 
        padding=0
    ).squeeze(1)  # (batch_size, seq_len)
    
    return smoothed.squeeze()

def analyze_curvature_features(curvatures, trajectory):
    """
    Analysis of Curvature Characteristics (Batch Version)
    Parameter
    curvatures: torch.Tensor, curvature value, shape (batch_size, n)
    trajectory: torch.Tensor, trajectory data, shape (batch_size, n, 2)
    Return
    dict: Curvature feature statistics, the shape of each value is (batch_size)
    """
    curvatures = ensure_tensor(curvatures)
    trajectory = ensure_tensor(trajectory)
    
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n = curvatures.shape
    
    valid_curvatures = curvatures[:, 1:-1]  # (batch_size, n-2)
    valid_mask = torch.abs(valid_curvatures) > 1e-10
    
    max_curvature = torch.zeros(batch_size, device=curvatures.device)
    mean_curvature = torch.zeros(batch_size, device=curvatures.device)
    std_curvature = torch.zeros(batch_size, device=curvatures.device)
    sharp_turns_count = torch.zeros(batch_size, device=curvatures.device, dtype=torch.long)
    straight_segments = torch.zeros(batch_size, device=curvatures.device, dtype=torch.long)
    
    for i in range(batch_size):
        batch_valid_mask = valid_mask[i]
        batch_valid_curvatures = valid_curvatures[i][batch_valid_mask]
        
        if batch_valid_curvatures.numel() == 0:
            max_curvature[i] = 0.0
            mean_curvature[i] = 0.0
            std_curvature[i] = 0.0
            sharp_turns_count[i] = 0
            straight_segments[i] = n
            continue
        
        max_curvature[i] = torch.max(torch.abs(batch_valid_curvatures))
        mean_curvature[i] = torch.mean(batch_valid_curvatures)
        std_curvature[i] = torch.std(batch_valid_curvatures)
        
        sharp_turn_threshold = 0.001
        sharp_turns_count[i] = torch.sum(torch.abs(batch_valid_curvatures) > sharp_turn_threshold)
        
        straight_threshold = 0.0001
        straight_segments[i] = torch.sum(torch.abs(curvatures[i]) < straight_threshold)
    
    return {
        'max_curvature': max_curvature,
        'mean_curvature': mean_curvature,
        'std_curvature': std_curvature,
        'sharp_turns_count': sharp_turns_count,
        'straight_segments': straight_segments
    }

def curvature_calculation(trajectory, window_size=5):
    '''
    Calculate the curvatures of the first point, middle point and last point of the trajectory, including smoothing processing (batch version)
    Parameter
    trajectory: torch.Tensor, shape (batch_size, 144, 2) or (144, 2), trajectory data [lon, lat] in radians
    window_size: Smooth the window size
    Return
    dict: Contains the 'smoothed_curvatures' key, which values the smooth curvature of three points in the shape of (batch_size, 3) or (3).
    '''

    trajectory = ensure_tensor(trajectory)
    
    # Handle the situation of a single trajectory
    is_single_trajectory = trajectory.dim() == 2
    if is_single_trajectory:
        trajectory = trajectory.unsqueeze(0)
    
    # Verify the input shape
    assert trajectory.dim() == 3, f"Expected trajectory shape (batch_size, 144, 2), got {trajectory.shape}"
    assert trajectory.shape[1:] == (144, 2), f"Expected trajectory shape (batch_size, 144, 2), got {trajectory.shape}"
    
    # Determine the indexes of the three points
    first_point = 0
    middle_point = 143 // 2
    last_point = 143
    selected_indices = [first_point, middle_point, last_point]
    
    # Calculate the original curvature of the complete trajectory
    raw_curvatures = calculate_curvature_lonlat(trajectory)
    
    # Smooth out the full curvature
    smoothed_curvatures = smooth_curvature(raw_curvatures, window_size)

    if smoothed_curvatures.dim() == 1:
        smoothed_curvatures = smoothed_curvatures.unsqueeze(0)
    
    # Extract the smooth curvature of three points - Select multiple positions using an index
    # For smoothed_curvatures with the shape of (batch_size, 144), select [batch, [idx1, idx2, idx3]]
    selected_smooth_curvatures = smoothed_curvatures[:, selected_indices]
    
    if is_single_trajectory:
        selected_smooth_curvatures = selected_smooth_curvatures.squeeze(0)
    
    return selected_smooth_curvatures
