import torch
import torch.nn.functional as F 
from torch.utils.data import DataLoader
import h5py
import os
import numpy as np
import operator
from itertools import accumulate
import copy
from utils import NodeType
from utils import face_sorted, node_sorted, faces_to_faces
from utils import nodes_to_edges, cells_to_edges, cells_to_faces
import pickle
import math

def write_smesh(filename, points, faces):

    points = np.asarray(points)
    faces = np.asarray(faces)

    assert points.ndim == 2 and points.shape[1] == 3
    assert faces.ndim == 2 and faces.shape[1] == 3
    assert np.issubdtype(faces.dtype, np.integer)
    assert faces.max() < len(points)

    with open(filename, 'w') as f:
        # Part 1: points
        f.write(f"{len(points)} 3 0 0\n")
        for i, p in enumerate(points):
            f.write(f"{i} {p[0]} {p[1]} {p[2]}\n")

        # Part 2: facets
        f.write(f"{len(faces)} 0\n")
        for i, face in enumerate(faces):
            f.write(f"3 {face[0]} {face[1]} {face[2]}\n")  # triangle
            # no boundary markers (so no further line)

        # Part 3: holes
        f.write("0\n")

        # Part 4: regions
        f.write("1\n")
        #a point inside the mesh region
        the_id = 30
        for i in range(30, 30 + 30):
            if (points[i, 0] > points[the_id, 0]):
                the_id = i
        
        f.write(f"0 {points[the_id, 0] - 0.01} {points[the_id, 1]} {points[the_id, 2]} 1 1.0\n")
        
def read_ele_file(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    header = list(map(int, lines[0].split()))
    num_elements = header[0]
    data = np.loadtxt(lines[1:num_elements+1], dtype=int)
    elements = data[:, 1:5]  
    return elements

def generate_cells(points, faces):
    faces = faces.astype(np.int32)
    
    write_smesh("input.smesh", points, faces)  
    os.system("tetgen -pi input.smesh")
    elements = read_ele_file("input.1.ele")
    
    return elements

def PreDataSetGC(dataset_dir, split):
    file_path = os.path.join(dataset_dir, "cavity_grasping_dataset_" + split + ".pkl")

    output_file_path = f"/cavity_grasping_dataset/data_{split}_new.h5"

    with h5py.File(output_file_path, 'a') as f:
        with open(file_path, 'rb') as files:
            file_data = pickle.load(files)
            trace_num = len(file_data)
            for trace_id in range(trace_num):
                print(trace_id)
                trace_data = file_data[trace_id]
                print(trace_data.keys())
                deform_pos = np.stack(trace_data["tissue_mesh_positions"], axis = 0) # 105 * N * 3
                rigid_pos = np.stack(trace_data["gripper_position"], axis = 0) #105 * M * 3
                
                pos = np.concatenate([deform_pos, rigid_pos], axis = 1) # 105 * (N + M) * 3
                node_type = np.zeros(pos.shape[1], dtype = int) 
                deform_point_num, rigid_point_num = deform_pos.shape[1], rigid_pos.shape[1]
                for i in range(deform_point_num):
                    node_type[i] = NodeType.NORMAL
                for i in range(rigid_point_num):
                    node_type[i + deform_point_num] = NodeType.OBSTACLE
                    
                cells = generate_cells(deform_pos[0], trace_data["tissue_mesh_triangles"])
                
                group_1 = f.create_group(str(trace_id))
                trace_length = pos.shape[0]
                
                deform_cells_edges = cells_to_edges(cells)
                deform_faces, deform_faces_edges, cells_faces = cells_to_faces(cells)
                rigid_faces = trace_data["gripper_triangles"] + deform_point_num
                rigid_faces_edges = np.zeros((rigid_faces.shape[0], 2)) - 1
                faces_edges = np.concatenate([deform_faces_edges, rigid_faces_edges], axis = 0)
                faces = np.concatenate([deform_faces, rigid_faces], axis = 0)
                rigid_cells_edges = cells_to_edges(rigid_faces)
                cells_edges = np.concatenate([deform_cells_edges, rigid_cells_edges], axis = 0)
                
                
                faces = faces.astype(int)
                is_surface = np.zeros(faces.shape[0]).astype(int)
                for i in range(faces.shape[0]):
                    is_surface[i] = faces_edges[i][-1] < 0
                mask = np.zeros(faces.shape[0]).astype(int)
                for i in range((faces.shape[0])):
                    mask[i] = (node_type[faces[i, 0]] == NodeType.NORMAL or node_type[faces[i, 0]] == NodeType.HANDLE) and is_surface[i]
                mask1 = np.zeros(faces.shape[0]).astype(int)
                for i in range((faces.shape[0])):
                    mask1[i] = node_type[faces[i, 0]] == NodeType.OBSTACLE and is_surface[i] 
                    
                step_slice, real_slice = 1, 1
                for frame_id in range(0, trace_length - step_slice, real_slice):
                    group_2 = group_1.create_group(str(frame_id))
                    next_pos_need = copy.deepcopy(pos[frame_id + step_slice])
                    next_pos_need[(node_type == NodeType.NORMAL).reshape(-1), :] = 0
                    last_pos, last_stress = 0, 0
                    if (frame_id > 0):
                        last_pos = pos[frame_id - step_slice]
                    else:
                        last_pos = pos[frame_id]
                    new_faces = node_sorted(copy.deepcopy(faces), pos[frame_id])
                    need_to_store = { 
                        "final_flag": np.array([((frame_id + real_slice) >= (trace_length - step_slice))]),
                        "mesh_pos": copy.deepcopy(pos[0]),
                        "mesh_edge": copy.deepcopy(cells_edges),
                        "cells_faces": face_sorted(copy.deepcopy(cells_faces), new_faces, pos[frame_id]),
                        "last_pos": last_pos,
                        "world_pos": pos[frame_id],
                        "node_type": copy.deepcopy(node_type),
                        "elements": node_sorted(cells, pos[frame_id]),  
                        "faces": new_faces,
                        "world_edge": nodes_to_edges(pos[frame_id], node_type),
                        "faces_edges": copy.deepcopy(faces_edges),
                        "next_pos_need": next_pos_need,
                        "faces_to_faces": faces_to_faces(faces, pos[frame_id], node_type, mask, mask1),
                        "target": pos[frame_id + step_slice]
                    }
 
                    for name, value in need_to_store.items():
                        group_2.create_dataset(name, data=value)
                    
                
if __name__ == '__main__':
    dataset = PreDataSetGC("/cavity_grasping_dataset/", "train")
    
    
