import torch
import math
import numpy as np
import geopandas as gpd
from shapely.geometry import Point
from pyproj import Geod
import warnings

# Earth average radius (in kilometers)
EARTH_RADIUS = 6371.0

def distance_to_coastline(lon, lat, coastline_path="ne_10m_coastline.shp", coastline_data = None):
    """
    Calculate the geodesic distance from given longitude and latitude to the nearest coastline (in meters)
    :return: Distance (meters), returns None on processing failure
    """

    if coastline_data is not None:
        coastlines = coastline_data
    else:
        # Otherwise load coastline data
        coastlines = gpd.read_file(coastline_path)
    try:
        # Ignore geographic coordinate system warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            
            # Create target point
            target_point = Point(lon, lat)
            
            # Use spatial index to find the nearest coastline feature
            nearest_idx = coastlines.geometry.distance(target_point).idxmin()
            closest_segment = coastlines.geometry.iloc[nearest_idx]
    
        # Initialize geodesic calculator
        geod = Geod(ellps="WGS84")
        min_distance = float('inf')
        
        # Handle MultiLineString
        if closest_segment.geom_type == 'MultiLineString':
            segments = list(closest_segment.geoms)
        else:
            segments = [closest_segment]
        
        # Calculate shortest distance from point to each line segment
        for line in segments:
            coords = np.array(line.coords)
            for i in range(len(coords) - 1):
                p1 = coords[i]
                p2 = coords[i+1]
                
                # Precisely calculate geodesic distance from point to line segment
                segment_distance = _geodesic_point_to_segment(geod, lon, lat, p1, p2)
                
                if segment_distance < min_distance:
                    min_distance = segment_distance
        
        return min_distance
    
    except Exception as e:
        print(f"Calculation error: {e}")
        return None

def _geodesic_point_to_segment(geod, lon, lat, p1, p2, num_points=20):
    """
    :return: Minimum distance (meters)
    """
    # Interpolate points on geodesic line
    line = [p1, p2]
    if num_points > 2:
        line = list(geod.npts(p1[0], p1[1], p2[0], p2[1], num_points-2))
        line = [p1] + line + [p2]
    
    # Calculate distance to all interpolated points
    min_dist = float('inf')
    for point in line:
        _, _, dist = geod.inv(lon, lat, point[0], point[1])
        if dist < min_dist:
            min_dist = dist
    
    return min_dist

def deg_to_rad(deg_tensor):
    """Batch convert degrees to radians"""
    return deg_tensor * (math.pi / 180.0)

def rad_to_deg(rad_tensor):
    """Batch convert radians to degrees"""
    return rad_tensor * (180.0 / math.pi)

def deg_to_vec(deg):
    rad = np.radians(deg)
    unit_vec = np.stack((np.cos(rad), np.sin(rad)), axis=1)
    return unit_vec

def UnitConversion(input, dest_coord, fr='std', to='rad'):
    tensor1 = input.clone()
    if dest_coord != None: tensor2 = dest_coord.clone()

    if fr == 'std' and to == "rad":
        tensor1[:,:,0] *= math.pi/3
        tensor1[:,:,1] *= math.pi
        tensor1[:,:,2:4] *= 25
        if dest_coord != None:
            tensor2[:,:,0] *= math.pi/3
            tensor2[:,:,1] *= math.pi
        # return tensor
    elif fr == "rad" and to == "std":
        tensor1[:,:,0] /= math.pi/3
        tensor1[:,:,1] /= math.pi
        tensor1[:,:,2:4] /= 25
        if dest_coord != None:
            tensor2[:,:,0] /= math.pi/3
            tensor2[:,:,1] /= math.pi
        # return tensor
    elif fr == "std" and to == "deg":
        tensor1[:,:,0] *= 60
        tensor1[:,:,1] *= 180
        tensor1[:,:,2:4] *= 25
        if dest_coord != None:
            tensor2[:,:,0] *= 60
            tensor2[:,:,1] *= 180
        # return tensor
    elif fr == "deg" and to == "std":
        tensor1[:,:,0] /= 60
        tensor1[:,:,1] /= 180
        tensor1[:,:,2:4] /= 25
        if dest_coord != None:
            tensor2[:,:,0] /= 60
            tensor2[:,:,1] /= 180
        # return tensor
    else:
        raise ValueError(f"from {fr} and to {to} should be in [std, rad, degree]")
    if dest_coord != None:
        return tensor1.clone(), tensor2.clone()
    return tensor1.clone(), None


def regulize(position):
    # position[...,0] = (position[..., 0] + math.pi/2) % math.pi - math.pi/2
    position[...,1] = (position[..., 1] + math.pi) % (2*math.pi) - math.pi
    # if position[...,0] > math.pi/2:
    #     position[...,0] = math.pi - position[...,0]
    #     position[...,1] = -position[...,1]
    # if position[..., 0] < 0:
    #     position[...,0] = -position[...,0]
    #     position[...,1] = -position[...,1]
    return position

def angular_loss(true, pred):
    # print(true.mean(), pred.mean())
    return 1 - torch.cos(true - pred).mean()

def haversine_distance(lat1, lon1, lat2, lon2):
    """
    Batch calculate great-circle distance between two points (spherical distance)
    Input: Latitude/longitude tensors (same shape)
    Return: Distance tensor (kilometers), shape same as input (except last dimension)
    """
    
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    
    a = torch.sin(dlat/2)**2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon/2)**2
    assert 0<=a.all()<=1, "a should between 0 and 1"
    c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a))
    
    return EARTH_RADIUS * c

def calculate_initial_bearing(lat1, lon1, lat2, lon2):
    """
    Batch calculate initial bearing from start point to end point (clockwise from north)
    Return: Bearing angle (radians), shape same as input (except last dimension)
    """
    
    dlon = lon2 - lon1
    x = torch.sin(dlon) * torch.cos(lat2)
    y = torch.cos(lat1) * torch.sin(lat2) - torch.sin(lat1) * torch.cos(lat2) * torch.cos(dlon)
    
    bearing = torch.atan2(x, y)
    return bearing

def displace_point(lat, lon, bearing, distance):
    """
    Batch calculate end point coordinates by moving given distance along specified bearing from start point
    Input:
        lat: Start latitude (degrees), shape [B, N]
        lon: Start longitude (degrees), shape [B, N]
        bearing: Bearing angle (radians), shape [B, N]
        distance: Distance (meters), shape [B, N]
    Return:
        new_lat, new_lon: End latitude and longitude (degrees), shape [B, N]
    """
    lat_rad = deg_to_rad(lat)
    lon_rad = deg_to_rad(lon)
    
    angular_dist = distance / EARTH_RADIUS
    
    sin_lat = torch.sin(lat_rad)
    cos_lat = torch.cos(lat_rad)
    sin_angular = torch.sin(angular_dist)
    cos_angular = torch.cos(angular_dist)
    cos_bearing = torch.cos(bearing)
    sin_bearing = torch.sin(bearing)
    
    # Calculate new latitude
    new_lat_rad = torch.asin(
        sin_lat * cos_angular +
        cos_lat * sin_angular * cos_bearing
    )
    
    # Calculate new longitude
    new_lon_rad = lon_rad + torch.atan2(
        sin_bearing * sin_angular * cos_lat,
        cos_angular - sin_lat * torch.sin(new_lat_rad)
    )
    
    # Ensure longitude is within [-180, 180] range
    new_lon_rad = (new_lon_rad + math.pi) % (2 * math.pi) - math.pi
    
    return rad_to_deg(new_lat_rad), rad_to_deg(new_lon_rad)

def calculate_resultant_displacement(A, B, C):
    """
    Batch calculate strict resultant of two displacement vectors
    Input:
        A: Start point tensor, shape [batch_size, num_points, 2] (lat, lon)
        B: First component force end point tensor, shape [batch_size, num_points, 2]
        C: Second component force end point tensor, shape [batch_size, num_points, 2]
    Return:
        D: Resultant end point tensor, shape [batch_size, num_points, 2]
        resultant_distance: Resultant displacement distance tensor, shape [batch_size, num_points]
        resultant_bearing: Resultant displacement bearing angle tensor (degrees), shape [batch_size, num_points]
    """
    # Separate latitude and longitude
    latA, lonA = A[..., 0], A[..., 1]
    latB, lonB = B[..., 0], B[..., 1]
    latC, lonC = C
    
    # Calculate displacement vectors AB and AC
    bearing_AB = calculate_initial_bearing(latA, lonA, latB, lonB)
    distance_AB = haversine_distance(latA, lonA, latB, lonB)
    # print("bearing_AB:", bearing_AB)
    # print("distance_AB:", distance_AB)
    bearing_AC = calculate_initial_bearing(latA, lonA, latC, lonC)
    distance_AC = haversine_distance(latA, lonA, latC, lonC)
    # print("bearing_AC:", bearing_AC)
    # print("distance_AC:", distance_AC)
    # Decompose vectors in the tangent plane at start point
    north_AB = distance_AB * torch.cos(bearing_AB)
    east_AB = distance_AB * torch.sin(bearing_AB)
    
    north_AC = distance_AC * torch.cos(bearing_AC)
    east_AC = distance_AC * torch.sin(bearing_AC)
    
    # Vector composition
    north_total = north_AB + north_AC
    east_total = east_AB + east_AC
    
    # Calculate distance and bearing angle of resultant vector
    resultant_distance = torch.sqrt(east_total**2 + north_total**2)
    resultant_bearing_rad = torch.atan2(east_total, north_total)
    resultant_bearing_deg = rad_to_deg(resultant_bearing_rad) % 360
    
    # Calculate resultant end point D
    latD, lonD = displace_point(latA, lonA, resultant_bearing_rad, resultant_distance)
    
    # Combine results
    D = torch.stack([latD, lonD], dim=-1)
    
    return D, resultant_distance, resultant_bearing_deg