import numpy as np
from typing import Any

class DynamicalSystemDataset:
    def __init__(self, data_path: str):
        data = np.load(data_path)
        self.ids = list(data['ids'])
        
        traj = []
        params = []
        meta = []

        for id_ in data['ids']:
            traj.append(data[f'traj_{id_}'])
            params.append(data[f'params_{id_}'])

            meta_key = f'meta_{id_}'
            if meta_key in data:
                meta.append(data[meta_key])
            else:
                meta.append({})

        self.traj = np.array(traj)
        self.params = np.array(params)
        self.meta = meta

    def __len__(self) -> int:
        return len(self.traj)
    
    def get_traj(self, idx: int) -> np.ndarray:
        return self.traj[idx]
    
    def get_batch(self, indices: np.ndarray) -> np.ndarray:
        return self.traj[indices]
    
    def get_params(self, idx: int) -> np.ndarray:
        return self.params[idx]
    
    def get_meta(self, idx: int) -> dict[str, Any]:
        return self.meta[idx]