import lightning
import torch
import logging
from typing import Optional
import os
from third_party.LDMI.ldm.data.celebahq import CelebAHQ
from pdb import set_trace as bb

class CelebAHQDataModule(lightning.LightningDataModule):
    """
    Lightning DataModule wrapper for LDMI's celebahq dataset that actually work
    """

    def __init__(self, args):
        super().__init__()
        self.args = args

        # Dataset configuration
        self.data_root = getattr(args, "train_datadir", "./data/celebahq")
        self.batch_size = getattr(args, "batch_size", 1)
        self.num_workers = getattr(args, "num_workers", 8)


        print(f"LDMI CelebAHQ DataModule initialized with data_root: {self.data_root}")

    def setup(self, stage=None):
        """Setup datasets using LDMI's working data loaders"""
        if stage == "fit" or stage is None:
            self.train_dataset = CelebAHQ(os.path.join(self.data_root, "train"), size=64)
            self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset for _ in range(1000)])
            self.val_dataset = CelebAHQ(os.path.join(self.data_root, "val"), size=64)
            print(f"Training dataset loaded: {len(self.train_dataset)} samples")
            print(f"Validation dataset loaded: {len(self.val_dataset)} samples")

        if stage == "test" or stage is None:
            self.test_dataset = CelebAHQ(os.path.join(self.data_root, "val"), size=64)
            print(f"Test dataset loaded: {len(self.test_dataset)} samples")

    def train_dataloader(self):
        """Create training dataloader"""
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

    def val_dataloader(self):
        """Create validation dataloader"""
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=False,
        )

    def test_dataloader(self):
        """Create test dataloader"""
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=False,
        )