import os
import sys
import types
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import types as _types
art_module = _types.ModuleType("art")
attacks_mod = _types.ModuleType("art.attacks")
poison_mod = _types.ModuleType("art.attacks.poisoning")
pert_mod = _types.ModuleType("art.attacks.poisoning.perturbations")
poison_mod.PoisoningAttackBackdoor = object
poison_mod.PoisoningAttackCleanLabelBackdoor = object
pert_mod.add_pattern_bd = lambda *a, **k: None
sys.modules["art"] = art_module
sys.modules["art.attacks"] = attacks_mod
sys.modules["art.attacks.poisoning"] = poison_mod
sys.modules["art.attacks.poisoning.perturbations"] = pert_mod
utils_mod = _types.ModuleType("art.utils")
utils_mod.load_mnist = lambda *a, **k: None
utils_mod.preprocess = lambda *a, **k: None
utils_mod.to_categorical = lambda *a, **k: None
sys.modules["art.utils"] = utils_mod
pd_mod = _types.ModuleType("pandas")
pd_mod.read_csv = lambda *a, **k: None
pd_mod.DataFrame = object
sys.modules["pandas"] = pd_mod
pyarrow_mod = _types.ModuleType("pyarrow")
parquet_mod = _types.ModuleType("pyarrow.parquet")
pyarrow_mod.__version__ = "0.0"
sys.modules["pyarrow"] = pyarrow_mod
sys.modules["pyarrow.parquet"] = parquet_mod
sys.modules["models.Resnet"] = _types.ModuleType("models.Resnet")
sys.modules["models.Resnet"].ResNet18 = object
sys.modules["models.vggmodule"] = _types.ModuleType("models.vggmodule")
sys.modules["models.vggmodule"].vgg = object
sys.modules["timm"] = _types.ModuleType("timm")
sys.modules["transformers"] = _types.ModuleType("transformers")
sys.modules["transformers"].ViTForImageClassification = object
sys.modules["torchvision"] = _types.ModuleType("torchvision")
datasets_mod = _types.ModuleType("torchvision.datasets")
transforms_mod = _types.ModuleType("torchvision.transforms")
sys.modules["torchvision"].datasets = datasets_mod
sys.modules["torchvision"].transforms = transforms_mod
sys.modules["torchvision.datasets"] = datasets_mod
sys.modules["torchvision.transforms"] = transforms_mod
datasets_mod.ImageFolder = object
transforms_mod.Compose = lambda *a, **k: None
transforms_mod.Resize = lambda *a, **k: None
transforms_mod.RandomHorizontalFlip = lambda *a, **k: None
transforms_mod.RandomRotation = lambda *a, **k: None
transforms_mod.ToTensor = lambda *a, **k: None
sk_mod = _types.ModuleType("sklearn.model_selection")
sk_mod.train_test_split = lambda *a, **k: None
sys.modules["sklearn"] = _types.ModuleType("sklearn")
sys.modules["sklearn.model_selection"] = sk_mod
from utils import init_data_model


def test_backdoor_label_validation(monkeypatch):
    call_targets = []

    def fake_dataset(*args, **kwargs):
        call_targets.append(kwargs.get("target_label", None))
        # return a tiny dataset so DataLoader can iterate
        data = torch.zeros(1, 3, 224, 224)
        labels = torch.tensor([0])
        return torch.utils.data.TensorDataset(data, labels)

    # Patch functions that touch the filesystem or network
    monkeypatch.setattr(init_data_model, "snapshot_download", lambda **kw: None)
    monkeypatch.setattr(init_data_model.os.path, "isdir", lambda p: True)
    monkeypatch.setattr(init_data_model.os, "scandir", lambda p: iter([object()]))
    monkeypatch.setattr(init_data_model.data_utils, "PACSDataset", fake_dataset)
    monkeypatch.setattr(init_data_model, "_split_loaders", lambda a, tr, te, n: (tr, te))

    args = types.SimpleNamespace(
        local_bs=1,
        backdoor_client_idx=[0],
        backdoor_target_label=9,
        backdoor_percent_poison=0.1,
    )

    init_data_model.create_pacs_datasets(args, inject_backdoor=True)

    # the target label should be reset to 0 and used for all dataset initializations
    assert args.backdoor_target_label == 0
    assert all(t == 0 for t in call_targets if t is not None)
