import lightning
import torch
import logging
from typing import Optional

from third_party.LDMI.ldm.data.pointcloud import VoxelTrain, VoxelVal, VoxelTest
from pdb import set_trace as bb
class ShapeNetVoxelDataModule(lightning.LightningDataModule):
    """
    Lightning DataModule wrapper for LDMI's voxel datasets that actually work
    """

    def __init__(self, args):
        super().__init__()
        self.args = args

        self.data_root = getattr(args, "train_datadir", "./data/hypernet-data")
        self.batch_size = getattr(args, "batch_size", 4)
        self.num_workers = getattr(args, "num_workers", 4)


        print(f"LDMI Voxel DataModule initialized with data_root: {self.data_root}")

    def setup(self, stage=None):
        """Setup datasets using LDMI's working data loaders"""
        if stage == "fit" or stage is None:
            import os
            expected_voxel_dir = os.path.join(self.data_root, "shapenet_voxels")
            print(f"Looking for voxel data in: {expected_voxel_dir}")
            if os.path.exists(expected_voxel_dir):
                voxel_files = [f for f in os.listdir(expected_voxel_dir) if f.endswith('.pt')]
                print(f"Found {len(voxel_files)} .pt files in {expected_voxel_dir}")
                if len(voxel_files) > 0:
                    print(f"Sample files: {voxel_files[:5]}")
            else:
                print(f"WARNING: Voxel data directory does not exist: {expected_voxel_dir}")
            
            train_ds = VoxelTrain(data_root=self.data_root)
            print(f"Single VoxelTrain dataset length: {len(train_ds)}")
            
            if len(train_ds) == 0:
                raise ValueError(
                    f"Training dataset is empty! Check if voxel files exist in {expected_voxel_dir}. "
                    f"Data root: {self.data_root}"
                )
            
            self.train_dataset = torch.utils.data.ConcatDataset([
                VoxelTrain(data_root=self.data_root) for _ in range(1000)
            ])
            self.val_dataset = VoxelVal(data_root=self.data_root)
            print(f"Training dataset loaded: {len(self.train_dataset)} samples")
            print(f"Validation dataset loaded: {len(self.val_dataset)} samples")

        if stage == "test" or stage is None:
            self.test_dataset = VoxelTest(data_root=self.data_root)
            print(f"Test dataset loaded: {len(self.test_dataset)} samples")
            self.val_dataset = VoxelVal(data_root=self.data_root)

    def train_dataloader(self):
        """Create training dataloader"""
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self):
        """Create validation dataloader"""
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self):
        """Create test dataloader"""
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=False,
        )