from torch.utils.data import Dataset
import numpy as np
import torch
from pathlib import Path

class PlaceholderDataset(Dataset):
    def __init__(
        self,
        num: int,
    ):
        self.num = num

    def __len__(self):
        return self.num

    def __getitem__(self, idx):
        return idx



class RobotDataset(Dataset):
    """
    Custom dataset for loading robot numpy files from directories.
    
    Args:
        data_dirs (list): List of directory paths containing .npy files
        grid_size (int): Expected grid size (default: 128)
        f_dim (int): Expected feature dimension (default: 4)
        file_extension (str): File extension to look for (default: '.npy')
    """
    
    def __init__(self, data_dirs, file_extension='.npy'):
        self.data_dirs = data_dirs if isinstance(data_dirs, list) else [data_dirs]
        self.file_extension = file_extension
        
        self.file_paths = []
        for data_dir in self.data_dirs:
            data_dir = Path(data_dir)
            if not data_dir.exists():
                raise ValueError(f"Directory {data_dir} does not exist")

            files = list(data_dir.glob(f"**/*{file_extension}"))
            self.file_paths.extend(files)
        
        print(f"Found {len(self.file_paths)} robot files in {len(self.data_dirs)} directories")
        
        if len(self.file_paths) == 0:
            raise ValueError("No robot files found in the specified directories")
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        """
        Load and return a robot sample.
        
        Returns:
            torch.Tensor: Robot voxel data of shape [f_dim, grid_size, grid_size, grid_size]
        """
        file_path = self.file_paths[idx]
        
        try:
            robot_data = np.load(file_path)
            
            robot_tensor = torch.from_numpy(robot_data).float()
            
            return robot_tensor
            
        except Exception as e:
            raise RuntimeError(f"Error loading robot file {file_path}: {str(e)}")

    
    def get_sample_info(self, idx):

        file_path = self.file_paths[idx]
        data = np.load(file_path)
        return {
            'file_path': str(file_path),
            'original_shape': data.shape,
            'original_dtype': data.dtype,
            'data_range': (data.min(), data.max()),
            'non_zero_count': np.count_nonzero(data),
        }