import torch
import math
import numpy as np
import geopandas as gpd
from shapely.geometry import Point
from pyproj import Geod
import warnings

# Average radius of the Earth (unit: kilometers)
EARTH_RADIUS = 6371.0

def distance_to_coastline(lon, lat, coastline_path="ne_10m_coastline.shp", coastline_data = None):
    """
    Calculate the geodesic distance from the given longitude and latitude to the nearest coastline (unit: meters)
    Return: 
    Distance (meters), return None when processing fails
    """

    if coastline_data is not None:
        coastlines = coastline_data
    else:
        coastlines = gpd.read_file(coastline_path)
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            
            # coastlines = gpd.read_file(coastline_path)
            
            target_point = Point(lon, lat)
            
            # Use spatial indexing to find the nearest coastline featureS
            nearest_idx = coastlines.geometry.distance(target_point).idxmin()
            closest_segment = coastlines.geometry.iloc[nearest_idx]
    
        # Initialize the geodesic calculator
        geod = Geod(ellps="WGS84")
        min_distance = float('inf')
        
        if closest_segment.geom_type == 'MultiLineString':
            segments = list(closest_segment.geoms)
        else:
            segments = [closest_segment]
        
        # Calculate the shortest distance from the point to each line segment
        for line in segments:
            coords = np.array(line.coords)
            for i in range(len(coords) - 1):
                p1 = coords[i]
                p2 = coords[i+1]
                
                # Accurately calculate the geodesic distance from a point to a line segment
                segment_distance = _geodesic_point_to_segment(geod, lon, lat, p1, p2)
                
                if segment_distance < min_distance:
                    min_distance = segment_distance
        
        return min_distance
    
    except Exception as e:
        print(f"miscalculation: {e}")
        return None

def _geodesic_point_to_segment(geod, lon, lat, p1, p2, num_points=20):
    """
    Return: Minimum distance (meters)
    """
    line = [p1, p2]
    if num_points > 2:
        line = list(geod.npts(p1[0], p1[1], p2[0], p2[1], num_points-2))
        line = [p1] + line + [p2]
    
    min_dist = float('inf')
    for point in line:
        _, _, dist = geod.inv(lon, lat, point[0], point[1])
        if dist < min_dist:
            min_dist = dist
    
    return min_dist

def deg_to_rad(deg_tensor):
    """convert the degrees to radians"""
    return deg_tensor * (math.pi / 180.0)

def rad_to_deg(rad_tensor):
    """Convert the radian to the degree"""
    return rad_tensor * (180.0 / math.pi)

def deg_to_vec(deg):
    rad = np.radians(deg)
    unit_vec = np.stack((np.cos(rad), np.sin(rad)), axis=1)
    return unit_vec

def UnitConversion(input, dest_coord, fr='std', to='rad'):
    tensor1 = input.clone()
    if dest_coord != None: tensor2 = dest_coord.clone()

    if fr == 'std' and to == "rad":
        tensor1[:,:,0] *= math.pi/3
        tensor1[:,:,1] *= math.pi
        tensor1[:,:,2:4] *= 25
        if dest_coord != None:
            tensor2[:,:,0] *= math.pi/3
            tensor2[:,:,1] *= math.pi
        # return tensor
    elif fr == "rad" and to == "std":
        tensor1[:,:,0] /= math.pi/3
        tensor1[:,:,1] /= math.pi
        tensor1[:,:,2:4] /= 25
        if dest_coord != None:
            tensor2[:,:,0] /= math.pi/3
            tensor2[:,:,1] /= math.pi
        # return tensor
    elif fr == "std" and to == "deg":
        tensor1[:,:,0] *= 60
        tensor1[:,:,1] *= 180
        tensor1[:,:,2:4] *= 25
        if dest_coord != None:
            tensor2[:,:,0] *= 60
            tensor2[:,:,1] *= 180
        # return tensor
    elif fr == "deg" and to == "std":
        tensor1[:,:,0] /= 60
        tensor1[:,:,1] /= 180
        tensor1[:,:,2:4] /= 25
        if dest_coord != None:
            tensor2[:,:,0] /= 60
            tensor2[:,:,1] /= 180
        # return tensor
    else:
        raise ValueError(f"from {fr} and to {to} should be in [std, rad, degree]")
    if dest_coord != None:
        return tensor1.clone(), tensor2.clone()
    return tensor1.clone(), None

def haversine_distance(lat1, lon1, lat2, lon2):
    """
    Batch calculate the large circle distance (spherical distance) between two points
    Input: Latitude/longitude tensor (with the same shape
    Return: Distance tensor (kilometers), with the same shape as the input (except for the last dimension)
    """
    
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    
    a = torch.sin(dlat/2)**2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon/2)**2
    assert 0<=a.all()<=1, "a should between 0 and 1"
    c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a))
    
    return EARTH_RADIUS * c

def calculate_initial_bearing(lat1, lon1, lat2, lon2):
    """
    Batch calculate the initial azimuth from the starting point to the ending point (clockwise from north)
    Return: Azimuth (radians), the shape is the same as the input (except for the last dimension)
    """
    
    dlon = lon2 - lon1
    x = torch.sin(dlon) * torch.cos(lat2)
    y = torch.cos(lat1) * torch.sin(lat2) - torch.sin(lat1) * torch.cos(lat2) * torch.cos(dlon)
    
    bearing = torch.atan2(x, y)
    return bearing

def displace_point(lat, lon, bearing, distance):
    """
    Batch move a given distance from the starting point along the specified azimuth Angle, and calculate the coordinates of the end point
    Input:
    lat: Starting point latitude (degree), shape [B, N]
    lon: Starting point longitude (degree), shape [B, N]
    bearing: Azimuth (radians), shape [B, N]
    distance: Distance (meters), shape [B, N]
    Return:
    new_lat, new_lon: end latitude and longitude (degrees), shape [B, N]
    """
    lat_rad = deg_to_rad(lat)
    lon_rad = deg_to_rad(lon)
    
    angular_dist = distance / EARTH_RADIUS
    
    sin_lat = torch.sin(lat_rad)
    cos_lat = torch.cos(lat_rad)
    sin_angular = torch.sin(angular_dist)
    cos_angular = torch.cos(angular_dist)
    cos_bearing = torch.cos(bearing)
    sin_bearing = torch.sin(bearing)
    
    new_lat_rad = torch.asin(
        sin_lat * cos_angular +
        cos_lat * sin_angular * cos_bearing
    )
    
    new_lon_rad = lon_rad + torch.atan2(
        sin_bearing * sin_angular * cos_lat,
        cos_angular - sin_lat * torch.sin(new_lat_rad)
    )
    
    new_lon_rad = (new_lon_rad + math.pi) % (2 * math.pi) - math.pi
    
    return rad_to_deg(new_lat_rad), rad_to_deg(new_lon_rad)

def calculate_resultant_displacement(A, B, C):
    """
    Calculate the strict synthesis result of two displacement vectors
    Input:
    A: Starting tensor, shape [batch_size, num_points, 2] (lat, lon)
    B: The first component force end point tensor, shape [batch_size, num_points, 2]
    C: The second component force end point tensor, shape [batch_size, num_points, 2]
    Return
    D: Composite endpoint tensor, shape [batch_size, num_points, 2]
    resultant_distance: Composite displacement distance tensor, shape [batch_size, num_points]
    resultant_bearing: Composite displacement azimuth tensor (degree), shape [batch_size, num_points]
    """
    # Separate latitude from longitude
    latA, lonA = A[..., 0], A[..., 1]
    latB, lonB = B[..., 0], B[..., 1]
    latC, lonC = C
    
    # Calculate the displacement vectors AB and AC
    bearing_AB = calculate_initial_bearing(latA, lonA, latB, lonB)
    distance_AB = haversine_distance(latA, lonA, latB, lonB)
    # print("bearing_AB:", bearing_AB)
    # print("distance_AB:", distance_AB)
    bearing_AC = calculate_initial_bearing(latA, lonA, latC, lonC)
    distance_AC = haversine_distance(latA, lonA, latC, lonC)
    # print("bearing_AC:", bearing_AC)
    # print("distance_AC:", distance_AC)
    north_AB = distance_AB * torch.cos(bearing_AB)
    east_AB = distance_AB * torch.sin(bearing_AB)
    
    north_AC = distance_AC * torch.cos(bearing_AC)
    east_AC = distance_AC * torch.sin(bearing_AC)
    
    # composes the vectors
    north_total = north_AB + north_AC
    east_total = east_AB + east_AC
    
    # Calculate the distance and azimuth of the synthetic vector
    resultant_distance = torch.sqrt(east_total**2 + north_total**2)
    resultant_bearing_rad = torch.atan2(east_total, north_total)
    resultant_bearing_deg = rad_to_deg(resultant_bearing_rad) % 360
    
    # Calculate the synthetic endpoint D
    latD, lonD = displace_point(latA, lonA, resultant_bearing_rad, resultant_distance)
    
    D = torch.stack([latD, lonD], dim=-1)
    
    return D, resultant_distance, resultant_bearing_deg
