import json
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import einops
import h5py
import lightning as L
import numpy as np
import scipy.spatial
import torch
from torch import Tensor
from torch.nn.functional import one_hot
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.pool import nearest, radius, radius_graph
from torch_geometric.transforms import KNNGraph
from torch_scatter import scatter_sum

from src.utils.data_utils import (
    QuinticKernel,
    pos_init_cartesian_2d,
    pos_init_cartesian_3d,
)


@dataclass
class MMDataset(InMemoryDataset):
    n_velocities: int = 2
    n_jump_ahead_timesteps: int = 1
    n_jumps: int = 1
    split: str = "train"
    mode: str = "train_autoencoder"
    n_particle_types: int = None
    n_supernodes: int = None
    num_points_range: int = None
    global_root: str = None
    local_root: str = None
    seed: int = None
    n_occupancy: int = 1000
    occupancy_radius: float = 0.03  # should corresponds to ~2*metadata["dx"]
    num_points_decode: int = None
    dataset_rel_path: str = "multi_material"
    occ_nearest_backend: str = "scipy"  # pyg or scipy
    overfit_single_trajectory: bool = False
    random_particles: bool = False

    def __post_init__(self):
        super().__init__()
        assert self.split in [
            "train",
            "valid",
            "test",
        ], f"Split {self.split} not available."
        assert self.n_velocities > 1, "n_velocities must be greater than 0."
        assert self.mode in [
            "train_autoencoder",
            "train_physics",
            "train_physics_full",
            "val_physics",
            "full_traj",
            "gns_f",
            "full_traj_gns_f",
        ]
        assert self.n_particle_types > 0
        if self.local_root is None:
            self.source_root = Path(self.global_root) / self.dataset_rel_path
            # self.logger.info(f"data_source (global): '{self.source_root}'")
        else:
            # load data from local_root
            self.source_root = Path(self.local_root) / self.dataset_rel_path
            # If local data does not exist, try global root
            if not self.source_root.exists():
                self.source_root = Path(self.global_root) / self.dataset_rel_path
        assert self.source_root.exists(), f"'{self.source_root.as_posix()}' doesn't exist"

        if self.overfit_single_trajectory:
            self.split = "train"
        self.trajectories = self.load_dataset(self.source_root, self.split)
        self.traj_keys = list(self.trajectories.keys())
        if self.overfit_single_trajectory:
            self.n_traj = 1
        else:
            self.n_traj = len(self.traj_keys)

        self.metadata = self.load_metadata(self.source_root)
        self.connectivity_radius = self.metadata["default_connectivity_radius"]
        self.bounds = torch.tensor(self.metadata["bounds"])
        self.ndim = self.metadata["dim"]

        if self.random_particles:
            # Only possible with 3D data
            assert self.ndim == 3, "random particles only works with 3D data at the moment"
            # Set the box so all the particles fit in there
            assert isinstance(
                self.num_points_range, int
            ), "num_points_range must be a single integer."
            # Sample always the same number of points
            self.num_points_range = [self.num_points_range, self.num_points_range]
            # Randomly sample points in the domain with dx = connectivity_radius / 1.5
            # Possible number of randomly selected points are multiples of 16*16=256
            nx = 16
            ny = 16
            assert (
                self.num_points_range[0] % (nx * ny) == 0
            ), f"num_points_range must be divisible by {nx*ny}."
            nz = self.num_points_range[0] // 256
            dx = self.connectivity_radius / 1.5
            self.bounds[:, 1] = torch.tensor([nx * dx, ny * dx, nz * dx])

        # Normalization stats
        self.vel_mean = torch.tensor(self.metadata["vel_mean"])
        self.vel_std = torch.tensor(self.metadata["vel_std"])
        self.acc_mean = torch.tensor(self.metadata["acc_mean"])
        self.acc_std = torch.tensor(self.metadata["acc_std"])

        self.n_occupancy = self.n_occupancy
        self.occupancy_radius = self.occupancy_radius

        # Check for PBC
        if "periodic_boundary_conditions" in self.metadata and any(
            self.metadata["periodic_boundary_conditions"]
        ):
            self.box = self.bounds[:, 1] - self.bounds[:, 0]
            # Scaling for positional embedding
            # Positional embedding function is the same as for the ViT
            # The range is from 0-197 -> we use as a max 200
            self.pos_offset = self.bounds[:, 0]
        else:
            # No periodic boundary condition
            self.box = self.bounds[:, 1] - self.bounds[:, 0]
            self.pos_offset = self.bounds[:, 0]

        if self.mode == "train_autoencoder":
            # Every sample in the trajectory can be used
            self.n_seq = (
                self.metadata["sequence_length_train"]
                if "sequence_length_train" in self.metadata
                else self.metadata["sequence_length"]
            )
            self.n_per_traj = self.n_seq - self.n_velocities
            self.pos_getter = partial(self.get_positions, length=self.n_velocities + 1)
        elif self.mode in ["train_physics", "val_physics", "train_physics_full", "gns_f"]:
            self.n_seq = (
                self.metadata["sequence_length_val"]
                if "sequence_length_val" in self.metadata
                else self.metadata["sequence_length"]
            )
            self.n_per_traj = (
                self.n_seq - self.n_jump_ahead_timesteps * self.n_jumps - self.n_velocities
            )
            self.pos_getter = partial(
                self.get_positions,
                length=self.n_jump_ahead_timesteps * self.n_jumps + self.n_velocities + 1,
            )
            # Define the idx for later picking the needed velocities
            self.target_vel_idx = torch.tensor(
                [
                    list(
                        range(
                            i * self.n_jump_ahead_timesteps,
                            i * self.n_jump_ahead_timesteps + self.n_velocities,
                        )
                    )
                    for i in range(1, self.n_jumps + 1)
                ]
            )
            self.target_pos_idx = torch.tensor(
                [
                    i * self.n_jump_ahead_timesteps + self.n_velocities
                    for i in range(1, self.n_jumps + 1)
                ]
            )
        elif self.mode in ["full_traj", "full_traj_gns_f"]:
            self.n_seq = (
                self.metadata["sequence_length_val"]
                if "sequence_length_val" in self.metadata
                else self.metadata["sequence_length"]
            )
            # There are n_jump_ahead_timesteps possible starts of a full trajectory
            # E.g. for n_jump_ahead_timesteps=10 and n_velocities=2
            # Start encoding positions with 2 ([0,1,2] positions are used for encoding)
            # and first prediction would be decoding pos 12 ([10,11,12] are targets)
            # Up to 11 ([9,10,11] positions are used for encoding) and decoding
            # at 21 ([19,20,21] are targets)
            # self.n_per_traj = self.n_jump_ahead_timesteps
            self.n_per_traj = 1
            self.pos_getter = partial(self.get_positions, length=self.n_seq + 1)

            self.n_jumps = (
                self.n_seq - self.n_velocities
            ) // self.n_jump_ahead_timesteps
            self.target_pos_idx = torch.tensor(
                [
                    i * self.n_jump_ahead_timesteps + self.n_velocities
                    for i in range(1, self.n_jumps + 1)
                ]
            )
            self.target_vel_idx = torch.tensor(
                [
                    list(
                        range(
                            i * self.n_jump_ahead_timesteps,
                            i * self.n_jump_ahead_timesteps + self.n_velocities,
                        )
                    )
                    for i in range(1, self.n_jumps + 1)
                ]
            )

    def __len__(self) -> int:
        return self.n_traj * self.n_per_traj

    def __getitem__(self, idx: int):
        if self.mode == "train_autoencoder":
            return self.getitem_autoencoder(idx)
        elif self.mode == "train_physics":
            return self.getitem_train_physics(idx)
        elif self.mode == "train_physics_full":
            return self.getitem_train_physics_full(idx)
        elif self.mode == "val_physics":
            return self.getitem_val_physics(idx)
        elif self.mode == "full_traj":
            return self.getitem_full_traj(idx)
        elif self.mode == "gns_f":
            return self.getitem_gns_f(idx)
        elif self.mode == "full_traj_gns_f":
            return self.getitem_full_traj_gns_f(idx)

    def getitem_autoencoder(self, idx: int) -> Data:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]
        perm_enc = self.get_permutation(n_particles=position.shape[0])
        perm_dec = self.get_permutation(
            n_particles=position.shape[0], n_perm=self.num_points_decode
        )
        # Calculate velocities
        velocities = position[:, 1:, :] - position[:, :-1, :]
        # Normalization
        velocities = self.normalize_vel(velocities)
        # Occupancy
        occ_dict = self.get_occupancy(
            position=position[:, self.n_velocities, :], particle_type=particle_type
        )
        occ_type = occ_dict["occ_type"]
        occ_pos = occ_dict["occ_pos"]
        # Supernode index
        supernode_index = torch.arange(self.n_supernodes)
        return Data(
            enc_pos=position[
                perm_enc, -1, :
            ],  # Latest position is the one which should be encoded; [55077, 3]
            enc_field=velocities[perm_enc],  # [55077, 2, 3]
            enc_particle_type=particle_type[perm_enc],  # [55077]
            dec_pos=position[
                perm_dec, -1, :
            ],  # Latest position is also the one to decode; [131072, 3]
            dec_field=velocities[perm_dec],  # [131072, 2, 3]
            dec_occ_type=occ_type,  # [65536, 2]
            dec_occ_pos=occ_pos,  # [65536, 3]
            supernode_index=supernode_index,  # [16384]  0,...,185932
            supernode_idx=supernode_index,
            timestep=position_dict["time_idx"][-1],  # 2
            num_nodes=len(perm_enc),
        )

    def getitem_train_physics(self, idx: int) -> Tuple[Data, Data]:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]
        perm_input = self.get_permutation(n_particles=position.shape[0])
        # Target permutation should be the same size
        # so we can stack everything in one data object
        perm_target = self.get_permutation(n_particles=position.shape[0], n_perm=len(perm_input))

        # VELOCITIES
        velocities = position[:, 1:, :] - position[:, :-1, :]
        input_velocities = velocities[perm_input, : self.n_velocities]
        target_velocities = velocities[:, self.target_vel_idx]
        target_velocities = target_velocities[perm_target]
        # Normalization
        input_velocities = self.normalize_vel(input_velocities)
        target_velocities = self.normalize_vel(target_velocities)
        # Shape of target velocities [n_particles, n_jumps, time, dim]

        # POSITION
        input_position = position[perm_input, self.n_velocities]
        target_position = position[:, self.target_pos_idx]
        target_position = target_position[perm_target]
        # Shape of target position [n_particles, n_jumps, dim]

        return Data(
            input_enc_pos=input_position,
            input_enc_field=input_velocities,
            input_enc_particle_type=particle_type[perm_input],
            supernode_index=torch.arange(self.n_supernodes),  # Is the same for input and target
            input_timestep=position_dict["time_idx"][self.n_velocities],
            target_enc_pos=target_position,
            target_enc_field=target_velocities,
            target_enc_particle_type=particle_type[perm_target],
            target_timestep=position_dict["time_idx"][self.target_pos_idx].unsqueeze(0),
            num_nodes=len(input_position),
        )

    def getitem_train_physics_full(self, idx: int) -> Tuple[Data, Data]:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]
        perm_input = self.get_permutation(n_particles=position.shape[0])
        perm_target = self.get_permutation(
            n_particles=position.shape[0], n_perm=self.num_points_decode
        )

        # VELOCITIES
        velocities = position[:, 1:, :] - position[:, :-1, :]
        input_velocities = velocities[perm_input, : self.n_velocities]
        target_velocities = velocities[:, self.target_vel_idx]
        target_velocities = target_velocities[perm_target]
        # Normalization
        input_velocities = self.normalize_vel(input_velocities)
        target_velocities = self.normalize_vel(target_velocities)
        # Shape of target velocities [n_particles, n_jumps, time, dim]

        # POSITION
        input_position = position[perm_input, self.n_velocities]
        target_position = position[:, self.target_pos_idx]
        # Shape of target position [n_particles, n_jumps, dim]

        # OCCUPANCIES
        all_occ_type = []
        all_occ_pos = []
        for jump_idx in range(target_position.shape[1]):
            occ_dict = self.get_occupancy(
                position=target_position[:, jump_idx], particle_type=particle_type
            )
            all_occ_type.append(occ_dict["occ_type"])
            all_occ_pos.append(occ_dict["occ_pos"])

        occ_type = torch.stack(all_occ_type, dim=1)
        occ_pos = torch.stack(all_occ_pos, dim=1)

        input_data = Data(
            enc_pos=input_position,
            enc_field=input_velocities,
            enc_particle_type=particle_type[perm_input],
            supernode_index=torch.arange(self.n_supernodes),
            timestep=position_dict["time_idx"][self.n_velocities],
            num_nodes=len(input_position),
        )
        target_data = Data(
            dec_pos=target_position[perm_target],
            dec_field=target_velocities,
            dec_occ_type=occ_type,
            dec_occ_pos=occ_pos,
            timestep=position_dict["time_idx"][self.target_pos_idx].unsqueeze(0),
            num_nodes=len(perm_target),
        )

        return input_data, target_data

    def getitem_val_physics(self, idx: int) -> Tuple[Data, Data]:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]
        perm_input = self.get_permutation(n_particles=position.shape[0])
        perm_target = self.get_permutation(
            n_particles=position.shape[0], n_perm=self.num_points_decode
        )

        # VELOCITIES
        velocities = position[:, 1:, :] - position[:, :-1, :]
        input_velocities = velocities[perm_input, : self.n_velocities]
        target_velocities = velocities[:, self.target_vel_idx]
        target_velocities = target_velocities[perm_target]
        # Normalization
        input_velocities = self.normalize_vel(input_velocities)
        target_velocities = self.normalize_vel(target_velocities)
        # Shape of target velocities [n_particles, n_jumps, time, dim]

        # POSITION
        input_position = position[perm_input, self.n_velocities]
        target_position = position[:, self.target_pos_idx]
        # Shape of target position [n_particles, n_jumps, dim]

        # OCCUPANCIES
        all_occ_type = []
        all_occ_pos = []
        for jump_idx in range(target_position.shape[1]):
            occ_dict = self.get_occupancy(
                position=target_position[:, jump_idx], particle_type=particle_type
            )
            all_occ_type.append(occ_dict["occ_type"])
            all_occ_pos.append(occ_dict["occ_pos"])

        occ_type = torch.stack(all_occ_type, dim=1)
        occ_pos = torch.stack(all_occ_pos, dim=1)

        input_data = Data(
            enc_pos=input_position,
            enc_field=input_velocities,
            enc_particle_type=particle_type[perm_input],
            supernode_index=torch.arange(self.n_supernodes),
            timestep=position_dict["time_idx"][self.n_velocities],
            num_nodes=len(input_position),
        )
        target_data = Data(
            dec_pos=target_position[perm_target],
            dec_field=target_velocities,
            dec_occ_type=occ_type,
            dec_occ_pos=occ_pos,
            supernode_index=torch.arange(self.n_supernodes),
            timestep=position_dict["time_idx"][self.target_pos_idx].unsqueeze(0),
            num_nodes=len(perm_target),
        )

        return input_data, target_data

    def getitem_full_traj(self, idx: int) -> Data:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]
        # VELOCITIES
        velocities = position[:, 1:, :] - position[:, :-1, :]
        input_velocities = velocities[:, : self.n_velocities]
        target_velocities = velocities[:, self.target_vel_idx]
        target_velocities = target_velocities[:]
        # Normalization
        input_velocities = self.normalize_vel(input_velocities)
        target_velocities = self.normalize_vel(target_velocities)
        # Shape of target velocities [n_particles, n_jumps, time, dim]

        # POSITION
        input_position = position[:, self.n_velocities]
        target_position = position[:, self.target_pos_idx]
        target_position = target_position[:]
        # Shape of target position [n_particles, n_jumps, dim]

        # Add grid occ and field
        input_gt_occ, input_gt_vel = self.get_grid_occ_and_field(
            particle_pos=input_position,
            particle_vel=input_velocities[:, -1, :],
            is_normalize_vel=False,
        )

        target_gt_occ = []
        target_gt_vel = []
        for jump_idx in range(target_position.shape[1]):
            gt_occ, gt_vel = self.get_grid_occ_and_field(
                particle_pos=target_position[:, jump_idx],
                particle_vel=target_velocities[:, jump_idx, -1, :],
                is_normalize_vel=False,
            )
            target_gt_occ.append(gt_occ)
            target_gt_vel.append(gt_vel)

        target_gt_occ = torch.stack(target_gt_occ, dim=1)
        target_gt_vel = torch.stack(target_gt_vel, dim=1)

        input_data = Data(
            enc_pos=input_position,
            enc_field=input_velocities,
            enc_particle_type=particle_type,
            supernode_index=torch.randperm(input_position.shape[0], generator=None)[
                : self.n_supernodes
            ],  # Because no permutation is used here, supernode index must be shuffled
            timestep=position_dict["time_idx"][self.n_velocities],
            grid_pos=self.grid_pos,
            gt_occ=input_gt_occ,
            gt_vel=input_gt_vel,
            num_nodes=len(input_position),
        )
        target_data = Data(
            dec_pos=target_position,
            dec_field=target_velocities,
            supernode_index=torch.randperm(input_position.shape[0], generator=None)[
                : self.n_supernodes
            ],  # Because no permutation is used here, supernode index must be shuffled
            timestep=position_dict["time_idx"][self.target_pos_idx].unsqueeze(0),
            grid_pos=self.grid_pos,
            gt_occ=target_gt_occ,
            gt_vel=target_gt_vel,
            num_nodes=len(target_position),
        )

        return input_data, target_data

    def getitem_full_traj_gns_f(self, idx: int) -> Data:
        position_dict = self.pos_getter(idx)
        position = position_dict["position"]
        particle_type = position_dict["particle_type"]

        # VELOCITIES
        velocities = position[:, 1:, :] - position[:, :-1, :]
        input_velocities = velocities[:, : self.n_velocities]
        target_velocities = velocities[:, self.n_velocities :]
        # Normalization
        input_velocities = self.normalize_vel(input_velocities)
        target_velocities = self.normalize_vel(target_velocities)
        # Shape of target velocities [n_particles, time, dim]

        # POSITION
        input_positions = position[:, : self.n_velocities + 1]
        target_positions = position[:, self.n_velocities + 1 :]
        # Shape of target position [n_particles, n_jumps, dim]

        return Data(  # data on the SPH particle positions
            enc_pos=input_positions,  # positions [0, 1, ..., 5]
            n_particles_per_example=input_positions.shape[0],
            particle_type=particle_type,
            target_pos=target_positions,  # position [6, ...]
            # for visualization:
            enc_field=input_velocities,
            target_field=target_velocities,
        )

    def get_grid_occ_and_field(
        self, particle_pos: Tensor, particle_vel: Tensor, is_normalize_vel: bool = True
    ) -> Tuple[Tensor, Tensor]:
        if is_normalize_vel:
            particle_vel = self.normalize_vel(particle_vel)

        # Instantiate regular grid positions on which to evaluate the occupancy and
        # field data (i.e. velocity field)
        if not hasattr(self, "grid_pos"):
            if self.ndim == 2:
                self.grid_pos = pos_init_cartesian_2d(
                    self.box.cpu().numpy(), dx=self.occupancy_radius
                )
            elif self.ndim == 3:
                self.grid_pos = pos_init_cartesian_3d(
                    self.box.cpu().numpy(), dx=self.occupancy_radius
                )
            self.grid_pos = torch.tensor(self.grid_pos, dtype=torch.float) + self.pos_offset

        # Calculate occupancy
        tree = scipy.spatial.cKDTree(particle_pos.cpu().numpy())
        receivers_list = tree.query_ball_point(self.grid_pos.cpu().numpy(), self.occupancy_radius)
        target_grid_occ = torch.tensor([len(points) > 0 for points in receivers_list]).to(
            torch.float
        )

        # Get field data
        if not hasattr(self, "sph_kernel"):
            self.sph_kernel = QuinticKernel(h=self.occupancy_radius / 3, dim=self.ndim)

        num_nodes = len(self.grid_pos)
        senders = torch.tensor(np.repeat(range(num_nodes), [len(a) for a in receivers_list]))
        receivers = torch.tensor(np.concatenate(receivers_list, axis=0), dtype=torch.int64)
        displ = particle_pos[receivers] - self.grid_pos[senders]
        dist = torch.norm(displ, dim=-1, keepdim=True)
        sph_weights = self.sph_kernel.w(dist)
        # use the sph_weights to compute a weighted average of the velocities
        # of the particles in the support of the kernel.
        weighted_velocities = sph_weights * particle_vel[receivers]
        summed_vels = scatter_sum(weighted_velocities, senders, dim=0, dim_size=num_nodes)
        sheparding_denominator = scatter_sum(sph_weights, senders, dim=0, dim_size=num_nodes)
        mask = sheparding_denominator.squeeze() > 0
        target_grid_field = torch.zeros_like(self.grid_pos)
        target_grid_field[mask] = summed_vels[mask] / sheparding_denominator[mask]

        # xmin, xmax, ymin, ymax = 0.4, 0.6, 0.15, 0.3
        # mask1 = particle_pos > torch.tensor([xmin, ymin])
        # mask1 *= particle_pos < torch.tensor([xmax, ymax])
        # mask1 = torch.all(mask1, dim=1)
        # area = (ymax - ymin) * (xmax - xmin)
        # n_per_area = mask1.sum().item() / area
        # dx = 1 / n_per_area**0.5
        # print("Average interparticle distance dx:", dx)  # waterdrop: dx: 0.0052

        # import matplotlib.pyplot as plt
        # vmin, vmax = -3, 3
        # fig, axs = plt.subplots(2, 2, figsize=(6,6))
        # axs[0,0].scatter(self.grid_pos[:, 0], self.grid_pos[:, 1], marker=".", s=8, c=target_grid_occ, linewidth=0, label="Occupancy")
        # axs[0,0].scatter(particle_pos[:, 0], particle_pos[:, 1], marker=".", s=1, linewidth=0, label="Particles")
        # axs[0,0].set_title("Particles vs occupancy")

        # axs[0,1].scatter(self.grid_pos[:, 0], self.grid_pos[:, 1], vmin=0, vmax=35000, marker=".", s=8, c=sheparding_denominator[:,0], linewidth=0, label="Occupancy")
        # axs[0,1].scatter(particle_pos[:, 0], particle_pos[:, 1], marker=".", s=1, linewidth=0, label="Particles")
        # axs[0,1].set_title("Particles vs density field")

        # axs[1,0].scatter(self.grid_pos[:, 0], self.grid_pos[:, 1], c=target_grid_field[:,0], vmin=vmin, vmax=vmax, marker=".", s=8, linewidth=0, label="Field")
        # axs[1,0].scatter(particle_pos[:, 0], particle_pos[:, 1], c=particle_vel[:,0], vmin=vmin, vmax=vmax, marker=".", s=1, linewidth=0, label="Particles")
        # axs[1,0].set_title("x-velocity")

        # axs[1,1].scatter(self.grid_pos[:, 0], self.grid_pos[:, 1], c=target_grid_field[:,1], vmin=vmin, vmax=vmax, marker=".", s=8, linewidth=0, label="Field")
        # axs[1,1].scatter(particle_pos[:, 0], particle_pos[:, 1], c=particle_vel[:,1], vmin=vmin, vmax=vmax, marker=".", s=1, linewidth=0, label="Particles")
        # axs[1,1].set_title("y-velocity")

        # for ax in axs.flatten():
        #     ax.set_xlim([0, 1])
        #     ax.set_ylim([0, 1])
        #     ax.grid()
        #     # ax.legend()
        # plt.tight_layout()
        # plt.savefig("test.png", dpi=600)
        # plt.close()

        return target_grid_occ, target_grid_field

    def getitem_gns_f(self, idx: int) -> Tuple[Data, Data]:
        position_dict = self.pos_getter(idx)

        particle_data = Data(  # data on the SPH particle positions
            enc_pos=position_dict["position"][:, :-1],  # positions [0, 1, ..., 5]
            n_particles_per_example=position_dict["position"].shape[0],
            particle_type=position_dict["particle_type"],
            target_pos=position_dict["position"][:, -1],  # position [6]
        )

        if (self.split in ["valid", "test"]) or self.overfit_single_trajectory:
            assert (
                position_dict["particle_type"].unique().shape[0] == 1
            ), "Only one particle type is supported for validation and test splits."
            target_grid_occ, target_grid_field = self.get_grid_occ_and_field(
                particle_pos=position_dict["position"][:, -1],
                particle_vel=(position_dict["position"][:, -1] - position_dict["position"][:, -2]),
            )
            grid_data = Data(
                enc_pos=self.grid_pos,
                target_grid_occ=target_grid_occ,
                target_grid_field=target_grid_field,
            )
        elif self.split == "train":
            grid_data = Data()

        # # NODE FEATURES
        # velocities = position[:, 1:, :] - position[:, :-1, :]
        # input_velocities = velocities[:, : self.n_velocities]
        # target_acceleration = (
        #     velocities[:, self.n_velocities] - velocities[:, self.n_velocities - 1]
        # )
        # # Normalization
        # input_velocities = self.normalize_vel(input_velocities).reshape(n_total_points, -1)
        # target_acceleration = self.normalize_acc(target_acceleration)
        # # Shapes: [n_particles, n_velocities * dim] and [n_particles, dim]

        # # Boundary distance
        # # Normalized clipped distances to lower and upper boundaries.
        # # boundaries are an array of shape [num_dimensions, 2], where the second
        # # axis, provides the lower/upper boundaries.
        # most_recent_position = position[:, -2]  # (n_nodes, dim)
        # distance_to_lower_boundary = (most_recent_position - self.bounds[:, 0][None])
        # distance_to_upper_boundary = (self.bounds[:, 1][None] - most_recent_position)
        # distance_to_boundaries = torch.cat([distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
        # boundary_distances = torch.clamp(distance_to_boundaries / self.connectivity_radius, -1., 1.)

        # # EDGE FEATURES
        # ### Compare scipy and matscipy and sklearn neighbor search
        # # from sklearn import neighbors
        # # tree = neighbors.KDTree(positions)

        # # most_recent_position = np.random.rand(600000).reshape(200000,3) * np.array([5.0, 2.0, 1.0])
        # # most_recent_position = torch.tensor(most_recent_position)
        # # self.box = torch.tensor([5, 2, 1.0])
        # # self.connectivity_radius = 0.06

        # # import time
        # # start = time.time()
        # # for _ in range(5):
        # #     tree = scipy.spatial.KDTree(most_recent_position.cpu().numpy())
        # #     receivers_list = tree.query_ball_point(most_recent_position, self.connectivity_radius)
        # #     num_nodes = len(most_recent_position)
        # #     senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
        # #     receivers = np.concatenate(receivers_list, axis=0)
        # # print(time.time()-start)

        # # import time
        # if most_recent_position.shape[-1] == 2:
        #     most_recent_position_3d = np.pad(
        #         most_recent_position.cpu().numpy(), ((0, 0), (0, 1)), mode="constant", constant_values=0.0001
        #     )
        #     box_size = np.pad(self.box, (0, 1), mode="constant", constant_values=1.0, )
        # else:
        #     most_recent_position_3d = most_recent_position.cpu().numpy()
        #     box_size = self.box.cpu().numpy()

        # # start = time.time()
        # # for _ in range(5):
        # from matscipy.neighbours import neighbour_list as matscipy_nl
        # if box_size.shape == (3,):
        #     box_size = np.diag(box_size)  # make into diagonal matrix

        # edge_list = matscipy_nl(
        #     "ij", cutoff=self.connectivity_radius, positions=most_recent_position_3d, cell=box_size, pbc=np.array([False, False, False])
        # )
        # # print(time.time() - start)

        # senders, receivers = edge_list
        # normalized_relative_displacements = (
        #     most_recent_position[senders, :] - most_recent_position[receivers, :]
        # ) / self.connectivity_radius

        # normalized_relative_distances = torch.norm(normalized_relative_displacements, dim=-1, keepdim=True)

        # input_data = {
        #     "vel_hist": input_velocities,
        #     # "vel_mag": None,
        #     "bound": boundary_distances,
        #     # "force": None,
        #     "particle_type": particle_type,
        #     "senders": senders,
        #     "receivers": receivers,
        #     "rel_disp": normalized_relative_displacements,
        #     "rel_dist": normalized_relative_distances,
        # }
        # target_data = {
        #     "acc": target_acceleration,
        #     "particle_type": particle_type
        # }

        return particle_data, grid_data

    def load_dataset(self, path: Path, split):
        # Load dataset
        data = h5py.File(path / (split + ".h5"))
        return data

    def load_metadata(self, path: Path):
        # Load metadata
        with open(path / "metadata.json") as f:
            metadata = json.loads(f.read())
        return metadata

    def get_positions(self, idx: int, length: int = 1) -> Dict[Tensor, Tensor]:
        # Index where to start in traj
        start_idx = idx % self.n_per_traj
        end_idx = start_idx + length
        time_idx = torch.arange(start_idx, end_idx)
        # Trajectory index
        i_traj = idx // self.n_per_traj
        traj = self.trajectories[self.traj_keys[i_traj]]
        if self.random_particles:
            position = torch.tensor(
                pos_init_cartesian_3d(self.box.cpu().numpy(), dx=self.connectivity_radius / 1.5),
                dtype=torch.float32,
            )
            position = position.unsqueeze(1).repeat(1, len(time_idx), 1)
            # position = (
            #     torch.ones([self.num_points_range[-1], len(time_idx), self.ndim]) * self.box / 2
            # )
            # noise = torch.rand_like(position) * self.occupancy_radius
            # noise = noise - noise.mean()
            # # Add some noise so not all points are exactly at one position (is it necessary? probably not)
            # position += noise
            particle_type = torch.zeros([self.num_points_range[-1]], dtype=torch.int)
        else:
            position = traj["position"][time_idx]
            particle_type = traj["particle_type"][:]
            position = torch.tensor(position)
            particle_type = torch.tensor(particle_type)
            position = einops.rearrange(
                position, "n_timesteps n_particles n_dim -> n_particles n_timesteps n_dim"
            )
        return {
            "position": position,  # [185943, 3, 3]
            "particle_type": particle_type,  # [185943]
            "time_idx": time_idx,  # [3], e.g. [0, 1, 2]
        }

    def get_permutation(self, n_particles, n_perm=None):
        if n_perm is None:
            if self.num_points_range[0] == self.num_points_range[1]:
                # fixed num_points_range
                end = self.num_points_range[1]
            else:
                lb, ub = self.num_points_range
                ub = min(ub, n_particles)
                num_points_range = torch.rand(size=(1,), generator=None).item() * (ub - lb) + lb
                end = int(num_points_range)
            # uniform sampling
            perm = torch.randperm(n_particles, generator=None)[:end]
        else:
            perm = torch.randperm(n_particles, generator=None)[:n_perm]
        return perm

    def get_occupancy(self, position: Tensor, particle_type: Tensor, occ_pos=None):
        if occ_pos is None:
            occ_pos = torch.rand((self.n_occupancy, self.box.shape[0])) * self.box
            occ_pos += self.pos_offset

        # Calculate occupancy
        if self.occ_nearest_backend == "pyg":
            # This implementation takes ~50s on dam3d
            nearest_neighbour = nearest(occ_pos, position)
            distances = torch.norm(occ_pos - position[nearest_neighbour], dim=-1)
            occ = (distances < self.occupancy_radius).to(torch.float)
            occ_type = (particle_type[nearest_neighbour] + 1) * occ
        elif self.occ_nearest_backend == "scipy":
            # tree.query_ball_point takes ~0.5s on dam3d
            tree = scipy.spatial.cKDTree(position.cpu().numpy())
            nbrs = tree.query_ball_point(occ_pos.cpu().numpy(), self.occupancy_radius)
            occ = torch.tensor([len(points) > 0 for points in nbrs]).to(torch.float)

            if len(particle_type.unique()) > 1:
                # This implementation adds 2x the runtime of query_ball_point on dam3d,
                # but it is only used for multiple particle types. Alternatively, using
                # `tree.query(occ_pos, k=1)` is 10x this runtime rather than 2x.
                # Using the occ_type calculation in the else branch below adds 0 time.

                nearest_neighbour = torch.tensor(
                    [
                        (
                            p[torch.argmin(torch.norm(occ_pos[i] - position[p], dim=-1))]
                            if len(p) > 0
                            else 0
                        )
                        for i, p in enumerate(nbrs)
                    ]
                )

                occ_type = (particle_type[nearest_neighbour] + 1) * occ
            else:
                occ_type = occ

        # Calculate occupancy with particle type enumerated
        # particle_type + 1 because particle types begin with 0
        # We want 0 to be for no occupancy, then start with different occupancies
        occ_type = one_hot(occ_type.to(torch.long), num_classes=self.n_particle_types + 1).to(
            torch.float
        )
        return {"occ_pos": occ_pos, "occ_type": occ_type}

    def normalize_vel(self, vel):
        vel = vel - self.vel_mean.to(vel.device)
        vel = vel / self.vel_std.to(vel.device)
        return vel

    def unnormalize_vel(self, vel):
        vel = vel * self.vel_std.to(vel.device)
        vel = vel + self.vel_mean.to(vel.device)
        return vel

    def normalize_acc(self, acc):
        acc = acc - self.acc_mean.to(acc.device)
        acc = acc / self.acc_std.to(acc.device)
        return acc

    def unnormalize_acc(self, acc):
        acc = acc * self.acc_std.to(acc.device)
        acc = acc + self.acc_mean.to(acc.device)
        return acc


@dataclass
class MMDataModule(L.LightningDataModule):
    train_split: str = "train"
    eval_split: str = "valid"
    n_velocities: int = 2
    n_jump_ahead_timesteps: int = 1
    n_jumps: int = 1
    stage: str = "autoencoder"
    n_particle_types: int = None
    n_supernodes: int = None
    num_points_range: int = None
    num_points_decode: int = None
    global_root: str = None
    local_root: str = None
    seed: int = None
    n_occupancy: int = 1000
    occupancy_radius: float = 0.03
    batch_size: int = 4
    num_workers: int = 1
    pin_memory: bool = True
    overfit_single_trajectory: bool = False
    random_particles: bool = False
    follow_batch: Optional[List[str]] = field(
        default_factory=lambda: ["enc_pos", "supernode_index"]
    )
    dataset_rel_path: str = "multi_material"

    def __post_init__(self):
        super().__init__()
        assert self.stage in [
            "autoencoder",
            "physics",
            "rollout",
            "physics_full",
            "gns_f",
            "rollout_gns_f",
        ]

    def setup(self, stage: Optional[str] = None) -> None:
        if self.stage == "autoencoder":
            mode_train = "train_autoencoder"
            mode_val = "train_autoencoder"
        elif self.stage == "physics":
            mode_train = "train_physics"
            mode_val = "val_physics"
        elif (
            self.stage == "physics_full"
        ):  # Training the encoder and decoder also in physics stage
            mode_train = "train_physics_full"
            mode_val = "train_physics_full"
        elif self.stage == "rollout":
            mode_train = "full_traj"
            mode_val = "full_traj"
        elif self.stage == "gns_f":
            mode_train = mode_val = "gns_f"
        elif self.stage == "rollout_gns_f":
            mode_train = mode_val = "full_traj_gns_f"

        self.train_dataset = MMDataset(
            n_velocities=self.n_velocities,
            n_jump_ahead_timesteps=self.n_jump_ahead_timesteps,
            n_jumps=self.n_jumps,
            split=self.train_split,
            mode=mode_train,
            n_particle_types=self.n_particle_types,
            n_supernodes=self.n_supernodes,
            num_points_range=self.num_points_range,
            global_root=self.global_root,
            local_root=self.local_root,
            seed=self.seed,
            n_occupancy=self.n_occupancy,
            occupancy_radius=self.occupancy_radius,
            num_points_decode=self.num_points_decode,
            dataset_rel_path=self.dataset_rel_path,
            overfit_single_trajectory=self.overfit_single_trajectory,
            random_particles=self.random_particles,
        )
        self.val_dataset = MMDataset(
            n_velocities=self.n_velocities,
            n_jump_ahead_timesteps=self.n_jump_ahead_timesteps,
            n_jumps=self.n_jumps,
            split=self.eval_split,
            mode=mode_val,
            n_particle_types=self.n_particle_types,
            n_supernodes=self.n_supernodes,
            num_points_range=self.num_points_range,
            global_root=self.global_root,
            local_root=self.local_root,
            seed=self.seed,
            n_occupancy=self.n_occupancy,
            occupancy_radius=self.occupancy_radius,
            num_points_decode=self.num_points_decode,
            dataset_rel_path=self.dataset_rel_path,
            overfit_single_trajectory=self.overfit_single_trajectory,
            random_particles=self.random_particles,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            follow_batch=self.follow_batch,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            follow_batch=self.follow_batch,
        )
