import torch
import torch.nn.functional as F 
from torch.utils.data import DataLoader
import h5py
import os
import numpy as np
import operator
from itertools import accumulate
import copy
from utils import NodeType
import math

previous_face_edges = []
frame_index = [0]
def custom_collate(data):
    final_flag = [torch.tensor(d['final_flag']) for d in data]
    world_pos = [torch.tensor(d['world_pos']) for d in data]
    stress = [torch.tensor(d['stress']) for d in data]
    node_type = [torch.tensor(d['node_type']) for d in data]
    ground_truth = [torch.tensor(d['target']) for d in data]
    next_pos_need = [torch.tensor(d['next_pos_need']) for d in data]
    ori_faces = [torch.tensor(d["faces"]) for d in data] 
    ori_elements = [torch.tensor(d["elements"]) for d in data] 
    ori_faces_sequ = [torch.tensor(d["faces_sequ"]) for d in data] 
    ori_elements_sequ = [torch.tensor(d["elements_sequ"]) for d in data] 
    ori_faces_edges = [torch.tensor(d["faces_edges"]) for d in data]
    ori_faces_to_faces = [torch.tensor(d["faces_to_faces"]) for d in data]
    last_pos = [torch.tensor(d['last_pos']) for d in data]
    last_stress = [torch.tensor(d['last_stress']) for d in data]
    ori_cells_faces = [torch.tensor(d['cells_faces']) for d in data]
    ori_mesh_edge = [torch.tensor(d['mesh_edge']) for d in data]
    mesh_pos = [torch.tensor(d['mesh_pos']) for d in data]
    ori_world_edge = [torch.tensor(d['world_edge']) for d in data]
    na_index = [torch.tensor(d['na_index']) for d in data]
    nb_index = [torch.tensor(d['nb_index']) for d in data]

    final_flag = torch.concat(final_flag, dim = 0)
    mesh_pos = torch.concat(mesh_pos, dim = 0)
    world_pos = torch.concat(world_pos, dim = 0)
    stress = torch.concat(stress, dim = 0)
    node_type = torch.concat(node_type, dim = 0)    
    ground_truth = torch.concat(ground_truth, dim = 0)
    next_pos_need = torch.concat(next_pos_need, dim = 0)
    last_pos = torch.concat(last_pos, dim = 0)
    last_stress = torch.concat(last_stress, dim = 0)
    na_index = torch.concat(na_index, dim = 0)
    nb_index = torch.concat(nb_index, dim = 0)
        
    node_count = [d["stress"].shape[0] for d in data]
    element_count = [d['elements'].shape[0] for d in data]
    face_count = [d['faces'].shape[0] for d in data]
    node_count = torch.tensor([0] + list(accumulate(node_count))[:-1]) 
    element_count = torch.tensor([0] + list(accumulate(element_count))[:-1]) 
    face_count = torch.tensor([0] + list(accumulate(face_count))[:-1]) 
    for i in range(len(ori_elements)):
        ori_world_edge[i] = ori_world_edge[i] + node_count[i]
        ori_mesh_edge[i] = ori_mesh_edge[i] + node_count[i]
        ori_elements[i] = ori_elements[i] + node_count[i]
        ori_faces[i] = ori_faces[i] + node_count[i]
        ori_elements_sequ[i] = ori_elements_sequ[i] + node_count[i]
        ori_faces_sequ[i] = ori_faces_sequ[i] + node_count[i]
        ori_faces_edges[i] = ori_faces_edges[i] + element_count[i]
        ori_faces_to_faces[i] = ori_faces_to_faces[i] + face_count[i]
        ori_cells_faces[i] = ori_cells_faces[i] + face_count[i]
    elements = torch.concat(ori_elements, dim = 0)
    faces = torch.concat(ori_faces, dim = 0)
    elements_sequ = torch.concat(ori_elements_sequ, dim = 0)
    faces_sequ = torch.concat(ori_faces_sequ, dim = 0)
    faces_edges = torch.concat(ori_faces_edges, dim = 0)
    faces_to_faces = torch.concat(ori_faces_to_faces, dim = 0)
    cells_faces = torch.concat(ori_cells_faces, dim = 0)
    world_edge = torch.concat(ori_world_edge, dim = 0)
    mesh_edge = torch.concat(ori_mesh_edge, dim = 0)

    return  { 
                "final_flag": final_flag,
                "elements_sequ":elements_sequ,
                "faces_sequ":faces_sequ,
                "cells_faces":cells_faces,
                "last_stress":last_stress,
                "last_pos":last_pos,
                "world_pos": world_pos,
                "stress": stress,
                "node_type": node_type,
                "target": ground_truth,
                "next_pos_need": next_pos_need,
                "elements": elements,
                "faces": faces,
                "faces_edges": faces_edges,
                "faces_to_faces": faces_to_faces,
                "mesh_edge": mesh_edge,
                "mesh_pos": mesh_pos, 
                "world_edge": world_edge,
                "na_index": na_index,
                "nb_index": nb_index
            }

def GetDataSetMSB(dataset_dir, split, batch_size = 1, need_create = False):
    file_path = os.path.join(dataset_dir, split)
    frame_data_collected = []
    cnter = 0
    input_file_path = os.path.join(dataset_dir, f"full_data_{split}.h5")

    with h5py.File(input_file_path, 'r') as f:
        for instance in os.listdir(file_path):
            group_1 = f[instance]
            instance_path = os.path.join(file_path, instance)
            instance_data = np.load(instance_path)
            print(instance_path)
            trace_length = instance_data["STRIP_U"].shape[0]
            step_slice = 4
            real_slice = step_slice     
            
            for frame_id in range(0, trace_length - step_slice, real_slice):
                if (frame_id + real_slice < trace_length - step_slice):
                    data_set = group_1[str(frame_id)]
                    faces = np.array(data_set["faces"])
                    na_index = np.zeros(faces.shape[0]).reshape(-1)
                    nb_index = np.zeros(faces.shape[0]).reshape(-1)
                    
                    frame_data_collected.append(
                        { 
                        "final_flag": np.array([frame_id + real_slice + real_slice >= trace_length - step_slice]),
                        "elements_sequ":np.array(data_set["elements_sequ"]),
                        "faces_sequ":np.array(data_set["faces_sequ"]),
                        "mesh_pos": np.array(data_set["mesh_pos"]).astype(np.float32),
                        "mesh_edge": np.array(data_set["mesh_edge"]),
                        "cells_faces": np.array(data_set["cells_faces"]),
                        "last_stress": np.array(data_set["last_stress"]).astype(np.float32),
                                        #np.array(data_set_pre["stress"]).astype(np.float32),
                        "last_pos": np.array(data_set["last_pos"]).astype(np.float32),
                                    #np.array(data_set_pre["world_pos"]).astype(np.float32),
                        "world_pos": np.array(data_set["world_pos"]).astype(np.float32),
                        "stress": np.array(data_set["stress"]).astype(np.float32),
                        "node_type": np.array(data_set["node_type"]),
                        "elements": np.array(data_set["elements"]),
                        "faces": np.array(data_set["faces"]),
                        "world_edge": np.array(data_set["world_edge"]),
                        "faces_edges": np.array(data_set["faces_edges"]),
                        "next_pos_need":  np.array(data_set["next_pos_need"]),
                                            #np.array(data_set_next["world_pos"]).astype(np.float32),
                        "faces_to_faces": np.array(data_set["faces_to_faces"]),
                        "target": np.array(data_set["target"]),
                        "na_index": na_index,
                        "nb_index": nb_index
                        }
                    )         
                
            cnter = cnter + 1

    dataset = DataLoader(dataset = frame_data_collected, batch_size = batch_size, shuffle = False, collate_fn=custom_collate, drop_last=True)
    
    return dataset
if __name__ == '__main__':
    dataset = GetDataSetMSB("/strip_simple/", "train", batch_size=1, need_create=False)
    
    
