import lightning
import torch
import logging
from typing import Optional
import os
from third_party.LDMI.ldm.data.era import ERA5Dataset
from pdb import set_trace as bb

class ERA5DataModule(lightning.LightningDataModule):
    """
    Lightning DataModule wrapper for LDMI's celebahq dataset that actually work
    """

    def __init__(self, args):
        super().__init__()
        self.args = args

        # Dataset configuration
        self.data_root = getattr(args, "train_datadir", "./data/era5/era5_temp2m_16x_train")
        self.val_data_root = getattr(args, "val_datadir", "./data/era5/era5_temp2m_16x_val")
        self.test_data_root = getattr(args, "test_datadir", "./data/era5/era5_temp2m_16x_test")
        self.batch_size = getattr(args, "batch_size", 4)
        self.num_workers = getattr(args, "num_workers", 8)


        print(f"LDMI CelebAHQ DataModule initialized with data_root: {self.data_root}")

    def setup(self, stage=None):
        """Setup datasets using LDMI's working data loaders"""
        if stage == "fit" or stage is None:
            self.train_dataset = ERA5Dataset(self.data_root)
            self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset for _ in range(10)])
            self.val_dataset = ERA5Dataset(self.val_data_root)
            print(f"Training dataset loaded: {len(self.train_dataset)} samples")
            print(f"Validation dataset loaded: {len(self.val_dataset)} samples")

        if stage == "test" or stage is None:
            self.test_dataset = ERA5Dataset(self.test_data_root)
            print(f"Test dataset loaded: {len(self.test_dataset)} samples")

    def train_dataloader(self):
        """Create training dataloader"""
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

    def val_dataloader(self):
        """Create validation dataloader"""
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=False,
        )

    def test_dataloader(self):
        """Create test dataloader"""
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=False,
        )