import math
import torch
from utils.earth_computation import haversine_distance

def ensure_tensor(tensor, device=None):
    """Ensure input is a tensor and move to the specified device"""
    if not isinstance(tensor, torch.Tensor):
        tensor = torch.tensor(tensor, dtype=torch.float32)
    if device is not None:
        tensor = tensor.to(device)
    return tensor

def euclidean_distance(p1, p2):
    """Calculate Euclidean distance between two points (supports tensors)"""
    p1 = ensure_tensor(p1)
    p2 = ensure_tensor(p2)
    
    # Ensure on the same device
    if p1.device != p2.device:
        p2 = p2.to(p1.device)
    
    return torch.sqrt(torch.sum((p1 - p2) ** 2))

def batch_euclidean_distance(p1, p2):
    """
    Batch calculate Euclidean distances
    Parameters:
    p1: torch.Tensor, shape (batch_size, n, 2)
    p2: torch.Tensor, shape (batch_size, m, 2)
    Returns:
    torch.Tensor: shape (batch_size, n, m)
    """
    p1 = ensure_tensor(p1)
    p2 = ensure_tensor(p2)
    
    if p1.device != p2.device:
        p2 = p2.to(p1.device)
    
    # Expand dimensions for broadcast calculation
    p1_expanded = p1.unsqueeze(2)  # (batch_size, n, 1, 2)
    p2_expanded = p2.unsqueeze(1)  # (batch_size, 1, m, 2)
    
    # Calculate Euclidean distances between all point pairs
    distances = torch.sqrt(torch.sum((p1_expanded - p2_expanded) ** 2, dim=-1))
    return distances

def frechet_distance(gt_trajectory, pred_trajectory):
    """
    Calculate discrete Fréchet distance between two trajectories (supports batch)
    
    Parameters:
    gt_trajectory: torch.Tensor, shape (batch_size, m, 2) or (m, 2)
    pred_trajectory: torch.Tensor, shape (batch_size, n, 2) or (n, 2)
    
    Returns:
    torch.Tensor: Fréchet distance, shape (batch_size,) or scalar
    """
    gt_trajectory = ensure_tensor(gt_trajectory)
    pred_trajectory = ensure_tensor(pred_trajectory)
    
    # Ensure on the same device
    if gt_trajectory.device != pred_trajectory.device:
        pred_trajectory = pred_trajectory.to(gt_trajectory.device)
    
    # Handle single trajectory case
    if gt_trajectory.dim() == 2:
        gt_trajectory = gt_trajectory.unsqueeze(0)
    if pred_trajectory.dim() == 2:
        pred_trajectory = pred_trajectory.unsqueeze(0)
    
    batch_size = gt_trajectory.shape[0]
    m, n = gt_trajectory.shape[1], pred_trajectory.shape[1]
    
    # Batch calculate all point pair distance matrices
    point_distances = batch_euclidean_distance(gt_trajectory, pred_trajectory)  # (batch_size, m, n)
    
    # Batch dynamic programming to calculate Fréchet distance
    distance_matrix = torch.zeros((batch_size, m, n), device=gt_trajectory.device)
    
    # Initialize first element
    distance_matrix[:, 0, 0] = point_distances[:, 0, 0]
    
    # Initialize first row
    for j in range(1, n):
        distance_matrix[:, 0, j] = torch.max(
            distance_matrix[:, 0, j-1], 
            point_distances[:, 0, j]
        )
    
    # Initialize first column
    for i in range(1, m):
        distance_matrix[:, i, 0] = torch.max(
            distance_matrix[:, i-1, 0], 
            point_distances[:, i, 0]
        )
    
    # Fill remaining part
    for i in range(1, m):
        for j in range(1, n):
            min_prev = torch.min(torch.stack([
                distance_matrix[:, i-1, j],
                distance_matrix[:, i, j-1],
                distance_matrix[:, i-1, j-1]
            ], dim=0), dim=0)[0]  # Take minimum along dimension 0
            
            distance_matrix[:, i, j] = torch.max(min_prev, point_distances[:, i, j])
    
    return distance_matrix[:, m-1, n-1].squeeze()

def calculate_initial_bearing(lon1, lat1, lon2, lat2):
    """
    Calculate initial bearing from point1 to point2 (batch version, radian input/output)
    
    Parameters:
    lon1, lat1, lon2, lat2: radian coordinates, supports tensors of any identical shape
    
    Returns:
    torch.Tensor: bearing angle, radians, range [-π, π]
    """
    dlon = lon2 - lon1
    
    x = torch.sin(dlon) * torch.cos(lat2)
    y = torch.cos(lat1) * torch.sin(lat2) - torch.sin(lat1) * torch.cos(lat2) * torch.cos(dlon)
    
    initial_bearing = torch.atan2(x, y)
    return initial_bearing

def calculate_curvature_lonlat(trajectory):
    """
    Calculate curvature of longitude-latitude trajectory (batch version, radian input)
    
    Parameters:
    trajectory: torch.Tensor, shape (batch_size, n, 2), [lon, lat] in radians
    
    Returns:
    torch.Tensor: curvature values at each point (1/meter), shape (batch_size, n), endpoints set to 0
    """
    trajectory = ensure_tensor(trajectory)
    
    # Handle single trajectory case
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n, _ = trajectory.shape
    
    # Initialize curvature tensor
    curvatures = torch.zeros((batch_size, n), device=trajectory.device)
    
    # Batch calculate curvature for all intermediate points
    for i in range(1, n-1):
        # Get three consecutive points
        p_prev = trajectory[:, i-1, :]  # (batch_size, 2)
        p_curr = trajectory[:, i, :]    # (batch_size, 2)
        p_next = trajectory[:, i+1, :]  # (batch_size, 2)
        
        lon_prev, lat_prev = p_prev[:, 0], p_prev[:, 1]
        lon_curr, lat_curr = p_curr[:, 0], p_curr[:, 1]
        lon_next, lat_next = p_next[:, 0], p_next[:, 1]
        
        # Batch calculate bearing angles for two segments
        bearing1 = calculate_initial_bearing(lon_prev, lat_prev, lon_curr, lat_curr)
        bearing2 = calculate_initial_bearing(lon_curr, lat_curr, lon_next, lat_next)
        
        # Calculate bearing change (automatically handles radian wrapping)
        bearing_diff = bearing2 - bearing1
        
        # Normalize angle difference to [-π, π] range
        bearing_diff = torch.atan2(torch.sin(bearing_diff), torch.cos(bearing_diff))
        
        # Batch calculate lengths of two segments
        dist1 = haversine_distance(lon_prev, lat_prev, lon_curr, lat_curr)
        dist2 = haversine_distance(lon_curr, lat_curr, lon_next, lat_next)
        
        # Use average distance as denominator
        avg_distance = (dist1 + dist2) / 2
        
        # Avoid division by zero
        mask = avg_distance < 1e-10
        curvatures[:, i] = torch.where(
            mask,
            torch.tensor(0.0, device=trajectory.device),
            bearing_diff / avg_distance
        )
    
    return curvatures.squeeze()

def calculate_curvature_turning_radius(trajectory):
    """
    Calculate turning radius (based on arc formed by three points, batch version)
    
    Parameters:
    trajectory: torch.Tensor, shape (batch_size, n, 2), [lon, lat] in radians
    
    Returns:
    torch.Tensor: turning radius in meters, shape (batch_size, n)
    """
    trajectory = ensure_tensor(trajectory)
    
    # Handle single trajectory case
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n, _ = trajectory.shape
    
    # Initialize turning radius tensor (infinity indicates straight line)
    turning_radii = torch.full((batch_size, n), float('inf'), device=trajectory.device)
    
    for i in range(1, n-1):
        p1 = trajectory[:, i-1, :]  # (batch_size, 2)
        p2 = trajectory[:, i, :]    # (batch_size, 2)
        p3 = trajectory[:, i+1, :]  # (batch_size, 2)
        
        # Batch calculate distances between three points
        a = haversine_distance(p1[:, 0], p1[:, 1], p2[:, 0], p2[:, 1])  # p1-p2
        b = haversine_distance(p2[:, 0], p2[:, 1], p3[:, 0], p3[:, 1])  # p2-p3
        c = haversine_distance(p1[:, 0], p1[:, 1], p3[:, 0], p3[:, 1])  # p1-p3
        
        # Use Heron's formula to calculate triangle area
        s = (a + b + c) / 2
        
        # Check triangle validity
        area_sq = s * (s - a) * (s - b) * (s - c)
        valid_mask = area_sq > 1e-20
        
        # Calculate area only for valid triangles
        area = torch.zeros_like(a)
        area[valid_mask] = torch.sqrt(area_sq[valid_mask])
        
        # Calculate turning radius R = (a*b*c) / (4*area)
        radius_mask = (area > 1e-10) & valid_mask
        turning_radius = torch.full_like(a, float('inf'))
        turning_radius[radius_mask] = (a[radius_mask] * b[radius_mask] * c[radius_mask]) / (4 * area[radius_mask])
        
        turning_radii[:, i] = turning_radius
    
    return turning_radii.squeeze()

def curvature_to_turning_radius(curvatures):
    """
    Convert curvature to turning radius (batch version)
    
    Parameters:
    curvatures: torch.Tensor, curvature values, shape (batch_size, n) or (n,)
    
    Returns:
    torch.Tensor: turning radius in meters
    """
    curvatures = ensure_tensor(curvatures)
    
    # Handle single trajectory case
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    
    # Create mask to avoid division by zero
    mask = torch.abs(curvatures) < 1e-10
    radii = torch.where(mask, 
                       torch.tensor(float('inf'), device=curvatures.device), 
                       1.0 / torch.abs(curvatures))
    
    return radii.squeeze()

def smooth_curvature(curvatures, window_size=5):
    """
    Moving average smoothing of curvature data (batch version)
    
    Parameters:
    curvatures: torch.Tensor, raw curvature, shape (batch_size, n) or (n,)
    window_size: sliding window size
    
    Returns:
    torch.Tensor: smoothed curvature
    """
    curvatures = ensure_tensor(curvatures)
    
    # Handle single trajectory case
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    
    batch_size, seq_len = curvatures.shape
    
    if seq_len < window_size:
        return curvatures.squeeze()
    
    # Use 1D convolution to implement moving average
    kernel = torch.ones(1, 1, window_size, device=curvatures.device) / window_size
    
    # Pad and convolve each batch
    padded = torch.nn.functional.pad(
        curvatures.unsqueeze(1), 
        (window_size//2, window_size//2), 
        mode='replicate'
    )  # (batch_size, 1, seq_len + window_size - 1)
    
    # Convolution operation
    smoothed = torch.nn.functional.conv1d(
        padded, 
        kernel, 
        padding=0
    ).squeeze(1)  # (batch_size, seq_len)
    
    return smoothed.squeeze()

def analyze_curvature_features(curvatures, trajectory):
    """
    Analyze curvature features (batch version)
    
    Parameters:
    curvatures: torch.Tensor, curvature values, shape (batch_size, n)
    trajectory: torch.Tensor, trajectory data, shape (batch_size, n, 2)
    
    Returns:
    dict: curvature feature statistics, each value shape (batch_size,)
    """
    curvatures = ensure_tensor(curvatures)
    trajectory = ensure_tensor(trajectory)
    
    # Handle single trajectory case
    if curvatures.dim() == 1:
        curvatures = curvatures.unsqueeze(0)
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, n = curvatures.shape
    
    # Filter valid curvatures (exclude endpoints and values close to zero)
    valid_curvatures = curvatures[:, 1:-1]  # (batch_size, n-2)
    valid_mask = torch.abs(valid_curvatures) > 1e-10
    
    # Initialize results
    max_curvature = torch.zeros(batch_size, device=curvatures.device)
    mean_curvature = torch.zeros(batch_size, device=curvatures.device)
    std_curvature = torch.zeros(batch_size, device=curvatures.device)
    sharp_turns_count = torch.zeros(batch_size, device=curvatures.device, dtype=torch.long)
    straight_segments = torch.zeros(batch_size, device=curvatures.device, dtype=torch.long)
    
    for i in range(batch_size):
        batch_valid_mask = valid_mask[i]
        batch_valid_curvatures = valid_curvatures[i][batch_valid_mask]
        
        if batch_valid_curvatures.numel() == 0:
            max_curvature[i] = 0.0
            mean_curvature[i] = 0.0
            std_curvature[i] = 0.0
            sharp_turns_count[i] = 0
            straight_segments[i] = n
            continue
        
        max_curvature[i] = torch.max(torch.abs(batch_valid_curvatures))
        mean_curvature[i] = torch.mean(batch_valid_curvatures)
        std_curvature[i] = torch.std(batch_valid_curvatures)
        
        # Count sharp turns
        sharp_turn_threshold = 0.001  # 1/1000 meter
        sharp_turns_count[i] = torch.sum(torch.abs(batch_valid_curvatures) > sharp_turn_threshold)
        
        # Count straight segments
        straight_threshold = 0.0001  # 1/10000 meter
        straight_segments[i] = torch.sum(torch.abs(curvatures[i]) < straight_threshold)
    
    return {
        'max_curvature': max_curvature,
        'mean_curvature': mean_curvature,
        'std_curvature': std_curvature,
        'sharp_turns_count': sharp_turns_count,
        'straight_segments': straight_segments
    }

def curvature_calculation(trajectory, window_size=5):
    """
    Process long trajectory curvature calculation, including smoothing (batch version)
    
    Parameters:
    trajectory: torch.Tensor, shape (batch_size, 288, 2), trajectory data [lon, lat] in radians
    window_size: smoothing window size
    
    Returns:
    dict: contains raw curvature, smoothed curvature, and feature analysis
    """
    trajectory = ensure_tensor(trajectory)
    
    # Handle single trajectory case
    if trajectory.dim() == 2:
        trajectory = trajectory.unsqueeze(0)
    
    batch_size, seq_len, _ = trajectory.shape
    
    # Calculate raw curvature
    raw_curvatures = calculate_curvature_lonlat(trajectory)
    
    # Smoothing
    smoothed_curvatures = smooth_curvature(raw_curvatures, window_size)
    
    return {
        'raw_curvatures': raw_curvatures,
        'smoothed_curvatures': smoothed_curvatures,
        # 'segment_features': segment_features,
        # 'overall_features': analyze_curvature_features(smoothed_curvatures, trajectory)
    }

# Test example
if __name__ == "__main__":
    # Create test data (radian system)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    batch_size = 128
    n_points = 144
    
    # Generate batch example trajectories
    trajectories = []
    for b in range(batch_size):
        # Generate different trajectory for each batch
        start_lon = 2.12 + 0.01 * torch.rand(1, device=device).item()
        end_lon = 2.18 + 0.01 * torch.rand(1, device=device).item()
        start_lat = 0.545 + 0.01 * torch.rand(1, device=device).item()
        end_lat = 0.567 + 0.01 * torch.rand(1, device=device).item()
        
        lons_rad = torch.linspace(start_lon, end_lon, n_points, device=device)
        lats_rad = torch.linspace(start_lat, end_lat, n_points, device=device)
        
        # Add some curvature
        for i in range(50, 100):
            lats_rad[i] += 0.001 * torch.sin(torch.tensor((i-50)*0.063, device=device))
        
        trajectory = torch.stack([lons_rad, lats_rad], dim=1)
        trajectories.append(trajectory)
    
    trajectories = torch.stack(trajectories)  # (batch_size, n_points, 2)
    
    print(f"Trajectory shape: {trajectories.shape}")
    print(f"Trajectory device: {trajectories.device}")
    
    # Test batch curvature calculation
    result = curvature_calculation(trajectories)
    
    print("\nCurvature analysis results (batch):")
    features = result['overall_features']
    for key, value in features.items():
        if value.dim() == 0:
            print(f"{key}: {value.item():.6f}")
        else:
            print(f"{key}: shape {value.shape}, first 5 values: {value[:5].cpu().numpy()}")
    
    # Test batch Fréchet distance
    pred_trajectories = trajectories + torch.randn_like(trajectories) * 0.001
    fd = frechet_distance(trajectories, pred_trajectories)
    print(f"\nFrétchet distance: shape {fd.shape}, first 5 values: {fd[:5].cpu().numpy()}")
    
    # Test batch turning radius calculation
    turning_radii = calculate_curvature_turning_radius(trajectories)
    print(f"\nTurning radius: shape {turning_radii.shape}")
    
    # Calculate average turning radius (excluding infinity)
    finite_mask = turning_radii != float('inf')
    avg_turning_radii = torch.zeros(batch_size, device=device)
    for b in range(batch_size):
        batch_finite = finite_mask[b]
        if batch_finite.any():
            avg_turning_radii[b] = turning_radii[b][batch_finite].mean()
        else:
            avg_turning_radii[b] = float('inf')
    
    print(f"First 5 average turning radii: {avg_turning_radii[:5].cpu().numpy()}")