
from types import NoneType
import numpy as np
import torch

import logging


from utils.domain import create_subdomains
import dgl
from models.GNOT.utils import MultipleTensors
from torch.nn.utils.rnn import pad_sequence
from models.GNOT.data_utils import MIODataset
from models.GNOT.utils import TorchQuantileTransformer, UnitTransformer, PointWiseUnitTransformer, MultipleTensors

import typing
from utils.augmentation import Darcy2dTransform, Laplace2dTransform, Laplace2dnTransform, Heat2dTransform, Helmholtz2dTransform
from utils.normalization import Laplace2dNormalizer, Laplace2dnNormalizer, Darcy2dNormalizer, Heat2dNormalizer
from trimesh.interfaces.gmsh import load_gmsh
import gmshparser
logger = logging.getLogger(__name__)

# def get_decomposed_dataset(args):
#     if args.dataset == "ns2d":
#         train_path = '/home/hin4sgh/pde/data/ns2d_1100_train.pkl'
#         test_path = '/home/hin4sgh/pde/data/ns2d_1100_test.pkl'
#     elif args.dataset == "inductor2d":
#         train_path = "/home/hin4sgh/pde/data/inductor2d_1100_train.pkl"
#         test_path = "/home/hin4sgh/pde/data/inductor2d_1100_test.pkl"

#     elif args.dataset == "heat2d":
#         train_path = "/home/hin4sgh/pde/data/heat2d_1100_train.pkl"
#         test_path = "/home/hin4sgh/pde/data/heat2d_1100_test.pkl"

#     else:
#         raise NotImplementedError

#     args.train_num = int(args.train_num) if args.train_num not in ['all', 'none'] else args.train_num
#     args.test_num = int(args.test_num) if args.test_num not in ['all', 'none'] else args.test_num

#     train_dataset = DecomposedDomainDataset(train_path, name=args.dataset, train=True, train_num=args.train_num,
#                                sort_data=args.sort_data,
#                                normalize_y=args.use_normalizer,
#                                normalize_x=args.normalize_x,
#                                n_parts=args.n_parts)
#     test_dataset = DecomposedDomainDataset(test_path, name=args.dataset, train=False, test_num=args.test_num,
#                               sort_data=args.sort_data,
#                               normalize_y=args.use_normalizer,
#                               normalize_x=args.normalize_x, y_normalizer=train_dataset.y_normalizer,
#                               x_normalizer=train_dataset.x_normalizer, up_normalizer=train_dataset.up_normalizer,
#                               n_parts=args.n_parts)

#     args.dataset_config = train_dataset.config

#     return train_dataset, test_dataset

def get_train_dolphinx_dataset(args):
    transform = lambda x, y, z: (x, y, z)
    time_dependent = False
    time_step = None
    # if args.dataset == "laplace3d_union":
    #     train_path = '/home/hin4sgh/pde/data/laplace3d_union_aug_22000_train.pkl'
    #     test_path = '/home/hin4sgh/pde/data/laplace3d_union_aug_22000_test.pkl'
    # elif args.dataset == "laplace2d_union":
    #     train_path = '/home/hin4sgh/pde/data/laplace2d_union_aug_22000_train_reorder.pkl'
    #     test_path = '/home/hin4sgh/pde/data/laplace2d_union_aug_22000_test_reorder.pkl'
    # if args.dataset == "poisson2d_union":
    #     train_path = '/home/hin4sgh/pde/data/poisson2d_union_aug_22000_train.pkl'
    #     test_path = '/home/hin4sgh/pde/data/poisson2d_union_aug_22000_test.pkl'
    if args.dataset == "laplace2d_simple":
        train_path = './data/2d/laplace2d_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.5, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_plain":
        train_path = './data/2d/laplace2d_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
    elif args.dataset == "laplace2d_simple_rotation":
        train_path = './data/2d/laplace2d_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[1.0, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_rotation_scale_0.2":
        train_path = './data/2d/laplace2d_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.2, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_rotation_scale_0.8":
        train_path = './data/2d/laplace2d_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_10000":
        train_path = './data/2d/laplace2d_simple_10000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.5, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_5000":
        train_path = './data/2d/laplace2d_simple_5000_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.5, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_simple_2500":
        train_path = './data/2d/laplace2d_simple_2500_train.pkl'
        test_path = './data/2d/laplace2d_simple_2500_test.pkl'
        transform = Laplace2dTransform(max_space_scale=[0.5, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_n_simple":
        train_path = './data/2d/laplace2d_n_simple_40000_train.pkl'
        test_path = './data/2d/laplace2d_n_simple_4000_test.pkl'
        transform = Laplace2dnTransform(max_space_scale=[1.0, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "laplace2d_n_simple_20000":
        train_path = './data/2d/laplace2d_n_simple_20000_train.pkl'
        test_path = './data/2d/laplace2d_n_simple_4000_test.pkl'
        transform = Laplace2dnTransform(max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_n_simple_10000":
        train_path = './data/2d/laplace2d_n_simple_10000_train.pkl'
        test_path = './data/2d/laplace2d_n_simple_4000_test.pkl'
        transform = Laplace2dnTransform(max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_n_simple_5000":
        train_path = './data/2d/laplace2d_n_simple_5000_train.pkl'
        test_path = './data/2d/laplace2d_n_simple_4000_test.pkl'
        transform = Laplace2dnTransform(max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)
    elif args.dataset == "laplace2d_n_simple_2500":
        train_path = './data/2d/laplace2d_n_simple_2500_train.pkl'
        test_path = './data/2d/laplace2d_n_simple_4000_test.pkl'
        transform = Laplace2dnTransform(max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                            always_apply=True)

    # elif args.dataset == "poisson2d_simple":
    #     train_path = '/home/hin4sgh/pde/data/2d/poisson2d_simple_20000_train.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/poisson2d_simple_2000_test.pkl'
    elif args.dataset == "darcy2d_simple":
        train_path = './data/2d/darcy2d_simple_40000_train.pkl'
        test_path = './data/2d/darcy2d_simple_2500_test.pkl'
        # transform = Darcy2dTransform(
        #                     max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                      always_apply=True)
    elif args.dataset == "darcy2d_simple_20000":
        train_path = './data/2d/darcy2d_simple_20000_train.pkl'
        test_path = './data/2d/darcy2d_simple_2500_test.pkl'
        # transform = Darcy2dTransform(
        #                     max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                      always_apply=True)
    elif args.dataset == "darcy2d_simple_10000":
        train_path = './data/2d/darcy2d_simple_10000_train.pkl'
        test_path = './data/2d/darcy2d_simple_2500_test.pkl'
        # transform = Darcy2dTransform(
        #                     max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                      always_apply=True)
    elif args.dataset == "darcy2d_simple_5000":
        train_path = './data/2d/darcy2d_simple_5000_train.pkl'
        test_path = './data/2d/darcy2d_simple_2500_test.pkl'
        # transform = Darcy2dTransform(
        #                     max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                      always_apply=True)
    elif args.dataset == "darcy2d_simple_2500":
        train_path = './data/2d/darcy2d_simple_2500_train.pkl'
        test_path = './data/2d/darcy2d_simple_2500_test.pkl'
        # transform = Darcy2dTransform(
        #                     max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                      always_apply=True)
    elif args.dataset == "heat2d_simple":
        train_path = './data/2d/heat2d_simple_100000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_80000":
        train_path = './data/2d/heat2d_simple_80000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_40000":
        train_path = './data/2d/heat2d_simple_40000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_20000":
        train_path = './data/2d/heat2d_simple_20000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_10000":
        train_path = './data/2d/heat2d_simple_10000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_5000":
        train_path = './data/2d/heat2d_simple_5000_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)
    elif args.dataset == "heat2d_simple_2500":
        train_path = './data/2d/heat2d_simple_2500_train.pkl'
        test_path = './data/2d/heat2d_simple_12500_test.pkl'
        time_dependent = args.time_dependent
        time_step = args.time_step
        transform = Heat2dTransform(
                            max_space_scale=[0.8, 1.0],
                            max_value_scale=[1.0, 1.0],
                             always_apply=True)

    # elif args.dataset == "helmholtz2d_simple":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_simple_40000_train.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_simple_5000_test.pkl'
    #     transform = Helmholtz2dTransform(
    #                 max_space_scale=[0.5, 1.0],
    #                 max_value_scale=[1.0, 1.0],
    #                     always_apply=True)

    # elif args.dataset == "helmholtz2d_schwarz":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_schwarz_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_schwarz_100_test.pkl'
    # elif args.dataset == "helmholtz2d_holes":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_holes_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_holes_100_test.pkl'

    args.train_num = int(args.train_num) if args.train_num not in ['all', 'none'] else args.train_num
    args.test_num = int(args.test_num) if args.test_num not in ['all', 'none'] else args.test_num


    train_dataset = DolphinxNeuralOperatorDataset(train_path, name=args.dataset, train=True, train_num=args.train_num,
                               sort_data=args.sort_data,
                               normalize_y=args.normalize_y,
                               normalize_x=args.normalize_x,
                               transform=transform,
                               time_dependent=time_dependent,
                               time_step=time_step)
    test_dataset = DolphinxNeuralOperatorDataset(test_path, name=args.dataset, train=False, test_num=args.test_num,
                              sort_data=args.sort_data,
                              normalize_y=args.normalize_y,
                              normalize_x=args.normalize_x, y_normalizer=train_dataset.y_normalizer,
                              x_normalizer=train_dataset.x_normalizer, up_normalizer=train_dataset.up_normalizer,
                              transform=transform,
                              time_dependent=time_dependent,
                              time_step=time_step)
    args.dataset_config = train_dataset.config

    return train_dataset, test_dataset

def get_inference_dolphinx_dataset(args):
    time_dependent = False
    time_span = None
    time_step = None
    if args.dataset == "laplace2d_schwarz":
        test_path = './data/2d/laplace2d_schwarz_100_test.pkl'
    elif args.dataset == "laplace2d_holes":
        test_path = './data/2d/laplace2d_holes_100_test.pkl'
    elif args.dataset == "laplace2d_bosch":
        test_path = './data/2d/laplace2d_bosch_100_test.pkl'  
    elif args.dataset == "laplace2d_n_schwarz":
        test_path = './data/2d/laplace2d_n_schwarz_100_test.pkl'
    elif args.dataset == 'laplace2d_n_holes':
        test_path = './data/2d/laplace2d_n_holes_100_test.pkl'
    elif args.dataset == 'laplace2d_n_bosch':
        test_path = './data/2d/laplace2d_n_bosch_100_test.pkl'
    # elif args.dataset == 'laplace2d_n_plane_with_1hole':
    #     train_path = '/home/hin4sgh/pde/data/2d/laplace2d_n_plane_with_1hole_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/laplace2d_n_plane_with_1hole_100_test.pkl'
    #     # transform = LaplaceN2DTransform(max_space_scale=[1.0, 1.0],
    #     #                     max_value_scale=[1.0, 1.0],
    #     #                     always_apply=True)
    
        # transform = LaplaceN2DTransform(max_space_scale=[1.0, 1.0],
        #                     max_value_scale=[1.0, 1.0],
        #                     always_apply=True)
    
    # elif args.dataset == "darcy2d_plane_with_1hole":
    #     train_path = '/home/hin4sgh/pde/data/2d/darcy2d_plane_with_1hole_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/darcy2d_plane_with_1hole_100_test.pkl'
    elif args.dataset == "darcy2d_holes":
        test_path = './data/2d/darcy2d_holes_100_test.pkl'
    elif args.dataset == "darcy2d_schwarz":
        test_path = './data/2d/darcy2d_schwarz_100_test.pkl'
    elif args.dataset == "darcy2d_bosch":
        test_path = './data/2d/darcy2d_bosch_100_test.pkl'
    
    elif args.dataset == "heat2d_schwarz":
        test_path = './data/2d/heat2d_schwarz_10_test.pkl'
        time_dependent = True 
        time_span = args.time_span
        time_step = args.time_step
    elif args.dataset == "heat2d_holes":
        test_path = './data/2d/heat2d_holes_10_test.pkl'
        time_dependent = True
        time_span = args.time_span
        time_step = args.time_step
    elif args.dataset == "heat2d_bosch":
        test_path = './data/2d/heat2d_bosch_10_test.pkl'
        time_dependent = True
        time_span = args.time_span
        time_step = args.time_step
    # elif args.dataset == "heat2d_plane_with_1hole":
    #     test_path = '/home/hin4sgh/pde/data/2d/heat2d_plane_with_1hole_100_test.pkl'
    #     time_dependent = True

    # elif args.dataset == "helmholtz2d_simple":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_simple_40000_train.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_simple_5000_test.pkl'
    #     transform = Helmholtz2dTransform(
    #                 max_space_scale=[0.5, 1.0],
    #                 max_value_scale=[1.0, 1.0],
    #                     always_apply=True)

    # elif args.dataset == "helmholtz2d_schwarz":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_schwarz_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_schwarz_100_test.pkl'
    # elif args.dataset == "helmholtz2d_holes":
    #     train_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_holes_100_test.pkl'
    #     test_path = '/home/hin4sgh/pde/data/2d/helmholtz2d_holes_100_test.pkl'

    args.test_num = int(args.test_num) if args.test_num not in ['all', 'none'] else args.test_num

    test_dataset = DolphinxNeuralOperatorDataset(test_path, name=args.dataset, train=False, test_num=args.test_num,
                              sort_data=args.sort_data,
                              normalize_y=args.normalize_y,
                              normalize_x=args.normalize_x, 
                              time_dependent=time_dependent,
                              time_span=time_span,
                              time_step=time_step)
    args.dataset_config = test_dataset.config

    return test_dataset

def get_inference_mesh(args):
    domain = args.dataset.split("_")[-1]
    if domain == "schwarz":
        mesh_path = './data/mesh/A-schwarz.msh'
    elif domain == "holes":
        mesh_path = './data/mesh/B-holes.msh'
    elif domain == "bosch":
        mesh_path = './data/mesh/C-bosch.msh'
    
    return gmshparser.parse(mesh_path), load_gmsh(mesh_path) 

def get_inference_boundary_marker(args, gmesh):
    pde = "_".join(args.dataset.split("_")[0:-1])
    domain = args.dataset.split("_")[-1]

    if pde in ["laplace2d", "darcy2d", "heat2d"]:
        # get boundary nodes of element type 1
        boundary_nodes = set()
        for entity in gmesh.get_element_entities():
            if entity.get_element_type() == 1:
                for element in entity.get_elements():
                    for n in element.get_connectivity():
                        boundary_nodes.add(n)
        # get bounday position of boundary nodes
        boundary_points = []
        for entity in gmesh.get_node_entities():
            for node in entity.get_nodes():
                if node.get_tag() in boundary_nodes:
                    boundary_points.append(node.get_coordinates())
        boundary_points = np.array(boundary_points)

        boundary_marker = {"dirichlet": boundary_points, "neumann":[]}
    elif pde == "laplace2d_n":
        if domain == "schwarz":
            # schwarz
            # get boundary nodes of element type 1
            db_index, nb_index = set(), set()
            for entity in gmesh.get_element_entities():
                if (entity.get_element_type() == 1) and (entity.get_tag() in [6]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            db_index.add(n)
                elif (entity.get_element_type() == 1) and (entity.get_tag() in [7,8,9]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            nb_index.add(n)
            # get bounday position of boundary nodes
            db_points, nb_points = [], []
            for entity in gmesh.get_node_entities():
                for node in entity.get_nodes():
                    if node.get_tag() in db_index:
                        db_points.append(node.get_coordinates())
                    elif node.get_tag() in nb_index:
                        nb_points.append(node.get_coordinates())
            db_points = np.array(db_points)
            nb_points = np.array(nb_points)
            boundary_marker = {"dirichlet": db_points, "neumann":nb_points}

        elif domain == "holes":
            # holes
            # get boundary nodes of element type 1
            db_index, nb_index = set(), set()
            for entity in gmesh.get_element_entities():
                if (entity.get_element_type() == 1) and (entity.get_tag() in [8, 9]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            db_index.add(n)
                elif (entity.get_element_type() == 1) and (entity.get_tag() in [1,2,3,4]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            nb_index.add(n)
            # get bounday position of boundary nodes
            db_points, nb_points = [], []
            for entity in gmesh.get_node_entities():
                for node in entity.get_nodes():
                    if node.get_tag() in db_index:
                        db_points.append(node.get_coordinates())
                    elif node.get_tag() in nb_index:
                        nb_points.append(node.get_coordinates())
            db_points = np.array(db_points)
            nb_points = np.array(nb_points)
            boundary_marker = {"dirichlet": db_points, "neumann":nb_points}
        
        elif domain == "bosch":
            # bosch
            # get boundary nodes of element type 1
            db_index, nb_index = set(), set()
            for entity in gmesh.get_element_entities():
                if (entity.get_element_type() == 1) and (entity.get_tag() in [1]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            db_index.add(n)
                elif (entity.get_element_type() == 1) and (entity.get_tag() in [2, 3, 4, 5, 6, 8, 9, 11]):
                    for element in entity.get_elements():
                        for n in element.get_connectivity():
                            nb_index.add(n)
            # get bounday position of boundary nodes
            db_points, nb_points = [], []
            for entity in gmesh.get_node_entities():
                for node in entity.get_nodes():
                    if node.get_tag() in db_index:
                        db_points.append(node.get_coordinates())
                    elif node.get_tag() in nb_index:
                        nb_points.append(node.get_coordinates())
            db_points = np.array(db_points)
            nb_points = np.array(nb_points)
            boundary_marker = {"dirichlet": db_points, "neumann":nb_points}

    return boundary_marker

def get_inference_normalizer(args):
    pde = "_".join(args.dataset.split("_")[0:-1])
    if pde == "laplace2d":
        normalizer = Laplace2dNormalizer
    elif pde == "laplace2d_n":
        normalizer = Laplace2dnNormalizer
    elif pde == "darcy2d":
        normalizer = Darcy2dNormalizer
    elif pde == "heat2d":
        normalizer = Heat2dNormalizer
    return normalizer

# class HDF5Dataset(Dataset):
#     """
#     Load samples of an PDE Dataset, get items according to PDE.
#     """
#     def __init__(self, path: str,
#                  mode: str,
#                  nt: int,
#                  nx: int,
#                  shift: str,
#                  pde: PDE = None,
#                  dtype=torch.float64,
#                  load_all: bool=False):
#         """Initialize the dataset object.
#         Args:
#             path: path to dataset
#             mode: [train, valid, test]
#             nt: temporal resolution
#             nx: spatial resolution
#             shift: [fourier, linear]
#             pde: PDE at hand
#             dtype: floating precision of data
#             load_all: load all the data into memory
#         """
#         super().__init__()
#         f = h5py.File(path, 'r')
#         self.mode = mode
#         self.dtype = dtype
#         self.data = f[self.mode]
#         self.dataset = f'pde_{nt}-{nx}'
#         self.pde = PDE() if pde is None else pde
#         self.augmentation = None
#         self.shift = 'fourier' if shift is None else shift

#         # Generators which are used for LSDAP
#         # Time generator is treated a bit differently and is implemented in the training loop
#         if str(self.pde) == 'KdV':
#             self.augmentation = KdV_augmentation(self.pde.max_x_shift,
#                                                  self.pde.max_velocity,
#                                                  self.pde.max_scale)
#         elif str(self.pde) == 'KS':
#             self.augmentation = KS_augmentation(self.pde.max_x_shift,
#                                                 self.pde.max_velocity)

#         # For the Heat equation, infinite subalgebra is evoked in __getitem__
#         elif str(self.pde) == 'Heat':
#             self.augmentation = Heat_augmentation(self.pde.max_x_shift)

#         if load_all:
#             data = {self.dataset: self.data[self.dataset][:]}
#             f.close()
#             self.data = data

#     def __len__(self):
#         return self.data[self.dataset].shape[0]

#     def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
#         """
#         Returns data items for batched training/validation/testing.
#         Args:
#             idx: data index
#         Returns:
#             torch.Tensor: data trajectory used for training/validation/testing
#             torch.Tensor: dx
#             torch.Tensor: dt
#         """
#         u = self.data[self.dataset][idx]
#         x = self.data['x'][idx]
#         t = self.data['t'][idx]

#         if str(self.pde) == 'Heat':
#             X = to_coords(torch.tensor(x), torch.tensor(t))
#             sol = (torch.tensor(u), X)
#             if self.mode == "train" and self.augmentation is not None:
#                 # Obtain a random second trajectory which is used for data mixing
#                 idx2 = torch.randint(0, self.__len__(), (1,))
#                 u2 = self.data[self.dataset][idx2]
#                 sol2 = (torch.tensor(u2), X)
#                 # Cole-Hopf transformation
#                 # For alpha > 0 mixing occurs
#                 sol = self.pde.subalgebra(sol, sol2, alpha=self.pde.alpha)
#                 # Remaining generators for data augmentation
#                 sol = self.augmentation(sol)

#             else:
#                 # Cole-Hopf transformation
#                 # For alpha == 0 no, mixing between different trajectories
#                 sol = self.pde.subalgebra(sol, alpha=0.)

#             # Scaling of the whole trajectory, otherwise amplitudes are pretty high
#             u = sol[0]
#             u = u / 100
#             X = sol[1]
#             # Only needed when scaling generator is added
#             dx = X[0, 1, 0] - X[0, 0, 0]
#             dt = X[1, 0, 1] - X[0, 0, 1]

#         else:
#             X = to_coords(torch.tensor(x), torch.tensor(t))
#             sol = (torch.tensor(u), X)

#             # Data augmentation using the defined generators for the equation at hand
#             if self.mode == "train" and self.augmentation is not None:
#                 sol = self.augmentation(sol, self.shift)

#             u = sol[0]
#             X = sol[1]
#             dx = X[0, 1, 0] - X[0, 0, 0]
#             dt = X[1, 0, 1] - X[0, 0, 1]

#         return u.float(), X.float(), dx.float(), dt.float()


# class DecomposedDomainDataset(MIODataset):
#     def __init__(self, 
#                  *args, dim=2, n_parts=10, time_independent=False,
#                  **kwargs, 
#                  ):
#         self.time_independent = time_independent
#         self.n_parts = n_parts
#         self.dim = dim
#         super(DecomposedDomainDataset, self).__init__(*args, **kwargs)
#         self.__update_dataset_config()
#         self.data_len = len(self.dcps)
        
    
#     def process(self):
        
#         assert self.name in ["heat2d", "ns2d"], logger.error("DDM is supported only on 'heat2d'")

#         super(DecomposedDomainDataset, self).process()

#         self.dcps, self.decomp_graphs, self.decomp_u_ps, self.decomp_inputs_f = [], [], [], []

#         for idx in tqdm(range(len(self))):
#             if self.name == "heat2d":
#                 x, y = self.graphs[idx].ndata['x'].numpy(), self.graphs[idx].ndata['y'].numpy()
#                 sol = np.concatenate((x, y), axis=1)
#                 x1max = x[:, 1].max()
#                 boundary_func = sol[sol[:, 1] == x1max]
#                 interior_func = sol[~(sol[:, 1] == x1max)]
#                 u_p = self.u_p[idx]

#                 # inputs_f without taking top boundary value
#                 inputs_f = MultipleTensors(self.inputs_f[idx].x)
#                 inputs_f.x = [inputs_f.x[0]] + inputs_f.x[2:]
#             elif self.name == "ns2d":
#                 x, y = self.graphs[idx].ndata['x'].numpy(), self.graphs[idx].ndata['y'].numpy()
#                 sol = np.concatenate((x, y), axis=1)
#                 tree = KDTree(x)
#                 boundary_index = [tree.query(f.numpy())[1] for f in self.x_normalizer.transform(self.inputs_f[idx][0], inverse=False)]
#                 boundary_func = np.stack(sol[boundary_index], axis=0)

#                 interior_func = np.delete(sol, boundary_index, axis=0)
#                 u_p = self.u_p[idx]
#                 inputs_f = MultipleTensors([])
                

#             interior_func = np.insert(interior_func, self.dim, values=range(interior_func.shape[0]), axis=1)
#             boundary_func = np.insert(boundary_func, self.dim, values=range(interior_func.shape[0],interior_func.shape[0]+boundary_func.shape[0]), axis=1)
#             func = np.concatenate((interior_func, boundary_func), axis=0)
            
#             # graph representing simplices and decomposed domains
#             G, dcps = create_subdomains(interior_func[:,:self.dim+1], boundary_func[:, :self.dim+1], dim=self.dim, n_parts=self.n_parts)
            
#             # node locations in each domain
#             X = [dcps.getSubDomain(i)[1] for i in range(len(dcps.subDomain.nodes))]
#             Y = [func[dcps.getSubDomain(i)[0], self.dim+1:] for i in range(len(dcps.subDomain.nodes))]
#             graphs = []
#             for x, y in zip(X, Y):
#                 g = dgl.DGLGraph()
#                 g.add_nodes(x.shape[0])
#                 g.ndata['x'] = torch.from_numpy(x).float()
#                 g.ndata['y'] = torch.from_numpy(y).float()
#                 graphs.append(g)

            
#             xb = [dcps.getBoundary(i) for i in range(len(dcps.subDomain.nodes))]
#             yb = [func[x_[0]][:, self.dim+1:] for x_ in xb]
#             b_func = [np.concatenate((x_[1], y_), axis=-1) for x_, y_ in zip(xb, yb)]
            
#             u_ps = [u_p for _ in range(len(graphs))]
#             # boundary functions
#             inputs_f = [MultipleTensors([torch.from_numpy(f).float()] + inputs_f.x) for f in b_func]

#             self.dcps.append(dcps)
#             self.decomp_graphs.append(graphs)
#             self.decomp_u_ps.append(torch.stack(u_ps))
#             self.decomp_inputs_f.append(inputs_f)

        
#         return

#     def __update_dataset_config(self):
#         self.config = {
#             'input_dim': self.decomp_graphs[0][0].ndata['x'].shape[1],
#             'theta_dim': self.decomp_u_ps[0].shape[1],
#             'output_dim': self.decomp_graphs[0][0].ndata['y'].shape[1],
#             #'branch_sizes': [x.shape[1] for x in self.inputs_f[0]] if isinstance(self.inputs_f, list) else 0
#             'branch_sizes': [x.shape[1] for x in self.decomp_inputs_f[0][0]] if isinstance(self.decomp_inputs_f, list) else 0

#         }
#         return

#     def save(self):
#         with open(self.cached_path[:-4] + "_decomposed" + self.cached_path[-4:], 'wb') as file:
#             pickle.dump((self.dcps, self.decomp_graphs, self.decomp_u_ps, self.decomp_inputs_f), file)
    
#     def load(self):
#         with open(self.cached_path[:-4] + "_decomposed" + self.cached_path[-4:], 'rb') as file:
#             self.dcps, self.decomp_graphs, self.decomp_u_ps, self.decomp_inputs_f = pickle.load(file)

#     def has_cache(self):
#         return os.path.exists(self.cached_path[:-4] + "_decomposed" + self.cached_path[-4:])

#     def __len__(self):
#         return self.data_len

#     def __getitem__(self, idx):
        
#         return self.dcps[idx], self.decomp_graphs[idx], self.decomp_u_ps[idx], self.decomp_inputs_f[idx]
    

class DecomposedDomainDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size=1,sort_data=True, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None):
        super(DecomposedDomainDataLoader, self).__init__(dataset=dataset, batch_size=batch_size,
                                           shuffle=shuffle, sampler=sampler,
                                           batch_sampler=batch_sampler,
                                           num_workers=num_workers,
                                           collate_fn=collate_fn,
                                           pin_memory=pin_memory,
                                           drop_last=drop_last, timeout=timeout,
                                           worker_init_fn=worker_init_fn)

        self.sort_data = sort_data
        if sort_data:
            self.batch_indices = [list(range(i, min(i+batch_size, len(dataset)))) for i in range(0, len(dataset), batch_size)]
            if drop_last:
                self.batch_indices = self.batch_indices[:-1]
        else:
            self.batch_indices = list(range(0, (len(dataset) // batch_size)*batch_size)) if drop_last else list(range(0, len(dataset)))
        if shuffle:
            np.random.shuffle(self.batch_indices)

    def __iter__(self):
        # 返回一个迭代器，用于遍历数据集中的每个批次
        for indices in self.batch_indices:
            transposed = zip(*[self.dataset[idx][1:] for idx in indices])
            batched = []
            for sample in transposed:
                if isinstance(sample[0][0], dgl.DGLGraph):
                    gs = []
                    for s in sample:
                        gs += s
                    batched.append(dgl.batch(gs))
                elif isinstance(sample[0][0], torch.Tensor):
                    ts = []
                    for s in sample:
                        ts += s
                    batched.append(torch.stack(ts))
                elif isinstance(sample[0][0], MultipleTensors):
                    ms = []
                    for s in sample:
                        ms += s
                    ms = [pad_sequence([ms[i][j] for i in range(len(ms))]).permute(1, 0, 2) for j in range(len(ms[0]))]
                    sample_ = MultipleTensors(ms)
                    batched.append(sample_)
                else:
                    raise NotImplementedError
            yield batched

    def __len__(self):
        # 返回数据集的批次数
        return len(self.batch_indices)

# def preprocess_heat2d_data(graph, u_p, inputs_f):
#     # preprocess heat2d data for inference

#     # generate boundary conditions as input_boundary
#     x_mean, x_std = train_gnot_dataset.x_normalizer.mean, train_gnot_dataset.x_normalizer.std
#     up_mean, up_std = train_gnot_dataset.up_normalizer.mean, train_gnot_dataset.up_normalizer.std

#     boundary_x = train_gnot_dataset.x_normalizer.transform(inputs_f[1][:, 0:2], inverse=False)
#     boundary_y = train_gnot_dataset.y_normalizer.transform(inputs_f[1][:, [2]], inverse=False)

#     input_boundary = torch.cat([boundary_x, boundary_y], dim=1)

#     input_func = MultipleTensors([inputs_f.x[0]] + inputs_f.x[2:])
#     # generate interior and boundary points

#     x = graph.ndata['x'].numpy()
#     y_pred = graph.ndata['y'].numpy()
#     x1max = x[:, 1].max()
#     x0max = x[:, 0].max()
#     x0min = x[:, 0].min()
#     # spatial points with indices
#     interior = x[~(x[:, 1] == x1max)]
#     interior = np.append(interior, np.array(range(interior.shape[0]))[...,None], 1)
#     boundary = x[x[:, 1] == x1max]
#     boundary = np.append(boundary, np.array(range(interior.shape[0], interior.shape[0]+boundary.shape[0]))[...,None], 1)

#     points = np.concatenate((interior, boundary))

#     # identified indices, used for periodic boundary
#     identified = {int(p[-1]):int(q[-1]) for p, q in itertools.permutations(points, r=2) if ((p[1] == q[1]) and (p[0] == x0max) and (q[0] == x0min))}
    
#     return interior, boundary, identified, u_p, input_func, input_boundary

# class DataCreator(nn.Module):
#     """
#     Helper class to construct input data and labels.
#     """
#     def __init__(self,
#                  time_history,
#                  time_future,
#                  t_resolution,
#                  x_resolution
#                  ):
#         """
#         Initialize the DataCreator object.
#         Args:
#             time_history (int): how many time steps are used for PDE prediction
#             time_future (int): how many time does the solver predict into the future
#             t_resolution: temporal resolution
#             x_resolution: spatial resolution
#         """
#         super().__init__()
#         self.time_history = time_history
#         self.time_future = time_future
#         self.t_res = t_resolution
#         self.x_res = x_resolution

#     def create_data(self, datapoints: torch.Tensor, start_time: list, pf_steps=0) -> Tuple[torch.Tensor, torch.Tensor]:
#         """
#         Getting data of PDEs for training, validation and testing.
#         Args:
#             datapoints (torch.Tensor): trajectory input
#             start_time (int list): list of different starting times for different trajectories in one batch
#             pf_steps (int): push forward steps
#         Returns:
#             torch.Tensor: neural network input data
#             torch.Tensor: neural network labels
#         """
#         data = []
#         labels = []
#         # Loop over batch and different starting points
#         # For every starting point, we take the number of time_history points as training data
#         # and the number of time future data as labels
#         for (dp, start) in zip(datapoints, start_time):
#             end_time = start+self.time_history
#             d = dp[start:end_time]
#             target_start_time = end_time + self.time_future * pf_steps
#             target_end_time = target_start_time + self.time_future
#             l = dp[target_start_time:target_end_time]

#             data.append(d.unsqueeze(dim=0))
#             labels.append(l.unsqueeze(dim=0))

#         return torch.cat(data, dim=0), torch.cat(labels, dim=0)

class DolphinxNeuralOperatorDataset(MIODataset):
    # load data generated from dolphinx as dataset
    # loaded pkl file is assumed to be a list of pair 
    # (numpy array of solution, numpy array of boundary condition)

    def __init__(self, *args, time_dependent=False, time_step=None, time_span=None, **kwargs):
        transform = kwargs.pop('transform', None)
        self.time_dependent = time_dependent
        self.time_step = time_step
        self.time_span = time_span
        super(DolphinxNeuralOperatorDataset, self).__init__(*args, **kwargs)
        self.transform = transform if transform else lambda x,y,z: (x,y,z)

    def process(self):
        if not self.time_dependent:
            # space only process
            self.data_len = len(self.data_list)
            self.graphs = []
            self.inputs_f = []
            self.u_p = []
            for i in range(len(self)):
                #x, y, u_p, input_f = self.data_list[i]
                if len(self.data_list[i]) == 2:
                    sol, inputs_f = self.data_list[i][0], self.data_list[i][1]
                    up = torch.zeros((1, )).float()
                elif len(self.data_list[i]) == 3:
                    sol, u_p, inputs_f = self.data_list[i][0], self.data_list[i][1], self.data_list[i][2]
                    up = torch.tensor((u_p, )).float()

                g = dgl.DGLGraph()
                g.add_nodes(sol.shape[0])
                g.ndata['x'] = torch.from_numpy(sol[:, 0:2]).float()
                g.ndata['y'] = torch.from_numpy(sol[:, 2:]).float()

                self.graphs.append(g)
                self.u_p.append(up) # global input parameters
                if inputs_f is not None:
                    
                    inputs_f = MultipleTensors([torch.from_numpy(f).float() for f in inputs_f])
                    self.inputs_f.append(inputs_f)
                    self.num_inputs = len(inputs_f)
        else:
            # space-time process
            self.graphs = []
            self.inputs_f = []
            self.u_p = []

            for i, data in enumerate(self.data_list):
                sol, u_p, inputs_f = data[0], data[1], data[2]
                
                if self.time_span is None:
                    # training
                    chunck_size = sol[:, 2:].shape[1] / self.time_step
                else:
                    chunck_size = 1
                # assert all times steps can be divided into equa time_step
                assert int(chunck_size) == chunck_size
                boundary_points = inputs_f[0][:, [0, 1]]
                for i, (s, f) in enumerate(zip(np.hsplit(sol[:, 2:], chunck_size), np.hsplit(inputs_f[0][:, 2:-1], chunck_size))):
                    g = dgl.DGLGraph()
                    g.add_nodes(sol.shape[0])
                    g.ndata['x'] = torch.from_numpy(sol[:, 0:2]).float()
                    g.ndata['y'] = torch.from_numpy(s)
                    up = torch.tensor((u_p, )).float()
                    self.graphs.append(g)
                    self.u_p.append(up) # global input parameters

                    # boundary condition and initial condition
                    bc = torch.from_numpy(np.concatenate([boundary_points, 
                                                        f, 
                                                        np.zeros((boundary_points.shape[0], 1))], axis=1)).float()
                    ic = torch.from_numpy(np.concatenate([sol[:, 0:2], s[:, [0]]], axis=1)).float()
                    self.inputs_f.append(MultipleTensors([ic, bc]))
                self.num_inputs = 2
            self.data_len = len(self.graphs)
        if len(self.inputs_f) == 0:
            self.inputs_f = torch.zeros([len(self)])  # pad values, tensor of 0, not list

            # logger.info('processing {}'.format(i))d

        #### sort data if necessary
        if self.sort_data:
            self.__sort_dataset()

        self.u_p = torch.stack(self.u_p)


        #### normalize_y
        if self.normalize_y != 'none':
            self.__normalize_y__()
        if self.normalize_x != 'none':
            self.__normalize_x__()

        self.__update_dataset_config__()

        return
    


    def __update_dataset_config__(self):
        if not self.time_dependent:
            self.config = {
                'input_dim': self.graphs[0].ndata['x'].shape[1],
                'theta_dim': self.u_p.shape[1],
                'output_dim': self.graphs[0].ndata['y'].shape[1],
                'branch_sizes': [x.shape[1] for x in self.inputs_f[0]] if isinstance(self.inputs_f, list) else 0
            }
        else:
            self.config = {
                'input_dim': 2,
                'theta_dim': self.u_p.shape[1],
                'output_dim': self.time_step if self.time_step else -1, #-1 for infinity time steps
                'branch_sizes': [x.shape[1] for x in self.inputs_f[0]] if isinstance(self.inputs_f, list) else 0
            }
        return
    def __getitem__(self, idx):
        return self.transform(self.graphs[idx], self.u_p[idx], self.inputs_f[idx])


def transform_gt(model, graph):

    if not model.time_dependent:
        inputs_f = MultipleTensors([torch.concat([graph.ndata['x'], graph.ndata['y'], torch.zeros((graph.ndata['x'].shape[0], 1))], dim=1)])
    else:
        boundary_condition = torch.concat([graph.ndata['x'], graph.ndata['y'], torch.zeros((graph.ndata['x'].shape[0], 1))], dim=1)
        initial_condition = torch.concat([graph.ndata['x'], graph.ndata['y'][:, 0:1]], dim=1)
        inputs_f = MultipleTensors([initial_condition, boundary_condition])
    gt_sol = model.initialize(inputs_f)

    return gt_sol