import pytorch_lightning as pl 
from datasets.rl_waymo import RLWaymoDatasetFineTuning 
from datasets.rl_waymo import RLWaymoDataset
from torch_geometric.loader import DataLoader

class RLWaymoDataModuleFineTuning(pl.LightningDataModule):

    def __init__(self,
                 cfg):
        
        super(RLWaymoDataModuleFineTuning, self).__init__()
        self.cfg_dataset = cfg.datasets.rl_waymo
        self.cfg_datamodule = cfg.train.datamodule
        

    def setup(self, stage):
        self.train_dataset = RLWaymoDatasetFineTuning(self.cfg_dataset)
        self.val_dataset = RLWaymoDataset(self.cfg_dataset, split_name='val') 


    def sample_real_indices(self):
        # Re-sample real dataset indices at the start of each training epoch
        self.train_dataset.sample_real_indices()


    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.cfg_datamodule.train_batch_size, 
                          shuffle=True, 
                          num_workers=self.cfg_datamodule.num_workers,
                          pin_memory=self.cfg_datamodule.pin_memory,
                          drop_last=True)


    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg_datamodule.val_batch_size,
                          shuffle=False,
                          num_workers=self.cfg_datamodule.num_workers,
                          pin_memory=self.cfg_datamodule.pin_memory,
                          drop_last=True)
