import os
import sys
import types
import importlib.util
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

spec = importlib.util.spec_from_file_location(
    'fedsalun_mod', os.path.join(os.path.dirname(__file__), '..', 'unlearning_methods', 'fedsalun.py')
)
fedsalun_mod = importlib.util.module_from_spec(spec)

# stub modules
stubs = {
    'models.Update_domain': types.ModuleType('models.Update_domain'),
    'utils.options': types.ModuleType('utils.options'),
    'utils.evaluate': types.ModuleType('utils.evaluate'),
    'utils.init_data_model': types.ModuleType('utils.init_data_model'),
}

stubs['models.Update_domain'].DomainClientUpdate = lambda *a, **k: types.SimpleNamespace(
    train=lambda *args, **kwargs: (torch.nn.Linear(1, 1), {}, {}, [0.0])
)
stubs['models.Update_domain'].DomainClientUpdate_avg = lambda *a, **k: types.SimpleNamespace(train=lambda net: (net, 0.0))
stubs['models.Update_domain'].DomainClientUnlearningBl3 = object
stubs['models.Update_domain'].DomainClientUpdate_avg_sal = object

stubs['utils.evaluate'].evaluate = lambda **kw: ([[],[]], [0]*kw['args'].num_users)
stubs['utils.init_data_model'].init_data = lambda args, **k: ([0]*args.num_users, [0]*args.num_users, None)
stubs['utils.init_data_model'].init_data_methodone = lambda args, **k: ([0]*args.num_users, [0]*args.num_users, None)
stubs['utils.init_data_model'].get_dataset = lambda args: [f'c{i}' for i in range(args.num_users)]
stubs['utils.init_data_model'].init_model = lambda args: torch.nn.Linear(1,1)
stubs['utils.options'].args_parser = lambda: None

sys.modules.update(stubs)

spec.loader.exec_module(fedsalun_mod)

def test_fedsalun_multi_client(monkeypatch):
    recorded = {}
    def fake_generate_federated_mask(deltas, args):
        recorded['num'] = len(deltas)
        return {args.mask_ratio: {"weight": torch.tensor(0.)}}
    monkeypatch.setattr(fedsalun_mod, 'generate_federated_mask', fake_generate_federated_mask)
    monkeypatch.setattr(fedsalun_mod.torch, 'save', lambda *a, **k: None)
    def fake_load(path, map_location=None):
        if 'weight_global.pth' in path:
            m = torch.nn.Linear(1,1)
            return m.state_dict()
        if 'global_epoch' in path:
            m = torch.nn.Linear(1,1)
            return m
        if 'weight_local.pth' in path:
            m0 = torch.nn.Linear(1,1); m1 = torch.nn.Linear(1,1)
            return {0: m0.state_dict(), 1: m1.state_dict()}
        if 'train_loaders.pth' in path:
            class DummyDS:
                def __init__(self):
                    self.targets = [0]
                def __len__(self):
                    return 1
            loader = types.SimpleNamespace(dataset=DummyDS())
            return [loader, loader]
        if 'test_loaders.pth' in path:
            class DummyDS:
                def __init__(self):
                    self.targets = [0]
                def __len__(self):
                    return 1
            loader = types.SimpleNamespace(dataset=DummyDS())
            return [loader, loader]
        return types.SimpleNamespace(state_dict=lambda: torch.nn.Linear(1,1).state_dict())
    monkeypatch.setattr(fedsalun_mod.torch, 'load', fake_load)

    args = types.SimpleNamespace(
        seed=0,
        gpu=-1,
        dataset_fullparti=False,
        dataset='dummy',
        save='tmp',
        model='tmp',
        verify='ver',
        domain_split_factor=1,
        backdoor_client_idx=[0,1],
        unlearning_client=[0,1],
        num_users=2,
        fedsalun_epoch=1,
        unlearn_epoch=0,
        epochs=1,
        lamb=0.5,
        unlearn_lr=0.1,
        num_classes=10,
        backdoor_target_label=0,
        mask_ratio=0.1,
        diff_mask_ratio=0.1,
        target='fedsalun'
    )
    fedsalun_mod.fedsalun(args)
    assert recorded['num'] == 2
