"""
Simplified Lightning DataModule for LIBERO dataset.
"""

import copy
from pathlib import Path
from typing import Optional, Dict, Any
import lightning as pl

from torch.utils.data import DataLoader, random_split

import torch

from .libero_dataloader import LiberoDataset, LiberoDataConfig, LiberoDataLoader


class LiberoDataModule(pl.LightningDataModule):
    """Simplified Lightning DataModule for LIBERO dataset."""

    def __init__(
        self,
        task_suite_name: str = "libero_spatial",
        data_root_path: str = "./data/libero",
        image_size: tuple = (224, 224),
        batch_size: int = 32,
        num_workers: int = 4,
        train_val_split: float = 0.8,
        seed: int = 42,
        horizon: int = 20,
        **kwargs,
    ):
        """
        Initialize LiberoDataModule.

        Args:
            task_suite_name: LIBERO task suite name
            data_root_path: Path to LIBERO data
            image_size: Image size (height, width)
            batch_size: Batch size for dataloaders
            num_workers: Number of worker processes
            train_val_split: Train/validation split ratio
            seed: Random seed for splits
            horizon: Action sequence horizon length
        """
        super().__init__()

        # Pop eval_config from kwargs if present (dataset-specific eval configuration)
        self.eval_config = kwargs.pop('eval_config', None)

        # Store parameters
        self.task_suite_name = task_suite_name
        self.data_root_path = Path(data_root_path)
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_val_split = train_val_split
        self.seed = seed
        self.horizon = horizon

        # Create dataset configuration
        self.config = LiberoDataConfig(
            task_suite_name=task_suite_name,
            data_root_path=str(data_root_path),
            image_size=image_size,
            horizon=horizon,
            **kwargs,
        )

        # Initialize datasets
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: Optional[str] = None):
        """Setup datasets for training, validation, and testing."""

        if stage == "fit" or stage is None:
            # Load train split
            train_config = copy.deepcopy(self.config)
            train_config.split = "train"

            full_train_dataset = LiberoDataset(train_config)

            # Split into train and validation
            train_size = int(self.train_val_split * len(full_train_dataset))
            val_size = len(full_train_dataset) - train_size

            generator = torch.Generator().manual_seed(self.seed)
            self.train_dataset, self.val_dataset = random_split(
                full_train_dataset, [train_size, val_size], generator=generator
            )

            print(f"Train dataset size: {len(self.train_dataset)}")
            print(f"Validation dataset size: {len(self.val_dataset)}")

        if stage == "test" or stage is None:
            # Load test split
            test_config = copy.deepcopy(self.config)
            test_config.split = "test"

            self.test_dataset = LiberoDataset(test_config)
            print(f"Test dataset size: {len(self.test_dataset)}")

    def train_dataloader(self) -> DataLoader:
        """Create training dataloader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=False,
            persistent_workers=True if self.num_workers > 0 else False,
        )

    def val_dataloader(self) -> DataLoader:
        """Create validation dataloader."""
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=False,
            persistent_workers=True if self.num_workers > 0 else False,
        )

    def test_dataloader(self) -> DataLoader:
        """Create test dataloader."""
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=False,
            persistent_workers=False,
        )

    def get_config(self) -> Dict[str, Any]:
        """Get configuration dictionary."""
        return {
            "task_suite_name": self.task_suite_name,
            "data_root_path": str(self.data_root_path),
            "image_size": self.image_size,
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "train_val_split": self.train_val_split,
            "seed": self.seed,
            "horizon": self.horizon,
        }

    def get_eval_config(self) -> Optional[Dict[str, Any]]:
        """Get evaluation configuration if available."""
        return self.eval_config
