import pytorch_lightning as pl 
from datasets.rl_waymo import RLWaymoDataset 
from datasets.rl_waymo import RLWaymoDiffDataset
from torch_geometric.loader import DataLoader
import os

def worker_init_fn(worker_id):
    os.sched_setaffinity(0, range(os.cpu_count())) 

class RLWaymoDataModule(pl.LightningDataModule):

    def __init__(self,
                 cfg, 
                 use_diffusion=False):
        super(RLWaymoDataModule, self).__init__()
        self.use_diffusion = use_diffusion
        if self.use_diffusion:
            self.cfg_dataset = cfg.datasets.rl_waymo_diffusion
            self.cfg_datamodule = cfg.train_diffusion.datamodule
        else:
            self.cfg_dataset = cfg.datasets.rl_waymo
            self.cfg_datamodule = cfg.train.datamodule
        

    def setup(self, stage):
        if self.use_diffusion:
            self.train_dataset = RLWaymoDiffDataset(self.cfg_dataset, split_name='train')
            self.val_dataset = RLWaymoDiffDataset(self.cfg_dataset, split_name='val') 
        else:
            self.train_dataset = RLWaymoDataset(self.cfg_dataset, split_name='train')
            self.val_dataset = RLWaymoDataset(self.cfg_dataset, split_name='val') 


    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.cfg_datamodule.train_batch_size, 
                          shuffle=True, 
                          num_workers=self.cfg_datamodule.num_workers,
                          pin_memory=self.cfg_datamodule.pin_memory,
                          drop_last=True,
                          worker_init_fn=worker_init_fn)


    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg_datamodule.val_batch_size,
                          shuffle=False,
                          num_workers=self.cfg_datamodule.num_workers,
                          pin_memory=self.cfg_datamodule.pin_memory,
                          drop_last=True)