import copy

import numpy as np
import pytest
import torch
import torch.nn as nn

from byzantine_robust_fl.utils.model_averaging import (
    average_fsvrg_weights,
    average_weights,
    average_weights_resilient,
)


class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 2)

    def forward(self, x):
        return self.layer(x)


@pytest.fixture
def state_dicts():
    state_dict1 = {"layer.weight": torch.ones((2, 10)), "layer.bias": torch.ones(2)}
    state_dict2 = {"layer.weight": torch.full((2, 10), 2.0), "layer.bias": torch.full((2,), 2.0)}
    state_dict3 = {"layer.weight": torch.full((2, 10), 3.0), "layer.bias": torch.full((2,), 3.0)}
    return [state_dict1, state_dict2, state_dict3]


@pytest.fixture
def numpy_arrays():
    return [np.ones((5, 5)), np.full((5, 5), 2.0), np.full((5, 5), 3.0)]


# --- average_weights 函数的测试 ---


def test_average_weights_with_tensors(state_dicts):
    averaged = average_weights(state_dicts)

    expected_weight = torch.full((2, 10), 2.0)
    expected_bias = torch.full((2,), 2.0)

    assert torch.allclose(averaged["layer.weight"], expected_weight)
    assert torch.allclose(averaged["layer.bias"], expected_bias)


def test_average_weights_with_numpy(numpy_arrays):
    averaged = average_weights(numpy_arrays)

    expected_array = np.full((5, 5), 2.0)

    assert np.allclose(averaged, expected_array)


def test_average_weights_single_item(state_dicts):
    single_list = [state_dicts[0]]
    averaged = average_weights(single_list)

    assert torch.allclose(averaged["layer.weight"], single_list[0]["layer.weight"])
    averaged["layer.weight"][0, 0] = 99.0
    assert single_list[0]["layer.weight"][0, 0] == 1.0


def test_average_weights_empty_list():
    with pytest.raises(ValueError, match="cannot be empty"):
        average_weights([])


def test_average_weights_resilient_basic(state_dicts):
    reputations = [1.0, 2.0, 1.0]  # 总声誉 = 4.0

    expected_weight = torch.full((2, 10), 2.0)
    expected_bias = torch.full((2,), 2.0)

    averaged = average_weights_resilient(state_dicts, reputations)

    assert torch.allclose(averaged["layer.weight"], expected_weight)
    assert torch.allclose(averaged["layer.bias"], expected_bias)


def test_average_weights_resilient_errors(state_dicts):
    with pytest.raises(ValueError, match="weights_list cannot be empty."):
        average_weights_resilient([], [])

    with pytest.raises(ValueError, match="Mismatch between number of weights and reputations"):
        average_weights_resilient(state_dicts, [1.0, 2.0])

    with pytest.raises(ValueError, match="Total reputation must be positive."):
        average_weights_resilient(state_dicts, [1.0, -1.0, 0.0])


def test_average_fsvrg_weights_basic():
    global_model = SimpleNet()
    with torch.no_grad():
        for param in global_model.parameters():
            param.fill_(1.0)

    client1_model = copy.deepcopy(global_model)
    client2_model = copy.deepcopy(global_model)

    with torch.no_grad():
        client1_model.layer.weight.fill_(2.0)  # diff = 1.0
        client1_model.layer.bias.fill_(-1.0)  # diff = -2.0
        client2_model.layer.weight.fill_(1.5)  # diff = 0.5
        client2_model.layer.bias.fill_(3.0)  # diff = 2.0

    client_updates = [
        (10, client1_model.state_dict()),
        (30, client2_model.state_dict()),
    ]
    agg_scalar = 1.0

    # new_weight = 1.0 + (1.0 / 40) * (10 * (2.0-1.0) + 30 * (1.5-1.0)) = 1.625
    # new_bias = 1.0 + (1.0 / 40) * (10 * (-1.0-1.0) + 30 * (3.0-1.0)) = 2.0

    updated_weights = average_fsvrg_weights(client_updates, agg_scalar, global_model)

    expected_weight = torch.full((2, 10), 1.625)
    expected_bias = torch.full((2,), 2.0)

    assert torch.allclose(updated_weights["layer.weight"], expected_weight)
    assert torch.allclose(updated_weights["layer.bias"], expected_bias)


def test_average_fsvrg_weights_no_updates():
    global_model = SimpleNet()
    original_weights = copy.deepcopy(global_model.state_dict())

    updated_weights = average_fsvrg_weights([], 1.0, global_model)
    assert torch.allclose(original_weights["layer.weight"], updated_weights["layer.weight"])
    assert torch.allclose(original_weights["layer.bias"], updated_weights["layer.bias"])


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available for this test")
def test_average_fsvrg_weights_gpu():
    gpu_id = 0
    device = torch.device(f"cuda:{gpu_id}")

    global_model = SimpleNet().to(device)
    with torch.no_grad():
        for param in global_model.parameters():
            param.fill_(1.0)

    client_model = SimpleNet().to(device)
    with torch.no_grad():
        client_model.layer.weight.fill_(3.0)  # diff = 2.0
        client_model.layer.bias.fill_(5.0)  # diff = 4.0

    client_updates = [(10, client_model.state_dict())]

    # new_weight = 1.0 + (1.0/10) * (10 * (3.0-1.0)) = 3.0
    # new_bias = 1.0 + (1.0/10) * (10 * (5.0-1.0)) = 5.0
    updated_weights = average_fsvrg_weights(client_updates, 1.0, global_model, gpu_id=gpu_id)

    expected_weight = torch.full((2, 10), 3.0, device=device)
    expected_bias = torch.full((2,), 5.0, device=device)

    assert updated_weights["layer.weight"].device.type == "cuda"
    assert torch.allclose(updated_weights["layer.weight"], expected_weight)
    assert torch.allclose(updated_weights["layer.bias"], expected_bias)
