# load_data.py

import torch
from torch_geometric.data import Data
import numpy as np
import os, pickle, random

def load_data(i, j, args):
    """Load the training dataset."""
    simulation_num = i
    step_num = j
    attributes_file = f"MeshAttributes_{simulation_num}.pkl"
    data_file = f"MeshData_{simulation_num}.pkl"
    attributes_path = os.path.join(args['data_folder'], attributes_file)
    data_path = os.path.join(args['data_folder'], data_file)

    with open(attributes_path, 'rb') as file:
        mesh_attributes = pickle.load(file)
    with open(data_path, 'rb') as file:
        mesh_data = pickle.load(file)

    edge_index = torch.Tensor(mesh_data["GraphStructure"]).to(torch.int64).t()
    # mesh_attributes = mesh_attributes[0]
    x = torch.from_numpy(mesh_attributes[0]["PointAttributes"]).to(torch.float32)
    time_step = torch.cat((torch.Tensor([step_num + 1]).unsqueeze(0), torch.Tensor([args['time_step']]).unsqueeze(0)), dim=1).to(torch.float32)
    y_E = torch.from_numpy(mesh_attributes[step_num + 1]["PointAttributes"][:, 0]).unsqueeze(1).to(torch.float32)
    y_H = torch.from_numpy(mesh_attributes[step_num + 1]["FaceAttributes"][:, 2:4]).to(torch.float32)
    edge_attr = torch.from_numpy(mesh_attributes[0]["FaceAttributes"]).to(torch.float32)
    length_attributes = torch.Tensor([[x.size(0), edge_attr.size(0)]]).to(torch.int64).cuda()
    points = np.array(mesh_attributes[0]["Points"])

    x = torch.cat((x, time_step.repeat(x.size(0), 1)), dim=1)
    edge_attr = torch.cat((edge_attr, time_step.repeat(edge_attr.size(0), 1)), dim=1)
    
    tensors = {
        "edge_index": edge_index,
        "x": x,
        "edge_attr": edge_attr,
        "y_E": y_E,
        "y_H": y_H,
    }
    for name, tensor in tensors.items():
        assert isinstance(tensor, torch.Tensor), \
            f"{name} must be a torch.Tensor, got {type(tensor)}"
        assert tensor.numel() > 0, \
            f"{name} is empty, shape={tensor.shape}"
        for dim_idx, dim in enumerate(tensor.shape):
            assert dim > 0, \
                f"{name} dimension {dim_idx} is zero"
    assert isinstance(points, np.ndarray), \
        f"points must be numpy.ndarray, got {type(points)}"
    assert points.size > 0, \
        f"points is empty, shape={points.shape}"

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y_E=y_E, y_H=y_H, points=points)