import torch
import torch.nn.functional as F 
from torch.utils.data import DataLoader
import h5py
import os
import numpy as np
import operator
from itertools import accumulate
import copy
from utils import NodeType
from utils import calc_face_to_face_distance_4
import math

previous_face_edges = []
frame_index = [0]
def custom_collate(data):
    final_flag = [torch.tensor(d['final_flag']) for d in data]
    world_pos = [torch.tensor(d['world_pos']) for d in data]
    stress = [torch.tensor(d['stress']) for d in data]
    node_type = [torch.tensor(d['node_type']) for d in data]
    ground_truth = [torch.tensor(d['target']) for d in data]
    next_pos_need = [torch.tensor(d['next_pos_need']) for d in data]
    ori_faces = [torch.tensor(d["faces"]) for d in data] 
    ori_elements = [torch.tensor(d["elements"]) for d in data] 
    ori_faces_sequ = [torch.tensor(d["faces_sequ"]) for d in data] 
    ori_elements_sequ = [torch.tensor(d["elements_sequ"]) for d in data] 
    ori_faces_edges = [torch.tensor(d["faces_edges"]) for d in data]
    ori_faces_to_faces = [torch.tensor(d["faces_to_faces"]) for d in data]
    last_pos = [torch.tensor(d['last_pos']) for d in data]
    last_stress = [torch.tensor(d['last_stress']) for d in data]
    ori_cells_faces = [torch.tensor(d['cells_faces']) for d in data]
    ori_mesh_edge = [torch.tensor(d['mesh_edge']) for d in data]
    mesh_pos = [torch.tensor(d['mesh_pos']) for d in data]
    ori_world_edge = [torch.tensor(d['world_edge']) for d in data]

    final_flag = torch.concat(final_flag, dim = 0)
    mesh_pos = torch.concat(mesh_pos, dim = 0)
    world_pos = torch.concat(world_pos, dim = 0)
    stress = torch.concat(stress, dim = 0)
    node_type = torch.concat(node_type, dim = 0)    
    ground_truth = torch.concat(ground_truth, dim = 0)
    next_pos_need = torch.concat(next_pos_need, dim = 0)
    last_pos = torch.concat(last_pos, dim = 0)
    last_stress = torch.concat(last_stress, dim = 0)
        
    node_count = [d["stress"].shape[0] for d in data]
    element_count = [d['elements'].shape[0] for d in data]
    face_count = [d['faces'].shape[0] for d in data]
    node_count = torch.tensor([0] + list(accumulate(node_count))[:-1]) 
    element_count = torch.tensor([0] + list(accumulate(element_count))[:-1]) 
    face_count = torch.tensor([0] + list(accumulate(face_count))[:-1]) 
    for i in range(len(ori_elements)):
        ori_world_edge[i] = ori_world_edge[i] + node_count[i]
        ori_mesh_edge[i] = ori_mesh_edge[i] + node_count[i]
        ori_elements[i] = ori_elements[i] + node_count[i]
        ori_faces[i] = ori_faces[i] + node_count[i]
        ori_elements_sequ[i] = ori_elements_sequ[i] + node_count[i]
        ori_faces_sequ[i] = ori_faces_sequ[i] + node_count[i]
        ori_faces_edges[i] = ori_faces_edges[i] + element_count[i]
        ori_faces_to_faces[i] = ori_faces_to_faces[i] + face_count[i]
        ori_cells_faces[i] = ori_cells_faces[i] + face_count[i]
    elements = torch.concat(ori_elements, dim = 0)
    faces = torch.concat(ori_faces, dim = 0)
    elements_sequ = torch.concat(ori_elements_sequ, dim = 0)
    faces_sequ = torch.concat(ori_faces_sequ, dim = 0)
    faces_edges = torch.concat(ori_faces_edges, dim = 0)
    faces_to_faces = torch.concat(ori_faces_to_faces, dim = 0)
    cells_faces = torch.concat(ori_cells_faces, dim = 0)
    world_edge = torch.concat(ori_world_edge, dim = 0)
    mesh_edge = torch.concat(ori_mesh_edge, dim = 0)

    return  { 
                "final_flag": final_flag,
                "elements_sequ":elements_sequ,
                "faces_sequ":faces_sequ,
                "cells_faces":cells_faces,
                "last_stress":last_stress,
                "last_pos":last_pos,
                "world_pos": world_pos,
                "stress": stress,
                "node_type": node_type,
                "target": ground_truth,
                "next_pos_need": next_pos_need,
                "elements": elements,
                "faces": faces,
                "faces_edges": faces_edges,
                "faces_to_faces": faces_to_faces,
                "mesh_edge": mesh_edge,
                "mesh_pos": mesh_pos, 
                "world_edge": world_edge,
            }
def cells_to_edges(cells):
  if (cells.shape[1] == 4):
    edges = np.concatenate([np.stack([cells[:, 0], cells[:, 1]], axis=1),
                            np.stack([cells[:, 1], cells[:, 2]], axis=1),
                            np.stack([cells[:, 2], cells[:, 3]], axis=1),
                            np.stack([cells[:, 3], cells[:, 0]], axis=1)
                            ], axis=0)  
  elif (cells.shape[1] == 8):
    edges = np.concatenate([np.stack([cells[:, 0], cells[:, 1]], axis=1),
                            np.stack([cells[:, 1], cells[:, 2]], axis=1),
                            np.stack([cells[:, 2], cells[:, 3]], axis=1),
                            np.stack([cells[:, 3], cells[:, 0]], axis=1),
                            np.stack([cells[:, 4], cells[:, 5]], axis=1),
                            np.stack([cells[:, 5], cells[:, 6]], axis=1),
                            np.stack([cells[:, 6], cells[:, 7]], axis=1),
                            np.stack([cells[:, 7], cells[:, 4]], axis=1),
                            np.stack([cells[:, 0], cells[:, 4]], axis=1),
                            np.stack([cells[:, 1], cells[:, 5]], axis=1),
                            np.stack([cells[:, 2], cells[:, 6]], axis=1),
                            np.stack([cells[:, 3], cells[:, 7]], axis=1),
                            ], axis=0)  
  else:
      print("WRONG CELL TYPE")
      exit(0)
  # those edges are sometimes duplicated (within the mesh) and sometimes
  # single (at the mesh boundary).
  # sort & pack edges as single tf.int64
  receivers = np.min(edges, axis=1)
  senders = np.max(edges, axis=1)
  packed_edges = np.stack([senders, receivers], axis=1)
  # remove duplicates and unpack
  unique_rows, indices = np.unique(packed_edges, return_index=True, axis=0)
  unique_edges = unique_rows[np.argsort(indices)]
  # create two-way connectivity
  unique_edges_1 = copy.deepcopy(unique_edges)
  unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
  unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
  return np.concatenate([unique_edges, unique_edges_1], axis = 0)

def nodes_to_edges(nodes, node_type):
    threshold = 1
    mask = np.logical_or(node_type == NodeType.NORMAL, node_type == NodeType.HANDLE)
    na_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_a = nodes[mask.reshape(-1),:]
    mask = (node_type == NodeType.OBSTACLE)
    nb_index = np.array(np.where(mask == 1))[0,:].reshape(-1)
    node_b = nodes[mask.reshape(-1),:]
    
    distances = np.sqrt(np.sum((node_a - node_b[:, None, :] )** 2, axis=-1))
    edge_index = np.array(np.where(distances < threshold))
    edge_index[0,:] = nb_index[edge_index[0,:]]
    edge_index[1,:] = na_index[edge_index[1,:]]
    unique_edges = np.transpose(edge_index)
    unique_edges_1 = copy.deepcopy(unique_edges)
    unique_edges_1[:,0] = copy.deepcopy(unique_edges[:,1])
    unique_edges_1[:,1] = copy.deepcopy(unique_edges[:,0])
    return np.concatenate([unique_edges, unique_edges_1], axis = 0)

def find_intersecting_pairs(B, A):

    events = []
    for idx in range(A.shape[0]):
        events.append((A[idx, 0], 'Start_A', idx))
        events.append((A[idx, 1], 'End_A', idx))
    for idx in range(B.shape[0]):
        events.append((B[idx, 0], 'Start_B', idx))
        events.append((B[idx, 1], 'End_B', idx))
    
    events.sort(key=lambda x: (x[0], {'End_A':3, 'End_B':2, 'Start_A':1, 'Start_B':0}[x[1]]))
    
    active_A, active_B = set(), set()
    result = []
    
    for time, typ, idx in events:
        if typ == 'Start_A':
            result.extend((idx, b_idx) for b_idx in active_B)
            active_A.add(idx)
        elif typ == 'Start_B':
            result.extend((a_idx, idx) for a_idx in active_A)
            active_B.add(idx)
        elif typ == 'End_A':
            active_A.discard(idx)
        else: 
            active_B.discard(idx)
    
    return result
def faces_to_faces(faces, node_pos, node_type, mask, mask1):
    threshold = 0.3
    
    faces_pos = node_pos[faces.reshape(-1)].reshape(-1, 4, 3)
    na_index = np.array(np.where(mask == 1)).reshape(-1)
    na_faces_pos = faces_pos[na_index,:]
    na_min = np.min(na_faces_pos, axis = 1).reshape(-1, 3)
    na_max = np.max(na_faces_pos, axis = 1).reshape(-1, 3)
    na_faces_x_min = na_min[:,0]
    na_faces_x_max = na_max[:,0]
    na_faces_y_min = na_min[:,1]
    na_faces_y_max = na_max[:,1]
    na_faces_z_min = na_min[:,2]
    na_faces_z_max = na_max[:,2]

    nb_index = np.array(np.where(mask1 == 1)).reshape(-1)
    nb_faces_pos = faces_pos[nb_index,:]
    nb_min = np.min(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_max = np.max(nb_faces_pos, axis = 1).reshape(-1, 3)
    nb_faces_x_min = nb_min[:,0] - threshold
    nb_faces_x_max = nb_max[:,0] + threshold
    nb_faces_y_min = nb_min[:,1] - threshold
    nb_faces_y_max = nb_max[:,1] + threshold  
    nb_faces_z_min = nb_min[:,2] - threshold
    nb_faces_z_max = nb_max[:,2] + threshold

    set_A = find_intersecting_pairs(np.concatenate([na_faces_x_min.reshape(-1, 1), na_faces_x_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_x_min.reshape(-1, 1), nb_faces_x_max.reshape(-1, 1)], axis = 1))
    set_B = find_intersecting_pairs(np.concatenate([na_faces_y_min.reshape(-1, 1), na_faces_y_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_y_min.reshape(-1, 1), nb_faces_y_max.reshape(-1, 1)], axis = 1))    
    set_C = find_intersecting_pairs(np.concatenate([na_faces_z_min.reshape(-1, 1), na_faces_z_max.reshape(-1, 1)], axis = 1), 
                                    np.concatenate([nb_faces_z_min.reshape(-1, 1), nb_faces_z_max.reshape(-1, 1)], axis = 1)) 
    set_A = set(set_A)
    set_B = set(set_B)
    results = []
    for item in set_C:
        if (item in set_A):
            if (item in set_B):
                results.append(item)
    #print(results)
    results = np.array(results)
    results[:, 0] = nb_index[results[:, 0]]
    results[:, 1] = na_index[results[:, 1]]   
    edge_index = results 
    frame_index.append(frame_index[-1] + edge_index.shape[1])
    previous_face_edges.append(edge_index)
    return edge_index
                

def cells_to_faces(cells):
    #print(cells.shape)
    if (cells.shape[1] == 4):
        faces = np.stack([cells[:, 0], cells[:, 1], cells[:, 2], cells[:, 3]], axis = 1)
        faces_edges = np.zeros((faces.shape[0], 2)) - 1e9
        return faces, faces_edges, np.array([]), copy.deepcopy(faces)
    elif (cells.shape[1] == 8):
        faces = np.concatenate([np.stack([cells[:, 0], cells[:, 1], cells[:, 2], cells[:, 3]], axis = 1),
                                np.stack([cells[:, 4], cells[:, 5], cells[:, 6], cells[:, 7]], axis = 1),
                                np.stack([cells[:, 0], cells[:, 1], cells[:, 5], cells[:, 4]], axis = 1),
                                np.stack([cells[:, 1], cells[:, 2], cells[:, 6], cells[:, 5]], axis = 1),
                                np.stack([cells[:, 2], cells[:, 3], cells[:, 7], cells[:, 6]], axis = 1),
                                np.stack([cells[:, 0], cells[:, 3], cells[:, 7], cells[:, 4]], axis = 1),
                                ], axis = 0)
        old_faces = copy.deepcopy(faces)
        cells_faces = []
        for i in range(cells.shape[0]):
            cells_faces.append([])
        faces = np.sort(faces, axis = 1)
        index = np.zeros((faces.shape[0], 1))   
        for i in range(faces.shape[0]):
            #index[i, 0] = i % (faces.shape[0] // 6)
            index[i, 0] = i
        new_faces = np.concatenate([faces, index], axis = 1)
        faces_list = [list(new_faces[i]) for i in range(faces.shape[0])]
        faces_list.sort()
        faces_edges = []
        unique_faces_list = []
        faces_sequ = []
        for i in range(faces.shape[0]):
            belong_element = int(faces_list[i][-1] % (faces.shape[0] // 6))
            if (i == 0 or faces_list[i][:-1] != faces_list[i - 1][:-1]):
                faces_edges.append([belong_element])
                unique_faces_list.append(faces_list[i][:-1])
                faces_sequ.append(old_faces[int(faces_list[i][-1])])
            else:
                faces_edges[-1].append(belong_element)    
            cells_faces[int(belong_element)].append(len(unique_faces_list) - 1)
        for i in range(len(faces_edges)):
            if (len(faces_edges[i]) == 1):
                faces_edges[i].append(-1e9)     
        return np.array(unique_faces_list), np.array(faces_edges), np.array(cells_faces), np.array(faces_sequ)
    else:
        print("WRONG CELL TYPE")
        exit(0)
        


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def node_sorted(x, pos):
    item_size = x.shape[0]
    node_size = x.shape[1]  
    x_pos = pos[x.reshape(-1)].reshape(item_size, -1, 3)
    x_mean = np.mean(x_pos, axis = 1)
    w_x = copy.deepcopy(x)
    for i in range(node_size):
        w_x[:, i] = np.sqrt(np.sum((x_pos[:, i, :] - x_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = x[np.arange(x.shape[0])[:, None], row_indices]
    return sorted_arr

def face_sorted(cells_faces, faces, pos):
    item_size = cells_faces.shape[0]
    face_size = cells_faces.shape[1]  
    
    faces_size = faces.shape[0]
    faces_pos = pos[faces.reshape(-1)].reshape(faces_size, -1, 3)
    faecs_mean = np.mean(faces_pos, axis = 1)
    
    cells_pos = faecs_mean[cells_faces.reshape(-1)].reshape(item_size, -1, 3)
    cells_mean = np.mean(cells_pos, axis = 1)
    w_x = copy.deepcopy(cells_faces)
    for i in range(face_size):
        w_x[:, i] = np.sqrt(np.sum((cells_pos[:, i, :] - cells_mean[:])**2, axis = 1)).reshape(-1)
    row_indices = np.argsort(w_x, axis=1)
    sorted_arr = cells_faces[np.arange(item_size)[:, None], row_indices]
    return sorted_arr


def calc_Mises(x):
    return np.sqrt(0.5 * (
                            (x[:,:,0] - x[:,:,3])**2 + (x[:,:,3] - x[:,:,5])**2 + + (x[:,:,5] - x[:,:,0])**2
                            + 6 * (x[:,:,1]**2 + x[:,:,2]**2 + x[:,:,4]**2)
                           )
                    ).reshape(x.shape[0], x.shape[1], 1)    
def calc_PEEQ(x):
    return np.sqrt(2.0 / 3.0 * (
                            (x[:,:,0] - x[:,:,3])**2 + (x[:,:,3] - x[:,:,5])**2 + + (x[:,:,5] - x[:,:,0])**2
                            + 0.5 * (x[:,:,1]**2 + x[:,:,2]**2 + x[:,:,4]**2)
                           )
                    ).reshape(x.shape[0], x.shape[1], 1)    
def GetDataSetMSB(dataset_dir, split, batch_size = 1):
    file_path = os.path.join(dataset_dir, split)
    frame_data_collected = []
    cnt = 0
    cnter = 0
    output_file_path = "/strip_simple/full_data_valid.h5"
    with h5py.File(output_file_path, 'w') as f:
        for instance in os.listdir(file_path):
            instance_path = os.path.join(file_path, instance)
            instance_data = np.load(instance_path)
            group_1 = f.create_group(instance)
            print(instance_path)
            strip_features = np.concatenate([
                                            instance_data["STRIP_U"], calc_Mises(instance_data["STRIP_S"]),
                                            calc_PEEQ(instance_data["STRIP_PE"]),
                                            ], axis = 2)
            mould_features = np.zeros((instance_data["STRIP_U"].shape[0],instance_data["INIT_MOULD_POS"].shape[0], strip_features.shape[2]))
            for i in range(instance_data["STRIP_U"].shape[0]):
                mould_features[i,:,0:3] = instance_data["INIT_MOULD_POS"]
                strip_features[i,:,0:3] += instance_data["INIT_STRIP_POS"]
            node_features = np.concatenate([strip_features, mould_features], axis = 1)
            trace_length = instance_data["STRIP_U"].shape[0]
            mesh_pos = np.concatenate([instance_data["INIT_STRIP_POS"], instance_data["INIT_MOULD_POS"]], axis = 0)
            node_type = np.zeros(node_features.shape[1])
            node_type[strip_features.shape[1]:] = NodeType.OBSTACLE
            node_type[:strip_features.shape[1]] = NodeType.NORMAL
            x_min,x_max = instance_data["INIT_STRIP_POS"][:,0].min(), instance_data["INIT_STRIP_POS"][:,0].max()
            for i in range(strip_features.shape[1]):
                if (abs(instance_data["INIT_STRIP_POS"][i, 0] - x_min) < 1e-3 or abs(instance_data["INIT_STRIP_POS"][i, 0] - x_max) < 1e-3):
                    node_type[i] = NodeType.HANDLE     
                            
            faces_0, faces_edges_0, cells_faces_0, faces_sequ_0 = cells_to_faces(instance_data["STRIP_MESH"])
            faces_1, faces_edges_1, cells_faces_1, faces_sequ_1 = cells_to_faces(instance_data["MOULD_MESH"] + instance_data["STRIP_U"].shape[1])#**
            faces = np.concatenate([faces_0, faces_1], axis = 0)
            faces_sequ = np.concatenate([faces_sequ_0, faces_sequ_1], axis = 0)
            faces_edges = np.concatenate([faces_edges_0, faces_edges_1], axis = 0)
            cells_faces = cells_faces_0  
            
            faces = faces.astype(int)
            is_surface = np.zeros(faces.shape[0]).astype(int)
            for i in range(faces.shape[0]):
                is_surface[i] = faces_edges[i][-1] < 0
            mask = np.zeros(faces.shape[0]).astype(int)
            for i in range((faces.shape[0])):
                mask[i] = (node_type[faces[i, 0]] == NodeType.NORMAL or node_type[faces[i, 0]] == NodeType.HANDLE) and is_surface[i]
            mask1 = np.zeros(faces.shape[0]).astype(int)
            for i in range((faces.shape[0])):
                mask1[i] = node_type[faces[i, 0]] == NodeType.OBSTACLE and is_surface[i] 
            step_slice = 4
            last_cnt = cnt
            real_slice = 4
            if (split == "test"):
                real_slice = step_slice        
            mesh_edges = np.concatenate([cells_to_edges(instance_data["STRIP_MESH"]), cells_to_edges(instance_data["MOULD_MESH"]) + instance_data["STRIP_U"].shape[1]], axis = 0)
            for frame_id in range(0, trace_length - step_slice, real_slice):
              if (frame_id + real_slice < trace_length - step_slice):
                    group_2 = group_1.create_group(str(frame_id))
                    next_pos_need = node_features[frame_id + step_slice,:,:3] - node_features[frame_id,:,:3]
                    next_pos_need[(node_type == NodeType.NORMAL).reshape(-1), :] = 0
                    last_pos, last_stress = 0, 0
                    if (frame_id > 0):
                        last_pos, last_stress = node_features[frame_id - step_slice,:,:3], node_features[frame_id - step_slice,:,3:]
                    else:
                        last_pos, last_stress = node_features[frame_id,:,:3], node_features[frame_id,:,3:]
                    new_faces = node_sorted(copy.deepcopy(faces), node_features[frame_id,:,:3])
                    data_need_to_store = { 
                        "final_flag": np.array([frame_id + real_slice + real_slice >= trace_length - step_slice]),
                        "elements_sequ":copy.deepcopy(instance_data["STRIP_MESH"]),
                        "faces_sequ":copy.deepcopy(faces_sequ),
                        "mesh_pos": copy.deepcopy(mesh_pos),
                        "mesh_edge": copy.deepcopy(mesh_edges),
                        "cells_faces": face_sorted(copy.deepcopy(cells_faces), new_faces, node_features[frame_id,:,:3]),
                        "last_stress": last_stress,
                        "last_pos": last_pos,
                        "world_pos": node_features[frame_id,:,:3],
                        "stress": node_features[frame_id,:,3:],
                        "node_type": copy.deepcopy(node_type),
                        "elements": node_sorted(instance_data["STRIP_MESH"], node_features[frame_id,:,:3]),
                        "faces": new_faces,
                        "world_edge": nodes_to_edges(np.concatenate([instance_data["STRIP_U"][frame_id], instance_data["INIT_MOULD_POS"]], axis = 0), node_type),
                        "faces_edges": copy.deepcopy(faces_edges),
                        "next_pos_need": next_pos_need,
                        "faces_to_faces": 
                                        faces_to_faces(copy.deepcopy(faces_sequ), node_features[frame_id,:,:3], node_type, mask, mask1),
                        "target": np.concatenate([node_features[frame_id + step_slice,:,:3], node_features[frame_id + step_slice,:,3:] - node_features[frame_id,:,3:]], axis=1)
                        }
                    for name, value in data_need_to_store.items():
                        group_2.create_dataset(name, data=value)

                    cnt = cnt + step_slice
            cnt = last_cnt + trace_length - 1
            cnter = cnter + 1


    dataset = DataLoader(dataset = frame_data_collected, batch_size = batch_size, shuffle = False, collate_fn=custom_collate)
    return dataset
if __name__ == '__main__':
    dataset = GetDataSetMSB("/strip_simple/", "valid")
    
    
