import numpy as np
import pytest
import torch

from byzantine_robust_fl.byzantine_defense import (
    FGNV,
    FLDetector,
    aggregate_multi_krum,
    calculate_gradient_cosine_similarities,
    detect_by_gradient_norm_outliers,
    update_reputations,
    wbc,
)

# Type alias for clarity
StateDict = dict[str, torch.Tensor]
WeightsList = list[StateDict]


# --- Fixtures for Mock Data ---


@pytest.fixture
def mock_weights() -> tuple[WeightsList, StateDict]:
    """Provide representative local and global weight collections for tests."""
    local_weights = [
        {"layer.w": torch.tensor([1.0, 1.0])},  # Benign
        {"layer.w": torch.tensor([1.1, 1.1])},  # Benign
        {"layer.w": torch.tensor([-5.0, -5.0])},  # Malicious
    ]
    global_weights_before = {"layer.w": torch.tensor([0.0, 0.0])}
    return local_weights, global_weights_before


# --- Test Cases ---


class TestReputationUpdates:
    """Exercise the behaviour of ``update_reputations``."""

    @pytest.fixture
    def reputation_data(self) -> tuple:
        """Prepare baseline reputation arrays for the test suite."""
        participating_indices = [0, 1, 2, 3]
        current_alphas = np.ones(10) * 5  # Start with alpha=5 for all 10 clients
        current_betas = np.ones(10) * 2  # Start with beta=2 for all 10 clients
        return participating_indices, current_alphas, current_betas

    def test_benign_client_update(self, reputation_data):
        """Verify reputation update for a benign client."""
        participating, alphas, betas = reputation_data
        detected_malicious = {3}  # Client 0 is not malicious

        _, updated_alphas, updated_betas = update_reputations(
            participating, detected_malicious, alphas, betas, discount_factor=1.0
        )

        # Benign client 0: alpha increases, beta stays the same
        assert updated_alphas[0] == alphas[0] + 1
        assert updated_betas[0] == betas[0]

    def test_malicious_client_update(self, reputation_data):
        """Verify reputation update for a malicious client."""
        participating, alphas, betas = reputation_data
        detected_malicious = {3}  # Client 3 is malicious

        _, updated_alphas, updated_betas = update_reputations(
            participating, detected_malicious, alphas, betas, discount_factor=1.0
        )

        # Malicious client 3: beta increases, alpha stays the same
        assert updated_alphas[3] == alphas[3]
        assert updated_betas[3] == betas[3] + 1

    def test_discount_factor(self, reputation_data):
        """Verify that the discount factor is applied correctly."""
        participating, alphas, betas = reputation_data
        detected_malicious = {3}
        discount = 0.9

        _, updated_alphas, updated_betas = update_reputations(
            participating, detected_malicious, alphas, betas, discount_factor=discount
        )

        # Benign client 0's alpha is decayed then incremented
        assert updated_alphas[0] == pytest.approx((alphas[0] * discount) + 1)
        # Malicious client 3's beta is decayed then incremented
        assert updated_betas[3] == pytest.approx((betas[3] * discount) + 1)


class TestCosineSimilarity:
    """Validate ``calculate_gradient_cosine_similarities`` outcomes."""

    def test_similarity_calculation(self, mock_weights, monkeypatch):
        """Verify cosine similarity values are calculated correctly."""
        local_weights, global_weights_before = mock_weights

        # Mock calculate_gradients to return predictable vectors
        mock_local_grads = [
            {"w": torch.tensor([1.0, 0.0])},  # Parallel to server
            {"w": torch.tensor([0.0, 1.0])},  # Orthogonal to server
            {"w": torch.tensor([-1.0, 0.0])},  # Anti-parallel to server
        ]
        mock_server_grad = {"w": torch.tensor([1.0, 0.0])}

        monkeypatch.setattr(
            "byzantine_robust_fl.byzantine_defense.calculate_gradients",
            lambda a, b, c: mock_local_grads if b is local_weights else mock_server_grad,
        )

        # The function calculates the raw cosine similarity, which can be negative.
        similarities = calculate_gradient_cosine_similarities(
            local_weights, {"w": torch.tensor([1.0])}, global_weights_before, 0.1
        )

        # Assert the correct, unclipped values: [1.0, 0.0, -1.0]
        assert np.allclose(similarities, [1.0, 0.0, -1.0])


class TestMultiKrum:
    """Assess selection logic within ``aggregate_multi_krum``."""

    def test_krum_selection(self, monkeypatch):
        """Verify that Krum selects the correct clients."""
        # Mock weights for 5 clients
        local_weights = [{"w": i} for i in range(5)]
        global_weights_before = {"w": 0}

        # Mock gradients: 3 benign clients are clustered together, 2 are outliers.
        mock_grads = [
            {"w": torch.tensor([100.0])},  # Outlier (attacker)
            {"w": torch.tensor([1.0])},  # Benign
            {"w": torch.tensor([1.1])},  # Benign
            {"w": torch.tensor([1.2])},  # Benign
            {"w": torch.tensor([-100.0])},  # Outlier (attacker)
        ]
        monkeypatch.setattr("byzantine_robust_fl.byzantine_defense.calculate_gradients", lambda a, b, c: mock_grads)

        # Mock L2 norm to return squared distances
        def mock_l2_norm(grad1, grad2):
            return torch.sum((grad1["w"] - grad2["w"]) ** 2).sqrt()

        monkeypatch.setattr("byzantine_robust_fl.byzantine_defense.calculate_l2_norm", mock_l2_norm)

        # Mock average_weights to capture the selected weights
        selected_for_averaging = []

        def mock_avg_weights(weights):
            nonlocal selected_for_averaging
            selected_for_averaging = weights
            return {"w": torch.tensor(0.0)}

        monkeypatch.setattr("byzantine_robust_fl.byzantine_defense.average_weights", mock_avg_weights)

        # num_clients = 5, num_attackers = 2. Now k = 5 - 2 - 2 = 1.
        # We will select the top 3 clients with the lowest scores.
        aggregate_multi_krum(
            local_weights, global_weights_before, num_attackers=2, num_benign_to_select=3, learning_rate=0.1
        )

        # Krum should select the three benign clients (indices 1, 2, 3) as they are closest.
        assert len(selected_for_averaging) == 3

        # Get the "w" values from the selected dictionaries
        selected_w_values = sorted([item["w"] for item in selected_for_averaging])

        # Check if the correct weights were selected
        assert selected_w_values == [1, 2, 3]


class TestNormOutlierDetection:
    """Validate ``detect_by_gradient_norm_outliers`` behaviour."""

    def test_outlier_detection(self, monkeypatch):
        """Verify that clients with outlier gradient norms are detected."""
        mock_local_weights = [{}, {}, {}]
        mock_global_weights = {}

        # Mock gradients where client 1 has a consistently large norm
        mock_grads = [
            {"layer1": torch.tensor(1.0), "layer2": torch.tensor(2.0)},  # Benign
            {"layer1": torch.tensor(100.0), "layer2": torch.tensor(200.0)},  # Malicious outlier
            {"layer1": torch.tensor(1.1), "layer2": torch.tensor(2.1)},  # Benign
        ]
        monkeypatch.setattr("byzantine_robust_fl.byzantine_defense.calculate_gradients", lambda a, b, c: mock_grads)

        malicious_indices = detect_by_gradient_norm_outliers(
            mock_local_weights, mock_global_weights, learning_rate=0.1, outlier_threshold=1.5
        )

        # The function should identify client 1 as malicious
        assert malicious_indices == [1]


class TestWBC:
    @staticmethod
    def _mk_list_single_key(values_per_client, key="p"):
        """Convert per-client tensors into state-dict structures keyed by ``key``."""
        return [{key: t.clone()} for t in values_per_client]

    def test_updates_when_delta_le_noise(self, monkeypatch):
        lr = 1.0
        w = self._mk_list_single_key([torch.tensor([0.05, 0.05]), torch.tensor([0.00, 0.02])])
        w_before = self._mk_list_single_key([torch.zeros(2), torch.zeros(2)])
        delta_before = self._mk_list_single_key([torch.zeros(2), torch.zeros(2)])

        mod_path = wbc.__module__
        monkeypatch.setattr(
            f"{mod_path}._laplace_like", lambda t, device=None, dtype=None: torch.full_like(t, 0.1), raising=True
        )

        w_bak = self._mk_list_single_key([v["p"] for v in w])
        wb_bak = self._mk_list_single_key([v["p"] for v in w_before])
        db_bak = self._mk_list_single_key([v["p"] for v in delta_before])

        w_new = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)

        assert len(w_new) == 2 and list(w_new[0].keys()) == ["p"]
        torch.testing.assert_close(w_new[0]["p"], torch.tensor([0.15, 0.15]))
        torch.testing.assert_close(w_new[1]["p"], torch.tensor([0.10, 0.12]))

        torch.testing.assert_close(w[0]["p"], w_bak[0]["p"])
        torch.testing.assert_close(w_before[1]["p"], wb_bak[1]["p"])
        torch.testing.assert_close(delta_before[1]["p"], db_bak[1]["p"])

    def test_no_update_when_noise_negative(self, monkeypatch):
        lr = 1.0
        w = self._mk_list_single_key([torch.tensor([0.03, 0.04])])
        w_before = self._mk_list_single_key([torch.zeros(2)])
        delta_before = self._mk_list_single_key([torch.zeros(2)])

        mod_path = wbc.__module__
        monkeypatch.setattr(
            f"{mod_path}._laplace_like", lambda t, device=None, dtype=None: torch.full_like(t, -1.0), raising=True
        )

        w_new = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)
        torch.testing.assert_close(w_new[0]["p"], w[0]["p"])

    def test_partial_mask(self, monkeypatch):
        lr = 1.0
        w = self._mk_list_single_key([torch.tensor([0.03, 0.02])])
        w_before = self._mk_list_single_key([torch.zeros(2)])
        delta_before = self._mk_list_single_key([torch.zeros(2)])

        mod_path = wbc.__module__
        monkeypatch.setattr(
            f"{mod_path}._laplace_like",
            lambda t, device=None, dtype=None: torch.tensor([0.025, 0.05], device=t.device, dtype=t.dtype),
            raising=True,
        )

        w_new = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)
        torch.testing.assert_close(w_new[0]["p"], torch.tensor([0.03, 0.07]))

    def test_structure_with_nonfloat_keys_preserved(self, monkeypatch):
        lr = 1.0
        w = [
            {"a": torch.tensor([0.0, 0.0]), "b": torch.tensor([0.1]), "c": torch.tensor([3], dtype=torch.long)},
            {"a": torch.tensor([0.2, 0.3]), "b": torch.tensor([0.0]), "c": torch.tensor([7], dtype=torch.long)},
        ]
        w_before = [
            {"a": torch.zeros(2), "b": torch.zeros(1), "c": torch.tensor([3], dtype=torch.long)},
            {"a": torch.zeros(2), "b": torch.zeros(1), "c": torch.tensor([7], dtype=torch.long)},
        ]
        delta_before = [
            {"a": torch.zeros(2), "b": torch.zeros(1), "c": torch.zeros(1, dtype=torch.long)},
            {"a": torch.zeros(2), "b": torch.zeros(1), "c": torch.zeros(1, dtype=torch.long)},
        ]

        mod_path = wbc.__module__

        def mock_laplace_like(t, device=None, dtype=None):
            """Return a tensor of ones irrespective of the input shape or dtype."""
            if torch.is_floating_point(t):
                output_dtype = t.dtype
            else:
                output_dtype = torch.float32

            return torch.full_like(t, 1.0, dtype=output_dtype)

        monkeypatch.setattr(f"{mod_path}._laplace_like", mock_laplace_like, raising=True)

        w_new = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)

        assert len(w_new) == 2
        assert set(w_new[0].keys()) == {"a", "b", "c"}
        assert set(w_new[1].keys()) == {"a", "b", "c"}

        torch.testing.assert_close(w_new[0]["a"], torch.tensor([1.0, 1.0]))
        torch.testing.assert_close(w_new[0]["b"], torch.tensor([1.1]))
        torch.testing.assert_close(w_new[1]["a"], torch.tensor([1.2, 1.3]))
        torch.testing.assert_close(w_new[1]["b"], torch.tensor([1.0]))

        assert w_new[0]["c"].dtype == torch.long
        assert torch.equal(w_new[0]["c"], torch.tensor([3], dtype=torch.long))
        assert w_new[1]["c"].dtype == torch.long
        assert torch.equal(w_new[1]["c"], torch.tensor([7], dtype=torch.long))

    def test_inputs_not_modified(self, monkeypatch):
        lr = 1.0
        w = [{"x": torch.tensor([0.1, 0.2])}]
        w_before = [{"x": torch.zeros(2)}]
        delta_before = [{"x": torch.zeros(2)}]

        mod_path = wbc.__module__
        monkeypatch.setattr(
            f"{mod_path}._laplace_like", lambda t, device=None, dtype=None: torch.full_like(t, 0.5), raising=True
        )

        w_bak = [{"x": w[0]["x"].clone()}]
        wb_bak = [{"x": w_before[0]["x"].clone()}]
        db_bak = [{"x": delta_before[0]["x"].clone()}]

        _ = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)

        torch.testing.assert_close(w[0]["x"], w_bak[0]["x"])
        torch.testing.assert_close(w_before[0]["x"], wb_bak[0]["x"])
        torch.testing.assert_close(delta_before[0]["x"], db_bak[0]["x"])

    def test_raises_on_length_mismatch(self):
        lr = 1.0
        w = self._mk_list_single_key([torch.zeros(1), torch.zeros(1)])
        w_before = self._mk_list_single_key([torch.zeros(1)])  # 少一个
        delta_before = self._mk_list_single_key([torch.zeros(1), torch.zeros(1)])
        with pytest.raises(ValueError):
            _ = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)

    def test_raises_on_key_mismatch(self):
        lr = 1.0
        w = [{"a": torch.zeros(1), "b": torch.zeros(1)}]
        w_before = [{"a": torch.zeros(1), "c": torch.zeros(1)}]
        delta_before = [{"a": torch.zeros(1), "b": torch.zeros(1)}]
        with pytest.raises(ValueError):
            _ = wbc(lr=lr, w=w, w_before=w_before, delta_w_before=delta_before)


class TestFLDetector:
    @staticmethod
    def _make_global(w_vec, with_long=True):
        d = {"w": torch.as_tensor(w_vec, dtype=torch.float32).clone()}
        if with_long:
            d["c"] = torch.tensor([3], dtype=torch.long)
        return d

    @staticmethod
    def _make_local_from_update(w_before, g_vec, lr=1.0, with_long=True):
        w_local = {"w": w_before["w"] + lr * torch.as_tensor(g_vec, dtype=w_before["w"].dtype)}
        if with_long and "c" in w_before:
            w_local["c"] = w_before["c"].clone()
        return w_local

    def test_initial_round_no_detection_history_grows(self):
        det = FLDetector(window_size=3, start_iter=3)

        lr = 1.0
        w_before = self._make_global([0.0, 0.0])
        g0 = torch.tensor([0.1, 0.1])
        g1 = torch.tensor([0.2, 0.0])
        local_weights = [
            self._make_local_from_update(w_before, g0, lr),
            self._make_local_from_update(w_before, g1, lr),
        ]
        w_after = self._make_global(((w_before["w"] + lr * g0) + (w_before["w"] + lr * g1)) / 2, with_long=True)

        malicious, scores = det.step_and_detect(
            chosen_users=[0, 1],
            local_weights=local_weights,
            global_weights_before=w_before,
            global_weights_after=w_after,
            lr=lr,
        )

        assert malicious == []
        assert isinstance(scores, (list, tuple)) or hasattr(scores, "shape")
        assert len(scores) == 2
        assert len(det._dW_hist) == 1
        assert len(det._dG_hist) == 1
        assert len(det._d_norm_window) == 1

    def test_nonfloat_keys_are_skipped(self):
        det = FLDetector(window_size=2, start_iter=2)
        lr = 1.0
        w_before = self._make_global([0.0, 0.0], with_long=True)
        g0 = torch.tensor([0.3, -0.1])
        g1 = torch.tensor([0.0, 0.2])
        local_weights = [
            self._make_local_from_update(w_before, g0, lr, with_long=True),
            self._make_local_from_update(w_before, g1, lr, with_long=True),
        ]
        w_after = self._make_global(((w_before["w"] + lr * g0) + (w_before["w"] + lr * g1)) / 2, with_long=True)

        mal, _ = det.step_and_detect(
            chosen_users=[10, 11],
            local_weights=local_weights,
            global_weights_before=w_before,
            global_weights_after=w_after,
            lr=lr,
        )
        assert mal == []

    def test_uses_provided_global_update_vec(self):
        det = FLDetector(window_size=2, start_iter=2)
        lr = 1.0
        w_before = self._make_global([0.0, 0.0])
        g0 = torch.tensor([0.2, 0.2])
        g1 = torch.tensor([0.2, 0.2])
        local_weights = [
            self._make_local_from_update(w_before, g0, lr),
            self._make_local_from_update(w_before, g1, lr),
        ]
        w_after = self._make_global(((w_before["w"] + lr * g0) + (w_before["w"] + lr * g1)) / 2)

        provided = torch.tensor([9.0, 9.0], dtype=torch.float32)

        det.step_and_detect(
            chosen_users=[1, 2],
            local_weights=local_weights,
            global_weights_before=w_before,
            global_weights_after=w_after,
            lr=lr,
            global_update_vec=provided,
        )
        assert torch.allclose(det._last_global_update.cpu(), provided, atol=1e-6)

    def test_detection_marks_high_norm_clients_with_monkeypatch(self, monkeypatch):
        det = FLDetector(window_size=4, start_iter=1)

        monkeypatch.setattr(FLDetector, "_hvp_ls", lambda self, v: torch.zeros_like(v), raising=True)

        def _cluster(self, scores: torch.Tensor) -> torch.Tensor:
            return (scores > scores.mean()).to(torch.long)

        monkeypatch.setattr(FLDetector, "_cluster_via_gap", _cluster, raising=True)

        lr = 1.0
        chosen = [7, 8, 9]
        w_before = self._make_global([0.0, 0.0])

        g_small = torch.tensor([0.1, 0.1])
        locals_1 = [
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_small, lr),
        ]
        w_after_1 = self._make_global((locals_1[0]["w"] + locals_1[1]["w"] + locals_1[2]["w"]) / 3.0)
        det.step_and_detect(
            chosen_users=chosen,
            local_weights=locals_1,
            global_weights_before=w_before,
            global_weights_after=w_after_1,
            lr=lr,
        )
        w_before = {"w": w_after_1["w"].clone(), "c": torch.tensor([3], dtype=torch.long)}

        locals_2 = [
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_small, lr),
        ]
        w_after_2 = self._make_global((locals_2[0]["w"] + locals_2[1]["w"] + locals_2[2]["w"]) / 3.0)
        det.step_and_detect(
            chosen_users=chosen,
            local_weights=locals_2,
            global_weights_before=w_before,
            global_weights_after=w_after_2,
            lr=lr,
        )
        w_before = {"w": w_after_2["w"].clone(), "c": torch.tensor([3], dtype=torch.long)}

        g_big = torch.tensor([5.0, 5.0])
        locals_3 = [
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_small, lr),
            self._make_local_from_update(w_before, g_big, lr),
        ]
        w_after_3 = self._make_global((locals_3[0]["w"] + locals_3[1]["w"] + locals_3[2]["w"]) / 3.0)
        malicious, scores = det.step_and_detect(
            chosen_users=chosen,
            local_weights=locals_3,
            global_weights_before=w_before,
            global_weights_after=w_after_3,
            lr=lr,
        )

        assert malicious == [9]
        assert len(scores) == 3

    def test_gap_single_cluster_returns_empty(self, monkeypatch):
        det = FLDetector(window_size=3, start_iter=1)

        monkeypatch.setattr(
            FLDetector,
            "_cluster_via_gap",
            lambda self, scores: torch.zeros(scores.numel(), dtype=torch.long),
            raising=True,
        )
        monkeypatch.setattr(FLDetector, "_hvp_ls", lambda self, v: torch.zeros_like(v), raising=True)

        lr = 1.0
        w_before = self._make_global([0.0, 0.0])
        g0 = torch.tensor([0.2, 0.2])
        g1 = torch.tensor([0.3, 0.0])
        locals_ = [
            self._make_local_from_update(w_before, g0, lr),
            self._make_local_from_update(w_before, g1, lr),
        ]
        w_after = self._make_global((locals_[0]["w"] + locals_[1]["w"]) / 2)

        malicious, _ = det.step_and_detect(
            chosen_users=[101, 102],
            local_weights=locals_,
            global_weights_before=w_before,
            global_weights_after=w_after,
            lr=lr,
        )
        assert malicious == []

    def test_history_window_capacity(self, monkeypatch):
        N = 3
        det = FLDetector(window_size=N, start_iter=1)
        monkeypatch.setattr(FLDetector, "_hvp_ls", lambda self, v: torch.zeros_like(v), raising=True)
        monkeypatch.setattr(
            FLDetector,
            "_cluster_via_gap",
            lambda self, scores: torch.zeros(scores.numel(), dtype=torch.long),
            raising=True,
        )

        lr = 1.0
        w_before = self._make_global([0.0, 0.0])
        chosen = [0, 1]
        for t in range(5):
            g0 = torch.tensor([0.1 + 0.01 * t, 0.1])
            g1 = torch.tensor([0.1, 0.1 + 0.02 * t])
            locals_ = [
                self._make_local_from_update(w_before, g0, lr),
                self._make_local_from_update(w_before, g1, lr),
            ]
            w_after = self._make_global((locals_[0]["w"] + locals_[1]["w"]) / 2)
            det.step_and_detect(
                chosen_users=chosen,
                local_weights=locals_,
                global_weights_before=w_before,
                global_weights_after=w_after,
                lr=lr,
            )
            w_before = {"w": w_after["w"].clone(), "c": torch.tensor([3], dtype=torch.long)}

        assert len(det._dW_hist) <= N
        assert len(det._dG_hist) <= N
        assert len(det._d_norm_window) <= N


class TestFGNV:
    @staticmethod
    def _mk_grads_from_norms(norms_per_client, keys=("a", "b"), with_long=False):
        out = []
        for norms in norms_per_client:
            d = {}
            for k, n in zip(keys, norms):
                d[k] = torch.tensor([float(n)], dtype=torch.float32)
            if with_long:
                d["c"] = torch.tensor([3], dtype=torch.long)
            out.append(d)
        return out

    def test_empty_clients_returns_empty(self, monkeypatch):
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: [],
            raising=True,
        )
        out = FGNV(
            w_locals=[],
            w_glob_before={},
            chosenUsers=[],
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == []

    def test_nonfloat_keys_are_skipped(self, monkeypatch):
        grads = self._mk_grads_from_norms([[1.0, 1.0], [0.98, 1.02], [1.01, 0.99]], keys=("a", "b"), with_long=True)
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [10, 11, 12]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == []

    def test_detect_single_malicious_high_norm(self, monkeypatch):
        grads = self._mk_grads_from_norms(
            [
                [1.0, 1.0],
                [1.0, 1.0],
                [2.0, 2.0],
            ],
            keys=("a", "b"),
        )
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [5, 6, 7]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == [7]

    def test_tie_returns_multiple(self, monkeypatch):
        grads = self._mk_grads_from_norms(
            [
                [1.0, 1.0],
                [1.0, 1.0],
                [2.0, 2.0],
                [2.0, 2.0],
            ],
            keys=("a", "b"),
        )
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [100, 101, 102, 103]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == [102, 103]

    def test_identical_norms_no_detection(self, monkeypatch):
        grads = self._mk_grads_from_norms([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], keys=("w1", "w2"))
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [0, 1, 2]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == []

    def test_zero_norms_are_stable(self, monkeypatch):
        grads = self._mk_grads_from_norms(
            [
                [0.0, 1.0],
                [0.0, 1.0],
                [0.0, 5.0],
            ],
            keys=("k1", "k2"),
        )
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [7, 8, 9]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == [9]

    def test_device_argument_is_honored(self, monkeypatch):
        grads = self._mk_grads_from_norms([[1.0, 1.0], [2.0, 2.0]], keys=("a", "b"))
        mod_path = FGNV.__module__
        monkeypatch.setattr(
            f"{mod_path}.calculate_gradients",
            lambda *args, **kwargs: grads,
            raising=True,
        )
        chosen = [1, 2]
        out = FGNV(
            w_locals=None,
            w_glob_before=None,
            chosenUsers=chosen,
            learning_rate=1.0,
            device=torch.device("cpu"),
        )
        assert out == [2]
