"""Data loader factory for creating training and validation data loaders.

This module provides functions to create data loaders for different dataset types
with support for distributed training and automatic validation set creation.
"""

from typing import Dict, Optional, Any, Tuple
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from src.dataset.pde_1d_dataset import PDE1DDataset


def build_data_loaders(train_path: str, val_path: str, batch_size: int,
                       n_workers: int, dataset_params: Optional[Dict[str, Any]] = None,
                       debug: bool = False) -> Dict[
    str, DataLoader]:
    """
    Create data loaders for training and validation.

    Parameters
    ----------
    train_path : str
        Path to training data file.
    val_path : str
        Path to validation data file.
    batch_size : int
        Batch size for training and validation.
    n_workers : int
        Number of data loader workers.
    debug : bool
        Whether to run in debug mode (reduces dataset size).
    dataset_params : dict, optional
        Additional parameters for dataset instantiation.

    Returns
    -------
    dict
        Dictionary containing 'train' and 'val' data loaders.
    """

    find_data_paths(train_path, val_path)

    train_set = PDE1DDataset(data_path=train_path, **(dataset_params or {}))
    train_set.n_samples = 20 if debug else train_set.n_samples

    val_set = PDE1DDataset(data_path=val_path, **(dataset_params or {}))
    val_set.n_samples = 20 if debug else val_set.n_samples

    return {
        "train": DataLoader(
            train_set,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True,
            num_workers=n_workers
        ),
        "val": DataLoader(
            val_set,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True,
            num_workers=n_workers
        ),
    }


def find_data_paths(train_path: str, val_path: str) -> Tuple[Path, Path]:
    """
    Find valid data paths from provided lists.
    """
    if not Path(train_path).exists():
        if (Path(__file__).parent.parent.parent / train_path).exists():
            train_path = Path(__file__).parent.parent.parent / train_path
        else:
            raise FileNotFoundError(f"Training data path {train_path} not found")
    if not Path(val_path).exists():
        if (Path(__file__).parent.parent.parent / val_path).exists():
            val_path = Path(__file__).parent.parent.parent / val_path
        else:
            raise FileNotFoundError(f"Validation data path {val_path} not found")
    return train_path, val_path