import os
import sys
import types
import importlib.util
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# create stub modules so rapid_retrain can be imported in isolation
stubs = {
    'models.Fed': types.ModuleType('models.Fed'),
    'utils.init_data_model': types.ModuleType('utils.init_data_model'),
    'utils.evaluate': types.ModuleType('utils.evaluate'),
    'models.Update_domain': types.ModuleType('models.Update_domain'),
    'utils.ada_hessain': types.ModuleType('utils.ada_hessain'),
    'matplotlib': types.ModuleType('matplotlib'),
}

stubs['models.Fed'].FedAvg = lambda *a, **k: None
stubs['utils.init_data_model'].init_data = lambda *a, **k: None
stubs['utils.init_data_model'].init_model = lambda *a, **k: None
stubs['utils.init_data_model'].init_data_methodone = lambda *a, **k: None
stubs['utils.init_data_model'].get_dataset = lambda *a, **k: []
stubs['utils.evaluate'].DeltaWeight = lambda *a, **k: 0.0
stubs['utils.evaluate'].evaluate = lambda *a, **k: (a[0] if a else [], [0])
stubs['models.Update_domain'].DomainClientUpdate_Hesian_record = object
stubs['matplotlib'].use = lambda *a, **k: None

sys.modules.update(stubs)

spec = importlib.util.spec_from_file_location(
    'rapid_retrain_mod',
    os.path.join(os.path.dirname(__file__), '..', 'unlearning_methods', 'rapid_retrain.py')
)
rapid_retrain_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(rapid_retrain_mod)


def create_state(val):
    model = torch.nn.Linear(1, 1)
    with torch.no_grad():
        model.weight.fill_(val)
        model.bias.fill_(0)
    return model.state_dict()


def test_rapid_retrain_multi_client(monkeypatch):
    trained = []

    def fake_init_data(args):
        assert args.dataset == 'PACS'
        loaders = list(range(args.num_users))
        return loaders, loaders, None

    def fake_init_model(args):
        assert args.dataset == 'PACS'
        return torch.nn.Linear(1, 1)

    class DummyLocal:
        def __init__(self, args, train_loader):
            self.idx = train_loader

        def train(self, net):
            trained.append(self.idx)
            return create_state(self.idx + 10), float(self.idx)

    def fake_FedAvg(w_locals):
        return w_locals[0]

    def fake_evaluate(**kwargs):
        return kwargs['example_stats'], [i for i in range(kwargs['args'].num_users)]

    base_state = create_state(0.0)

    def fake_load(path, map_location=None):
        return base_state

    saved = []

    def fake_save(obj, path):
        saved.append((path, obj))

    monkeypatch.setattr(rapid_retrain_mod, 'init_data_methodone', fake_init_data)
    monkeypatch.setattr(rapid_retrain_mod, 'init_data', fake_init_data)
    monkeypatch.setattr(rapid_retrain_mod, 'init_model', fake_init_model)
    monkeypatch.setattr(rapid_retrain_mod, 'get_dataset', lambda args: [f'c{i}' for i in range(args.num_users)])
    monkeypatch.setattr(rapid_retrain_mod, 'DomainClientUpdate_Hesian_record', DummyLocal)
    monkeypatch.setattr(rapid_retrain_mod, 'FedAvg', fake_FedAvg)
    monkeypatch.setattr(rapid_retrain_mod, 'evaluate', fake_evaluate)
    monkeypatch.setattr(rapid_retrain_mod, 'DeltaWeight', lambda *a, **k: 0.0)
    monkeypatch.setattr(rapid_retrain_mod.torch, 'load', fake_load)
    monkeypatch.setattr(rapid_retrain_mod.torch, 'save', fake_save)
    real_exists = rapid_retrain_mod.os.path.exists
    monkeypatch.setattr(rapid_retrain_mod.os.path, 'exists', lambda path: True if '/learning/' in path else real_exists(path))

    args = types.SimpleNamespace(
        seed=0,
        gpu=-1,
        dataset_fullparti=False,
        dataset='PACS',
        domain_split_factor=1,
        domain_times_factor=1,
        save='tmp',
        unlearning_client=[0, 1],
        backdoor_client_idx=[0, 1],
        num_users=3,
        unlearn_epoch=1,
        mask_ratio=0.1,
        diff_mask_ratio=0.1,
        target='rapid_retrain'
    )

    rapid_retrain_mod.rapid_retrain(args)

    assert trained == [2]

    loss_obj = next(obj for path, obj in saved if path.endswith('_loss_train.pth'))
    assert loss_obj[0] == [] and loss_obj[1] == [] and loss_obj[2] == [2]

    for path, _ in saved:
        assert '0_1/0_1' in path.replace('\\', '/')
