# datamodule.py
# this is for exhibition only. We may contruct a real base class in the future
import os
import torch
from torch.utils.data import DataLoader, Dataset

##############################################################################
# 1) Example of a simple Dataset
##############################################################################
class MyDataset(Dataset):
    """
    A generic dataset that reads from CSV, images, or any other data source.
    Adjust to your own data logic.
    """
    def __init__(self, data_path, transform=None):
        """
        Args:
            data_path (str): Path to your data (folder, CSV, etc.).
            transform (callable, optional): Optional transform to apply to each sample.
        """
        super().__init__()
        self.data_path = data_path
        self.transform = transform

        # Example: if data_path is a CSV, you might load it with pandas
        # import pandas as pd
        # self.df = pd.read_csv(data_path)
        # Or if it's images, gather file paths, etc.

        self.samples = self._load_samples()  # custom method to list or parse data

    def _load_samples(self):
        """
        Implement logic to collect or parse samples from data_path.
        """
        # In a real scenario, you'd parse your CSV or read filenames.
        # For demonstration, assume we return a list of items.
        # e.g., return [{"input": ..., "label": ...}, ...]
        return []

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Extract input, label, etc.
        # e.g. x = sample["input"], y = sample["label"]
        x = torch.tensor(0.0)  # placeholder
        y = torch.tensor(0.0)  # placeholder

        if self.transform:
            x = self.transform(x)

        return x, y

##############################################################################
# 2) The DataModule
##############################################################################
class Datamodule:
    def __init__(self, cfg):
        """
        Args:
            cfg: A config object (e.g., Hydra DictConfig or a simple namespace)
                 that contains dataset paths, batch size, num_workers, etc.
        """
        self.cfg = cfg

        # Commonly used fields
        self.batch_size = cfg.dataset.batch_size  # e.g. 32
        self.num_workers = cfg.dataset.num_workers  # e.g. 4

        self.train_path = cfg.dataset.train_path  # e.g. "data/train.csv"
        self.valid_path = cfg.dataset.valid_path  # e.g. "data/valid.csv"
        self.test_path = cfg.dataset.test_path    # e.g. "data/test.csv"

        # Optionally define transforms here or load from config
        self.transform = None  # or some Compose([...]) if needed

        # Initialize placeholders for datasets
        self.train_dataset = None
        self.valid_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):
        """
        stage: One of {"fit", "validate", "test", "predict"} (in a Lightning-like pattern).
               Or you can ignore it if your logic is simpler.
        This method is where you actually create your datasets.
        """
        # If stage is None or "fit", set up train/valid
        if stage in (None, "fit", "train"):
            self.train_dataset = MyDataset(self.train_path, transform=self.transform)
            self.valid_dataset = MyDataset(self.valid_path, transform=self.transform)

        # If stage is None or "validate", just ensure valid_dataset is created
        if stage in (None, "validate"):
            if self.valid_dataset is None:
                self.valid_dataset = MyDataset(self.valid_path, transform=self.transform)

        # If stage is None or "test", set up test dataset
        if stage in (None, "test"):
            self.test_dataset = MyDataset(self.test_path, transform=self.transform)

    def train_dl(self):
        """
        Return a DataLoader for the training set.
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def valid_dl(self):
        """
        Return a DataLoader for the validation set.
        """
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def test_dl(self):
        """
        Return a DataLoader for the test set.
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
