from torch.utils.data import DataLoader
import numpy as np
from dataset.dataProvider import DataProvider
from torch.utils.data.distributed import DistributedSampler

def my_worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)


def get_TV_dl(cfg, train_dataset, val_dataset, domain='source', rank=0, world_size=1):
  
    if world_size > 1:
        train_sampler = DistributedSampler(
            train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True
        )

        # For the source domain, use DataProvider with the sampler
        if domain == 'source':
            train_loader = DataProvider(
                train_dataset,
                sampler=train_sampler,
                batch_size=cfg.DATALOADER.TRA_BATCH_SIZE,
                prefetch_factor=cfg.DATALOADER.TRA_BATCH_SIZE,
                num_workers=cfg.DATALOADER.NUM_WORKERS,
                worker_init_fn=my_worker_init_fn,
                collate_fn=train_dataset.collate_fn,
                pin_memory=True,
                drop_last=True
            )
        else:
            # For the target domain, use DataLoader directly
            train_loader = DataLoader(
                train_dataset,
                sampler=train_sampler,
                batch_size=cfg.DATALOADER.TRA_BATCH_SIZE,
                prefetch_factor=cfg.DATALOADER.TRA_BATCH_SIZE,
                num_workers=cfg.DATALOADER.NUM_WORKERS,
                worker_init_fn=my_worker_init_fn,
                collate_fn=train_dataset.collate_fn,
                pin_memory=True,
                drop_last=True
            )

        # Validation DataLoader (no sampler)
        val_loader = DataLoader(
            val_dataset,
            batch_size=cfg.DATALOADER.VAL_BATCH_SIZE,
            prefetch_factor=cfg.DATALOADER.VAL_BATCH_SIZE,
            shuffle=False,
            num_workers=cfg.DATALOADER.NUM_WORKERS,
            worker_init_fn=my_worker_init_fn,
            collate_fn=val_dataset.collate_fn,
            pin_memory=True,
            drop_last=False
        )
        return train_loader, val_loader, train_sampler
    else:
        if domain == 'source':
            train_source_loader = DataProvider(
                train_dataset,
                batch_size=cfg.DATALOADER.TRA_BATCH_SIZE,
                prefetch_factor=cfg.DATALOADER.TRA_BATCH_SIZE,
                shuffle=True,
                num_workers=cfg.DATALOADER.NUM_WORKERS,
                worker_init_fn=my_worker_init_fn,
                collate_fn=train_dataset.collate_fn,
                pin_memory=True,
                drop_last=True
            )
        else:
            train_source_loader = DataLoader(
                train_dataset,
                batch_size=cfg.DATALOADER.TRA_BATCH_SIZE,
                prefetch_factor=cfg.DATALOADER.TRA_BATCH_SIZE,
                shuffle=True,
                num_workers=cfg.DATALOADER.NUM_WORKERS,
                worker_init_fn=my_worker_init_fn,
                collate_fn=train_dataset.collate_fn,
                pin_memory=True,
                drop_last=True
            )
        val_source_loader = DataLoader(
            val_dataset,
            batch_size=cfg.DATALOADER.VAL_BATCH_SIZE,
            prefetch_factor=cfg.DATALOADER.VAL_BATCH_SIZE,
            shuffle=False,
            num_workers=cfg.DATALOADER.NUM_WORKERS,
            worker_init_fn=my_worker_init_fn,
            collate_fn=val_dataset.collate_fn,
            pin_memory=True,
            drop_last=False
        )
    
        return train_source_loader, val_source_loader, None #, train_sampler, valid_sampler