import os
import sys
import importlib.util
import types
import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Provide stub modules so increase_loss can be imported in isolation
stubs = {
    'models.Update_domain': types.ModuleType('models.Update_domain'),
    'utils.options': types.ModuleType('utils.options'),
    'utils.evaluate': types.ModuleType('utils.evaluate'),
    'utils.init_data_model': types.ModuleType('utils.init_data_model'),
    'utils.increase_loss_utils': types.ModuleType('utils.increase_loss_utils'),
}

# minimal attributes required during import
stubs['models.Update_domain'].DomainClientUpdate = object
stubs['models.Update_domain'].DomainClientUpdate_avg = object
stubs['models.Update_domain'].DomainClientUnlearningBl3 = object
stubs['utils.options'].args_parser = lambda: None
stubs['utils.evaluate'].evaluate = lambda *a, **k: (None, None)
stubs['utils.init_data_model'].init_data = lambda *a, **k: None
stubs['utils.init_data_model'].init_model = lambda *a, **k: None
stubs['utils.init_data_model'].init_data_methodone = lambda *a, **k: None
stubs['utils.init_data_model'].get_dataset = lambda *a, **k: []
stubs['utils.increase_loss_utils'].get_distance = lambda *a, **k: torch.tensor(0.0)

sys.modules.update(stubs)

spec = importlib.util.spec_from_file_location(
    'increase_loss_mod',
    os.path.join(os.path.dirname(__file__), '..', 'unlearning_methods', 'increase_loss.py')
)
increase_loss_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(increase_loss_mod)


def create_state(value):
    model = torch.nn.Linear(1, 1)
    with torch.no_grad():
        model.weight.fill_(value)
        model.bias.fill_(value)
    return model.state_dict()

class DummyUnlearning:
    def __init__(self, args, train_loader, threshold):
        self.value = train_loader  # train_loader is just index
    def train(self, net, net_ref, net_unlearning_client):
        return create_state(self.value + 10)

def test_multiple_unlearning_clients(monkeypatch, tmp_path):
    args = types.SimpleNamespace(
        seed=0,
        gpu=-1,
        dataset_fullparti=False,
        domain_split_factor=1,
        dataset='dummy',
        save='tmp',
        unlearning_client=[0,1],
        backdoor_client_idx=[0,1],
        num_users=3,
        unlearn_epoch=0,
        lr=0.1,
        local_ep=1,
        num_local_unlearn_epochs=1,
        clip_grad=1,
        mask_ratio=0.1,
        diff_mask_ratio=0.1,
        target='increase_loss'
    )

    monkeypatch.setattr(increase_loss_mod, 'init_data_methodone', lambda args: ([0,1,2], [0,1,2], None))
    monkeypatch.setattr(increase_loss_mod, 'get_dataset', lambda args: [f'c{i}' for i in range(args.num_users)])
    monkeypatch.setattr(increase_loss_mod, 'init_model', lambda args: torch.nn.Linear(1,1))
    monkeypatch.setattr(increase_loss_mod, 'evaluate', lambda **kw: ([[],[]], [0]*args.num_users))
    monkeypatch.setattr(increase_loss_mod, 'DomainClientUnlearningBl3', DummyUnlearning)

    weight_global = create_state(0.0)
    weight_local = {0: create_state(1.0), 1: create_state(2.0)}

    def fake_load(path, map_location=None):
        if 'weight_global.pth' in path:
            return weight_global
        if 'weight_local.pth' in path:
            return weight_local
        return weight_global

    saved = {}
    def fake_save(obj, path):
        if path.endswith('_weight_global.pth'):
            saved['w'] = obj

    monkeypatch.setattr(increase_loss_mod.torch, 'load', fake_load)
    monkeypatch.setattr(increase_loss_mod.torch, 'save', fake_save)

    increase_loss_mod.increase_loss(args)

    expected = create_state(10.5)
    for k in expected:
        assert torch.allclose(saved['w'][k], expected[k])
