from torch.utils.data import DataLoader
from .ttasampler import build_sampler

from src.data.cifar_c import CorruptionCIFAR, CorruptionCIFARRecur
from src.data.domainnet import DomainNet126
from src.data.imagenet_c import ImageNetC
from src.data.cifar_c_lazy import CorruptionCIFAR_Lazy
from ..utils.result_precess import AvgResultProcessor

def build_loader(cfg, ds_name, all_corruptions, all_severity):
    if ds_name in ["cifar10c", "cifar100c"]:
        dataset_class = CorruptionCIFAR_Lazy   # ← 用懒加载版本
    elif ds_name in ["cifar10c_recur", "cifar100c_recur"]:
        dataset_class = CorruptionCIFARRecur
    elif ds_name in ["domainnet126", "domainnet126_recur"]:
        dataset_class = DomainNet126
    elif ds_name in ["imagenetc", "imagenetc_recur"]:
        dataset_class = ImageNetC
    else:
        raise NotImplementedError(f"Not Implemented for dataset: {cfg.CORRUPTION.DATASET}")

    ds = dataset_class(cfg, all_corruptions, all_severity)
    sampler = build_sampler(cfg, ds.data_source)
    loader = DataLoader(ds, cfg.TEST.BATCH_SIZE, sampler=sampler, num_workers=cfg.LOADER.NUM_WORKS)

    result_processor = AvgResultProcessor(ds.domain_id_to_name)

    return loader, result_processor
