import lightning
import torch
import logging
from typing import Optional

from third_party.LDMI.ldm.data.imagenet import ImageNetTest, ImageNetTrain, ImageNetValidation
from pdb import set_trace as bb
class ImageNetDataModule(lightning.LightningDataModule):
    """
    Lightning DataModule wrapper for LDMI's voxel datasets that actually work
    """

    def __init__(self, args):
        super().__init__()
        self.args = args

        # Dataset configuration
        self.data_root = getattr(args, "train_datadir", "./data/imagenet")
        self.batch_size = getattr(args, "batch_size", 24)
        self.num_workers = getattr(args, "num_workers", 4)


        print(f"LDMI Voxel DataModule initialized with data_root: {self.data_root}")

    def setup(self, stage=None):
        """Setup datasets using LDMI's working data loaders"""
        if stage == "fit" or stage is None:
            self.train_dataset = ImageNetTrain(data_root=self.data_root)
            self.val_dataset = ImageNetValidation(data_root=self.data_root)
            print(f"Training dataset loaded: {len(self.train_dataset)} samples")
            print(f"Validation dataset loaded: {len(self.val_dataset)} samples")

        if stage == "test" or stage is None:
            self.test_dataset = ImageNetTest(data_root=self.data_root)
            print(f"Test dataset loaded: {len(self.test_dataset)} samples")
            self.val_dataset = ImageNetValidation(data_root=self.data_root)

    def train_dataloader(self):
        """Create training dataloader"""
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self):
        """Create validation dataloader"""
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self):
        """Create test dataloader"""
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )