import torch
from torch.utils.data import Dataset, DataLoader

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ImageVoxelAdapterDataset(Dataset):
    def __init__(self, voxel_file, image_file, adapter_file, size=None, indices = None):
        """
        Args:
            voxel_file (str): Path to the file containing voxel data.
            image_file (str): Path to the file containing image data.
            adapter_file (torch.Tensor): Adapter embeddings tensor.
            size (int, optional): Number of samples to select from the full dataset.
        """
        self.features = torch.load(voxel_file, map_location='cpu')
        self.images = torch.load(image_file, map_location='cpu')  # Assuming images are stored in .npy format
        if self.images.min()<0:
            self.images = (self.images + 1) / 2
        self.adapter_embeds = adapter_file
        
        # Ensure size doesn't exceed the dataset length
        if size is None or size > len(self.images):
            size = len(self.images)
        
        # Select random indices based on the given size
        if indices is None:
            self.indices = np.random.choice(len(self.images), size, replace=False)
        else:
            self.indices = indices
        
        # Index the tensors upfront
        self.features = self.features[self.indices]
        self.images = self.images[self.indices]
        self.adapter_embeds = self.adapter_embeds[self.indices]

    def __len__(self):
        return len(self.images)

    def get_indices(self):
        return self.indices
        
    def __getitem__(self, idx):
        image = self.images[idx]
        voxel = self.features[idx]
        adapter = self.adapter_embeds[idx]
        return voxel, image, adapter


import torch
from torch.utils.data import Dataset, DataLoader

class ImageVoxelDataset(Dataset):
    def __init__(self, voxel_file, image_file):
        """
        Args:
            image_ids (numpy array): Array of unique image IDs.
            voxel_ids (numpy array): 2D array of voxel IDs with shape (3, unique_image_ids).
        """
        self.features = torch.load(voxel_file, map_location='cpu')
        #self.features = torch.tensor(feature_file)
        self.images = torch.load(image_file, map_location='cpu')  # Assuming images are stored in .npy format
        if self.images.min()<0:
            self.images = (self.images + 1) / 2
        # if self.images.min()==0:
        #     self.images = 2*self.images - 1

    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        voxel = self.features[idx]  # Get the (3,) voxel_ids corresponding to this image_id
        return voxel, image