import copy
from typing import Dict, List, Set

import pytest
import torch

# Assuming your code is in a file like 'byzantine_attacks.py'
# and your utility functions are in 'utils/'
from byzantine_robust_fl.byzantine_attacks import (
    fang_attack,
    gaussian_attack,
    get_malicious_client_updates,
    lie_attack,
)

# Define Type Aliases for clarity
StateDict = Dict[str, torch.Tensor]
WeightsList = List[StateDict]


# --- Fixtures for Mock Data ---


@pytest.fixture
def mock_updates() -> WeightsList:
    """Create five distinct mock model updates."""
    return [
        {"layer.w": torch.tensor([1.0, 1.0]), "layer.b": torch.tensor([0.1])},  # Benign
        {"layer.w": torch.tensor([2.0, 2.0]), "layer.b": torch.tensor([0.2])},  # Benign
        {"layer.w": torch.tensor([3.0, 3.0]), "layer.b": torch.tensor([0.3])},  # Benign
        {"layer.w": torch.tensor([10.0, 10.0]), "layer.b": torch.tensor([1.0])},  # Malicious
        {"layer.w": torch.tensor([11.0, 11.0]), "layer.b": torch.tensor([1.1])},  # Malicious
    ]


@pytest.fixture
def all_client_indices() -> List[int]:
    """Return five representative client indices."""
    return [10, 20, 30, 40, 50]


@pytest.fixture
def malicious_client_indices() -> Set[int]:
    """Return the indices designating malicious clients."""
    return {40, 50}


@pytest.fixture
def global_weights_before() -> StateDict:
    """Provide baseline global weights for Fang attack tests."""
    return {"layer.w": torch.tensor([0.5, 0.5]), "layer.b": torch.tensor([0.0])}


@pytest.fixture
def device() -> torch.device:
    """Return the torch device used in tests."""
    return torch.device("cpu")


# --- Tests for Byzantine Attack Functions ---


def test_get_malicious_client_updates(
    mock_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
):
    """Verify that the function correctly filters for malicious updates."""
    malicious_updates = get_malicious_client_updates(mock_updates, all_client_indices, malicious_client_indices)
    assert len(malicious_updates) == 2
    # Check if the correct updates were filtered
    assert torch.equal(malicious_updates[0]["layer.w"], torch.tensor([10.0, 10.0]))
    assert torch.equal(malicious_updates[1]["layer.w"], torch.tensor([11.0, 11.0]))


def test_gaussian_attack(
    mock_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
    device: torch.device,
):
    """Verify that Gaussian noise is added only to malicious clients."""
    original_updates = copy.deepcopy(mock_updates)
    attacked_updates = gaussian_attack(
        mock_updates, all_client_indices, malicious_client_indices, attack_scale=1.0, device=device
    )

    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            # Malicious updates should be changed
            assert not torch.equal(original_updates[i]["layer.w"], attacked_updates[i]["layer.w"])
        else:
            # Benign updates should remain the same
            assert torch.equal(original_updates[i]["layer.w"], attacked_updates[i]["layer.w"])


def test_lie_attack(
    mock_updates: WeightsList,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
    monkeypatch,
):
    """Verify the LIE attack logic using mocked helper functions."""
    original_updates = copy.deepcopy(mock_updates)

    # Mock the dependencies to return predictable values
    mock_avg = {"layer.w": torch.tensor([2.0, 2.0]), "layer.b": torch.tensor([0.2])}
    mock_std = {"layer.w": torch.tensor([0.8, 0.8]), "layer.b": torch.tensor([0.08])}
    monkeypatch.setattr("byzantine_robust_fl.byzantine_attacks.average_gradients", lambda _: mock_avg)
    monkeypatch.setattr("byzantine_robust_fl.byzantine_attacks.calculate_gradient_std_dev", lambda _: mock_std)
    # From the paper, for num_clients=5, num_attackers=2:
    # s = (5 // 2 + 1) - 2 = 1
    # z = norm.ppf((5 - 2 - 1) / (5 - 2)) = norm.ppf(2/3) ~= 0.4307
    z_score = 0.430727

    # Calculate the expected malicious update
    expected_malicious_update = {
        "layer.w": mock_avg["layer.w"] + mock_std["layer.w"] * z_score,
        "layer.b": mock_avg["layer.b"] + mock_std["layer.b"] * z_score,
    }

    attacked_updates = lie_attack(mock_updates, all_client_indices, malicious_client_indices, learning_rate=0.1)

    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            # Malicious updates should match the crafted one
            assert torch.allclose(attacked_updates[i]["layer.w"], expected_malicious_update["layer.w"])
        else:
            # Benign updates should be unchanged
            assert torch.equal(original_updates[i]["layer.w"], attacked_updates[i]["layer.w"])


def test_fang_attack(
    mock_updates: WeightsList,
    global_weights_before: StateDict,
    all_client_indices: List[int],
    malicious_client_indices: Set[int],
):
    """Verify the Fang attack logic."""
    original_updates = copy.deepcopy(mock_updates)

    # In our mock_updates, the benign updates are [1,1], [2,2], [3,3].
    # The aggregated benign update will be [2,2].
    # The global weights before were [0.5, 0.5].
    # The update direction is sign([2,2] - [0.5,0.5]) = sign([1.5,1.5]) = [1,1].
    # Since the direction is positive, the attack should use the element-wise min
    # of the benign updates, which is [1,1].
    expected_w = torch.tensor([1.0, 1.0])
    expected_b = torch.tensor([0.1])

    attacked_updates = fang_attack(mock_updates, global_weights_before, all_client_indices, malicious_client_indices)

    for i, client_idx in enumerate(all_client_indices):
        if client_idx in malicious_client_indices:
            # Malicious updates should match the crafted one
            assert torch.equal(attacked_updates[i]["layer.w"], expected_w)
            assert torch.equal(attacked_updates[i]["layer.b"], expected_b)
        else:
            # Benign updates should be unchanged
            assert torch.equal(original_updates[i]["layer.w"], attacked_updates[i]["layer.w"])


def test_attack_with_no_benign_clients(
    mock_updates: WeightsList,
    all_client_indices: List[int],
    global_weights_before: StateDict,
):
    """Ensure attacks do nothing if there are no benign clients to learn from."""
    # All clients are malicious
    malicious_indices = set(all_client_indices)
    original_updates = copy.deepcopy(mock_updates)

    lie_poisoned = lie_attack(mock_updates, all_client_indices, malicious_indices, 0.1)
    fang_poisoned = fang_attack(mock_updates, global_weights_before, all_client_indices, malicious_indices)

    # The updates should be returned unmodified
    assert all(torch.equal(orig["layer.w"], lie["layer.w"]) for orig, lie in zip(original_updates, lie_poisoned))
    assert all(torch.equal(orig["layer.w"], fang["layer.w"]) for orig, fang in zip(original_updates, fang_poisoned))
