import torch
import enum
import numpy as np
import math

class NodeType(enum.IntEnum):
    NORMAL = 0
    OBSTACLE = 1
    AIRFOIL = 2
    HANDLE = 3
    INFLOW = 4
    OUTFLOW = 5
    WALL_BOUNDARY = 6
    SIZE = 9


def plane_equation(p1, p2, p3):
    """
    根据三角形三个顶点坐标求出平面方程的系数
    :param p1: 第一个顶点坐标
    :param p2: 第二个顶点坐标
    :param p3: 第三个顶点坐标
    :return: 平面方程的系数 a, b, c, d
    """
    v1 = p2 - p1
    v2 = p3 - p1
    normal_vector = np.cross(v1, v2)
    a, b, c = normal_vector
    d = -np.dot(normal_vector, p1)
    return a, b, c, d

def Side(A, B, C, P):
    AB = B - A
    AC = C - A
    AP = P - A
    a1 = np.cross(AB, AC)
    a2 = np.cross(AB, AP)
    return np.dot(a1, a2) >= 0  

def PointInTri(P, A, B, C):
    return Side(A, B, C, P) and Side(B, C, A, P) and Side(C, A, B, P)

def calc_point_to_face_distance(p_x, f_a, f_b, f_c):
    a, b, c, d = plane_equation(f_a, f_b, f_c)
    #print(a, b, c, d)
    numerator = np.abs(a * p_x[0] + b * p_x[1] + c * p_x[2] + d)
    denominator = np.sqrt(a ** 2 + b ** 2 + c ** 2)
    dist = numerator / denominator
    p_y = p_x + (dist / denominator) * np.array([a, b, c]) 
    if (PointInTri(p_y, f_a, f_b, f_c)):
    #    print("calc_point_to_face", dist, p_x, p_y)
        return dist, p_x, p_y
    else:
        return 1e9, -1, -1

def IsInside(a, c, b):
    if (a[0] != c[0]):
        if (min(a[0], c[0]) <= b[0] and b[0] <= max(a[0], c[0])):
            return True
        else:
            return False
    elif (a[1] != c[1]):   
        if (min(a[1], c[1]) <= b[1] and b[1] <= max(a[1], c[1])):
            return True
        else:
            return False
    elif (a[2] != c[2]):
         if (min(a[2], c[2]) <= b[2] and b[2] <= max(a[2], c[2])):
            return True
         else:
            return False       
    return False

def calc_line_to_line_distance(x, y, f, g):
    a, t_a =  x, y - x
    b, t_b =  f, g - f
    a_1, a_2, a_3 = t_a[0] ** 2 + t_a[1] ** 2 + t_a[2] ** 2, -t_a[0] * t_b[0] - t_a[1] * t_b[1] - t_a[2] * t_b[2], (a[0] - b[0]) * t_a[0] + (a[1] - b[1]) * t_a[1] + (a[2] - b[2]) * t_a[2]
    a_4, a_5, a_6 = t_a[0] * t_b[0] + t_a[1] * t_b[1] + t_a[2] * t_b[2], -(t_b[0] ** 2 + t_b[1] ** 2  + t_b[2] ** 2), (a[0] - b[0]) * t_b[0] + (a[1] - b[1]) * t_b[1] + (a[2] - b[2]) * t_b[2]
    if (a_1 * a_5 == a_2 * a_4):    
        if (a_2 != 0):
            p_x = a
            p_y = b + -a_3 / a_2
        else:
            p_y = b
            p_x = a + -a_3 / a_1
    else:   
        solve_x = np.linalg.solve(np.array([[a_1, a_2], [a_4, a_5]]), np.array([-a_3, -a_6]))
        p_x = a + solve_x[0] * t_a
        p_y = b + solve_x[1] * t_b
    if (IsInside(x, y, p_x) and IsInside(f, g, p_y)):
    #    print("calc_line_to_line", math.sqrt(np.sum((p_x - p_y)**2)), p_x, p_y)
        return math.sqrt(np.sum((p_x - p_y)**2)), p_x, p_y
    return 1e9, -1, -1
        
    
def calc_point_to_line_distance(sx, sy, ey):
    t_y = ey - sy
    x = ((sx[0] - sy[0]) * t_y[0] + (sx[1] - sy[1]) * t_y[1] + (sx[2] - sy[2]) * t_y[2]) / (np.sum(t_y ** 2))
    p_x = sy + x * (ey - sy)
    #print("calc_point_to_line_pre", sx, p_x, sy, ey)
    if (IsInside(sy, ey, p_x)):
    #    print("calc_point_to_line", math.sqrt(np.sum((sx - p_x)**2)), sx, p_x)
        return math.sqrt(np.sum((sx - p_x)**2)), sx, p_x
    return 1e9, -1, -1

        
def calc_edge_to_edge_distance(sx, ex, sy, ey):
    d, ta, tb = 1e9, -1, -1
    d_new, ta_new, tb_new = calc_line_to_line_distance(sx, ex, sy, ey)
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new
    d_new, ta_new, tb_new = calc_point_to_line_distance(sx, sy, ey)
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new
    d_new, ta_new, tb_new = calc_point_to_line_distance(ex, sy, ey)
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new     
    d_new, tb_new, ta_new = calc_point_to_line_distance(sy, sx, ex)
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new     
    d_new, tb_new, ta_new = calc_point_to_line_distance(ey, sx, ex)
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new          
    d_new, ta_new, tb_new = math.sqrt(np.sum((sx - sy) ** 2)), sx, sy 
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new     
    d_new, ta_new, tb_new = math.sqrt(np.sum((sx - ey) ** 2)), sx, ey 
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new 
    d_new, ta_new, tb_new = math.sqrt(np.sum((ex - sy) ** 2)), ex, sy 
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new 
    d_new, ta_new, tb_new = math.sqrt(np.sum((ex - ey) ** 2)), ex, ey 
    if (d_new < d):
        d, ta, tb = d_new, ta_new, tb_new 
    #print("calc_edge_to_edge", d, ta, tb)
    return d, ta, tb

def calc_face_to_face_distance(obs_face, obj_face):
    d, ta, tb = 1e9, -1, -1
    for idx in range(3):
        d_new, ta_new, tb_new = calc_point_to_face_distance(obs_face[idx], obj_face[0], obj_face[1], obj_face[2])
        if (d_new < d):
            d, ta, tb = d_new, ta_new, tb_new 
    for idx in range(3):
        d_new, tb_new, ta_new = calc_point_to_face_distance(obj_face[idx], obs_face[0], obs_face[1], obs_face[2])
        if (d_new < d):
            d, ta, tb = d_new, ta_new, tb_new       
    for idx in [(0, 1), (1, 2), (2, 0)]:
        for idy in [(0, 1), (1, 2), (2, 0)]:
            d_new, ta_new, tb_new = calc_edge_to_edge_distance(obs_face[idx[0]], obs_face[idx[1]], obj_face[idy[0]],obj_face[idy[1]])
            if (d_new < d):
                d, ta, tb = d_new, ta_new, tb_new                
    return d, ta, tb

def calc_face_to_face_distance_4(obs_face, obj_face):
    d, ta, tb = 1e9, -1, -1
    for tri_a in [[0, 1, 2], [0, 2, 3]]:
      for tri_b in [[0, 1, 2], [0, 2, 3]]: 
        d_new, ta_new, tb_new = calc_face_to_face_distance(obs_face[tri_a,:], obj_face[tri_b,:])
        if (d_new < d):
          d, ta, tb = d_new, ta_new, tb_new
    return d     

def find_intersecting_pairs(B, A):

    events = []
    for idx in range(A.shape[0]):
        events.append((A[idx, 0], 'Start_A', idx))
        events.append((A[idx, 1], 'End_A', idx))
    for idx in range(B.shape[0]):
        events.append((B[idx, 0], 'Start_B', idx))
        events.append((B[idx, 1], 'End_B', idx))
    
    # 事件排序：时间优先，同时间按类型排序（End_A → End_B → Start_A → Start_B）
    events.sort(key=lambda x: (x[0], {'End_A':3, 'End_B':2, 'Start_A':1, 'Start_B':0}[x[1]]))
    
    active_A, active_B = set(), set()
    result = []
    
    for time, typ, idx in events:
        if typ == 'Start_A':
            # 记录与所有活跃的B区间相交
            result.extend((idx, b_idx) for b_idx in active_B)
            active_A.add(idx)
        elif typ == 'Start_B':
            # 记录与所有活跃的A区间相交
            result.extend((a_idx, idx) for a_idx in active_A)
            active_B.add(idx)
        elif typ == 'End_A':
            active_A.discard(idx)
        else: # End_B
            active_B.discard(idx)
    
    return result
def faces_to_faces(faces, node_pos, node_type, mask, mask1):
    threshold = 0.05
    
    faces_pos = node_pos[faces.reshape(-1)].reshape(-1, 3, 3)
    na_index = np.array(np.where(mask == 1)).reshape(-1)
    na_faces_pos = faces_pos[na_index,:]
    na_min = np.min(na_faces_pos, axis = 1).reshape(-1, 3)
    na_max = np.max(na_faces_pos, axis = 1).reshape(-1, 3)
    na_faces_x_min = na_min[:,0]
    na_faces_x_max = na_max[:,0]
    na_faces_y_min = na_min[:,1]
    na_faces_y_max = na_max[:,1]
    na_faces_z_min = na_min[:,2]
    na_faces_z_max = na_max[:,2]

    nb_index = np.array(np.where(mask1 == 1)).reshape(-1)
    nb_faces_pos = faces_pos[nb_index,:]
    nb_min = np.min(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_max = np.max(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_faces_x_min = nb_min[:,0] - threshold
    nb_faces_x_max = nb_max[:,0] + threshold
    nb_faces_y_min = nb_min[:,1] - threshold
    nb_faces_y_max = nb_max[:,1] + threshold  
    nb_faces_z_min = nb_min[:,2] - threshold
    nb_faces_z_max = nb_max[:,2] + threshold

    set_A = find_intersecting_pairs(np.concatenate([na_faces_x_min.reshape(-1, 1), na_faces_x_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_x_min.reshape(-1, 1), nb_faces_x_max.reshape(-1, 1)], axis = 1))
    set_B = find_intersecting_pairs(np.concatenate([na_faces_y_min.reshape(-1, 1), na_faces_y_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_y_min.reshape(-1, 1), nb_faces_y_max.reshape(-1, 1)], axis = 1))    
    set_C = find_intersecting_pairs(np.concatenate([na_faces_z_min.reshape(-1, 1), na_faces_z_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_z_min.reshape(-1, 1), nb_faces_z_max.reshape(-1, 1)], axis = 1)) 
    set_A = set(set_A)
    set_B = set(set_B)
    results = []
    for item in set_C:
        if (item in set_A):
            if (item in set_B):
                results.append(item)
    if (len(results) > 0):
        results = np.array(results).reshape(-1, 2)
        results[:, 0] = nb_index[results[:, 0]]
        results[:, 1] = na_index[results[:, 1]]   
        edge_index = results 
    else:
        edge_index = np.zeros((0, 2))

    return edge_index

def node_sorted(x, pos):
    item_size = x.shape[0]
    node_size = x.shape[1]  
    x_pos = pos[x.reshape(-1)].reshape(item_size, -1, 3)
    x_mean = np.mean(x_pos, axis = 1)
    w_x = copy.deepcopy(x)
    for i in range(node_size):
        w_x[:, i] = np.sqrt(np.sum((x_pos[:, i, :] - x_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = x[np.arange(x.shape[0])[:, None], row_indices]
    return sorted_arr

def face_sorted(cells_faces, faces, pos):
    item_size = cells_faces.shape[0]
    face_size = cells_faces.shape[1]  
    
    faces_size = faces.shape[0]
    faces_pos = pos[faces.reshape(-1)].reshape(faces_size, -1, 3)
    faecs_mean = np.mean(faces_pos, axis = 1)
    
    cells_pos = faecs_mean[cells_faces.reshape(-1)].reshape(item_size, -1, 3)
    cells_mean = np.mean(cells_pos, axis = 1)
    w_x = copy.deepcopy(cells_faces)
    for i in range(face_size):
        w_x[:, i] = np.sqrt(np.sum((cells_pos[:, i, :] - cells_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = cells_faces[np.arange(item_size)[:, None], row_indices]
    return sorted_arr


def cells_to_edges(cells):
  if (cells.shape[-1] == 4):
    edges = np.concatenate([
                        np.stack([cells[:, i], cells[:, j]], axis=1) for i in range(4) for j in range(i + 1, 4)], axis=0)
  else:
    edges = np.concatenate([
                        np.stack([cells[:, i], cells[:, j]], axis=1) for i in range(3) for j in range(i + 1, 3)], axis=0)
  # those edges are sometimes duplicated (within the mesh) and sometimes
  # single (at the mesh boundary).
  # sort & pack edges as single tf.int64
  receivers = np.min(edges, axis=1)
  senders = np.max(edges, axis=1)
  packed_edges = np.stack([senders, receivers], axis=1)
  # remove duplicates and unpack
  unique_rows, indices = np.unique(packed_edges, return_index=True, axis=0)
  unique_edges = unique_rows[np.argsort(indices)]
  # create two-way connectivity
  unique_edges_1 = copy.deepcopy(unique_edges)
  unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
  unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
  return np.concatenate([unique_edges, unique_edges_1], axis = 0)


                
def nodes_to_edges(nodes, node_type):
    threshold = 0.1
    mask = np.logical_or(node_type == NodeType.NORMAL, node_type == NodeType.HANDLE)
    na_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_a = nodes[mask.reshape(-1),:]
    mask = (node_type == NodeType.OBSTACLE)
    nb_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_b = nodes[mask.reshape(-1),:]
    
    distances = np.sqrt(np.sum((node_a - node_b[:, None, :] )** 2, axis=-1))
    edge_index = np.array(np.where(distances < threshold))
    edge_index[0,:] = nb_index[edge_index[0,:]]
    edge_index[1,:] = na_index[edge_index[1,:]]
    unique_edges = np.transpose(edge_index)
    unique_edges_1 = copy.deepcopy(unique_edges)
    unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
    unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
    return np.concatenate([unique_edges, unique_edges_1], axis = 0)

def cells_to_faces(cells):
    faces = np.concatenate([np.stack([cells[:, 0], cells[:, 1], cells[:, 2]], axis = 1),
                            np.stack([cells[:, 0], cells[:, 1], cells[:, 3]], axis = 1),
                            np.stack([cells[:, 0], cells[:, 2], cells[:, 3]], axis = 1),
                            np.stack([cells[:, 1], cells[:, 2], cells[:, 3]], axis = 1),], axis = 0)
    cells_faces = []
    for i in range(cells.shape[0]):
        cells_faces.append([])
    faces = np.sort(faces, axis = 1)
    index = np.zeros((faces.shape[0], 1))   
    for i in range(faces.shape[0]):
        index[i, 0] = i % (faces.shape[0] // 4)
    new_faces = np.concatenate([faces, index], axis = 1)
    faces_list = [list(new_faces[i]) for i in range(faces.shape[0])]
    faces_list.sort()
    faces_edges = []
    unique_faces_list = []
    for i in range(faces.shape[0]):
        if (i == 0 or faces_list[i][:-1] != faces_list[i - 1][:-1]):
            faces_edges.append([faces_list[i][-1]])
            unique_faces_list.append(faces_list[i][:-1])
        else:
             faces_edges[-1].append(faces_list[i][-1])    
        cells_faces[int(faces_list[i][-1])].append(len(unique_faces_list) - 1)
    for i in range(len(faces_edges)):
        if (len(faces_edges[i]) == 1):
            faces_edges[i].append(-1e9)     

    return np.array(unique_faces_list), np.array(faces_edges), np.array(cells_faces)