import os
import sys
import types
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Stub heavy modules similar to other tests
import types as _types
art_module = _types.ModuleType("art")
attacks_mod = _types.ModuleType("art.attacks")
poison_mod = _types.ModuleType("art.attacks.poisoning")
pert_mod = _types.ModuleType("art.attacks.poisoning.perturbations")
poison_mod.PoisoningAttackBackdoor = object
poison_mod.PoisoningAttackCleanLabelBackdoor = object
pert_mod.add_pattern_bd = lambda *a, **k: None
sys.modules["art"] = art_module
sys.modules["art.attacks"] = attacks_mod
sys.modules["art.attacks.poisoning"] = poison_mod
sys.modules["art.attacks.poisoning.perturbations"] = pert_mod
utils_mod = _types.ModuleType("art.utils")
utils_mod.load_mnist = lambda *a, **k: None
utils_mod.preprocess = lambda *a, **k: None
utils_mod.to_categorical = lambda *a, **k: None
sys.modules["art.utils"] = utils_mod
pd_mod = _types.ModuleType("pandas")
pd_mod.read_csv = lambda *a, **k: None
pd_mod.DataFrame = object
sys.modules["pandas"] = pd_mod
pyarrow_mod = _types.ModuleType("pyarrow")
parquet_mod = _types.ModuleType("pyarrow.parquet")
pyarrow_mod.__version__ = "0.0"
sys.modules["pyarrow"] = pyarrow_mod
sys.modules["pyarrow.parquet"] = parquet_mod
sys.modules["models.Resnet"] = _types.ModuleType("models.Resnet")
sys.modules["models.Resnet"].ResNet18 = object
sys.modules["models.vggmodule"] = _types.ModuleType("models.vggmodule")
sys.modules["models.vggmodule"].vgg = object
sys.modules["timm"] = _types.ModuleType("timm")
sys.modules["transformers"] = _types.ModuleType("transformers")
sys.modules["transformers"].ViTForImageClassification = object
sys.modules["torchvision"] = _types.ModuleType("torchvision")
datasets_mod = _types.ModuleType("torchvision.datasets")
transforms_mod = _types.ModuleType("torchvision.transforms")
sys.modules["torchvision"].datasets = datasets_mod
sys.modules["torchvision"].transforms = transforms_mod
sys.modules["torchvision.datasets"] = datasets_mod
sys.modules["torchvision.transforms"] = transforms_mod
datasets_mod.ImageFolder = object
transforms_mod.Compose = lambda *a, **k: None
transforms_mod.Resize = lambda *a, **k: None
transforms_mod.RandomHorizontalFlip = lambda *a, **k: None
transforms_mod.RandomRotation = lambda *a, **k: None
transforms_mod.ToTensor = lambda *a, **k: None
transforms_mod.Normalize = lambda *a, **k: None
sk_mod = _types.ModuleType("sklearn.model_selection")
sk_mod.train_test_split = lambda *a, **k: None
sys.modules["sklearn"] = _types.ModuleType("sklearn")
sys.modules["sklearn.model_selection"] = sk_mod

import importlib
init_data_model = importlib.import_module("utils.init_data_model")
importlib.reload(init_data_model)

def test_backdoor_indices_map_to_domains(monkeypatch):
    called = []

    def fake_office_dataset(base, domain, *a, **k):
        called.append(k.get("dataset_name", domain))
        data = torch.zeros(1, 3, 224, 224)
        labels = torch.tensor([0])
        return torch.utils.data.TensorDataset(data, labels)

    monkeypatch.setattr(init_data_model.data_utils, "OfficeDataset", fake_office_dataset)
    monkeypatch.setattr(init_data_model, "_split_loaders", lambda a, tr, te, n: (tr, te))

    args = types.SimpleNamespace(
        dataset="office-caltech10",
        local_bs=1,
        verify="backdoor",
        domain_split_factor=10,
        backdoor_client_idx=list(range(10)),
    )

    init_data_model.init_data(args)

    assert called.count("amazon") >= 2
