import pytorch_lightning as pl


from torch.utils.data import DataLoader

from . import evidence_dataset
from . import image_text_pair_dataset
from .. import builder
from . import stage3_relation_dataset


class EvidenceDataModule(pl.LightningDataModule):


    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.dataset = evidence_dataset.EvidenceDataset
        self.collate_fn = evidence_dataset.evidence_collate_fn
        self.num_workers = getattr(cfg.train, 'num_workers', 16)
        self.persistent_workers = getattr(cfg.train, 'persistent_workers', True) if self.num_workers > 0 else False

    def train_dataloader(self):
        dataset = self.dataset(self.cfg, split="train", preprocess=True)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def val_dataloader(self):
        dataset = self.dataset(self.cfg, split="valid", preprocess=True)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def test_dataloader(self):
        dataset = self.dataset(self.cfg, split="test", preprocess=True)
        return DataLoader(
            dataset,
            pin_memory=True,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

class ImageTextPairDataModule(pl.LightningDataModule):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.dataset = image_text_pair_dataset.ImageTextPairDataset
        self.collate_fn = image_text_pair_dataset.image_text_pair_collate_fn
        self.sample_ratio = getattr(cfg.data, 'sample_ratio', 0.1)
        self.paired_csv_path = getattr(cfg.data, 'paired_csv_path', None)
        # 优化：增加num_workers
        self.num_workers = getattr(cfg.train, 'num_workers', 16)
        self.persistent_workers = getattr(cfg.train, 'persistent_workers', True) if self.num_workers > 0 else False

    def train_dataloader(self):
        dataset = self.dataset(self.cfg, split="train", sample_ratio=self.sample_ratio, preprocess_text=True, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def val_dataloader(self):
        dataset = self.dataset(self.cfg, split="valid", sample_ratio=self.sample_ratio, preprocess_text=True, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def test_dataloader(self):
        dataset = self.dataset(self.cfg, split="test", sample_ratio=self.sample_ratio, preprocess_text=True, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

class Stage3RelationDataModule(pl.LightningDataModule):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.dataset = stage3_relation_dataset.Stage3RelationDataset
        self.collate_fn = stage3_relation_dataset.stage3_relation_collate_fn
        self.stage2_sample_ratio = getattr(cfg.data, 'stage2_sample_ratio', 0.1)
        self.paired_csv_path = getattr(cfg.data, 'paired_csv_path', None)
        # 优化：增加num_workers
        self.num_workers = getattr(cfg.train, 'num_workers', 16)
        self.persistent_workers = getattr(cfg.train, 'persistent_workers', True) if self.num_workers > 0 else False

    def train_dataloader(self):
        dataset = self.dataset(self.cfg, split="train", stage2_sample_ratio=self.stage2_sample_ratio, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def val_dataloader(self):
        dataset = self.dataset(self.cfg, split="valid", stage2_sample_ratio=self.stage2_sample_ratio, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            drop_last=False,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )

    def test_dataloader(self):
        dataset = self.dataset(self.cfg, split="test", stage2_sample_ratio=self.stage2_sample_ratio, paired_csv_path=self.paired_csv_path)
        return DataLoader(
            dataset,
            pin_memory=True,
            shuffle=False,
            batch_size=self.cfg.train.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=self.persistent_workers,
            prefetch_factor=getattr(self.cfg.train, 'prefetch_factor', 2) if self.num_workers > 0 else 2,
        )
