import numpy as np
import pytest
import torch

# Import the functions to be tested from the Canvas
from byzantine_robust_fl.utils.federated_metrics import (
    _flatten_state_dict,
    average_gradients,
    calculate_cosine_similarity_to_normal,
    calculate_gradient_std_dev,
    calculate_gradients,
    calculate_inner_product,
    calculate_l2_norm,
    delta_parameters,
    normalize_gradients,
    perform_t_test,
)

# Type alias for clarity in tests
StateDict = dict[str, torch.Tensor]


# --- Pytest Fixtures for Reusable Test Data ---
@pytest.fixture
def state_dict_a() -> StateDict:
    """A fixture for a sample state dictionary with ones."""
    return {
        "layer1.weight": torch.ones((10, 5)),
        "layer1.bias": torch.ones(10),
        "batch_norm.running_mean": torch.ones(5),  # This key is included by default
    }


@pytest.fixture
def state_dict_b() -> StateDict:
    """A fixture for a sample state dictionary with twos."""
    return {
        "layer1.weight": torch.full((10, 5), 2.0),
        "layer1.bias": torch.full((10,), 2.0),
        "batch_norm.running_mean": torch.full((5,), 2.0),
    }


@pytest.fixture
def state_dict_c() -> StateDict:
    """A fixture for a sample state dictionary with threes."""
    return {
        "layer1.weight": torch.full((10, 5), 3.0),
        "layer1.bias": torch.full((10,), 3.0),
        "batch_norm.running_mean": torch.full((5,), 3.0),
    }


@pytest.fixture
def gradient_list(state_dict_a, state_dict_b, state_dict_c) -> list[StateDict]:
    """Provides a list of three state_dicts for gradient tests."""
    return [state_dict_a, state_dict_b, state_dict_c]


# --- Test Suite ---


def test_flatten_state_dict(state_dict_a):
    """Tests the flattening utility."""
    flat_vec = _flatten_state_dict(state_dict_a)
    # 10*5 (weight) + 10 (bias) + 5 (mean) = 65 elements
    assert flat_vec.shape == (65,)
    assert np.all(flat_vec == 1.0)

    # Test with ignore_keys
    flat_vec_ignored = _flatten_state_dict(state_dict_a, ignore_keys=["mean"])
    # Should ignore 'batch_norm.running_mean' (5 elements)
    assert flat_vec_ignored.shape == (60,)

    flat_vec_ignored_bias = _flatten_state_dict(state_dict_a, ignore_keys=["bias"])
    # Should ignore 'layer1.bias' (10 elements)
    assert flat_vec_ignored_bias.shape == (55,)


def test_calculate_l2_norm(state_dict_a, state_dict_b):
    """Tests L2 norm calculation for both single and dual inputs."""
    # Test norm of a single state_dict
    # Norm of a vector of 65 ones is sqrt(65*1^2)
    norm_a = calculate_l2_norm(state_dict_a)
    assert np.isclose(norm_a, np.sqrt(65))

    # Test distance between two state_dicts
    # The difference vector is all ones (65 elements), so the norm is sqrt(65)
    norm_diff = calculate_l2_norm(state_dict_a, state_dict_b)
    assert np.isclose(norm_diff, np.sqrt(65))


def test_calculate_inner_product(state_dict_a, state_dict_b):
    """Tests the inner product calculation."""
    # Inner product of a vector of 65 ones and 65 twos is 65 * 1 * 2 = 130
    inner_prod = calculate_inner_product(state_dict_a, state_dict_b)
    assert np.isclose(inner_prod, 130.0)


def test_calculate_gradients(state_dict_a, state_dict_b):
    """Tests gradient calculation."""
    lr = 0.1
    # grad = (a - b) / lr = (1 - 2) / 0.1 = -10
    expected_grad_val = -10.0

    # Test with a single state_dict
    grads = calculate_gradients(state_dict_a, state_dict_b, learning_rate=lr)
    assert torch.allclose(grads["layer1.weight"], torch.full((10, 5), expected_grad_val))

    # Test with a list of state_dicts
    grads_list = calculate_gradients(state_dict_a, [state_dict_b], learning_rate=lr)
    assert isinstance(grads_list, list)
    assert torch.allclose(grads_list[0]["layer1.weight"], torch.full((10, 5), expected_grad_val))


def test_average_gradients(gradient_list):
    """Tests the averaging of gradients."""
    avg_grad = average_gradients(gradient_list)
    # Average of 1, 2, 3 is 2
    expected_avg_val = 2.0
    assert torch.allclose(avg_grad["layer1.weight"], torch.full((10, 5), expected_avg_val))
    assert torch.allclose(avg_grad["layer1.bias"], torch.full((10,), expected_avg_val))


def test_calculate_gradient_std_dev(gradient_list):
    """Tests the standard deviation calculation."""
    std_dev = calculate_gradient_std_dev(gradient_list)
    # Std dev of {1, 2, 3} is sqrt(((1-2)^2 + (2-2)^2 + (3-2)^2) / (3-1)) = sqrt(1) = 1
    expected_std_val = 1.0
    assert torch.allclose(std_dev["layer1.weight"], torch.full((10, 5), expected_std_val))

    with pytest.raises(ValueError, match="at least two gradients"):
        calculate_gradient_std_dev([gradient_list[0]])


def test_normalize_gradients(gradient_list):
    """Tests gradient normalization."""
    grad_mean = average_gradients(gradient_list)  # All 2s
    grad_std = calculate_gradient_std_dev(gradient_list)  # All 1s

    normalized_grads = normalize_gradients(gradient_list, grad_mean, grad_std)

    # Expected normalized values: (1-2)/1=-1, (2-2)/1=0, (3-2)/1=1
    assert torch.allclose(normalized_grads[0]["layer1.weight"], torch.full((10, 5), -1.0))
    assert torch.allclose(normalized_grads[1]["layer1.weight"], torch.zeros((10, 5)))
    assert torch.allclose(normalized_grads[2]["layer1.weight"], torch.ones((10, 5)))


def test_calculate_cosine_similarity_to_normal(state_dict_a):
    """Tests the cosine similarity calculation."""
    # A constant value distribution is very different from normal
    similarity = calculate_cosine_similarity_to_normal(state_dict_a, num_bins=100)
    assert isinstance(similarity, float)
    assert 0.0 <= similarity <= 1.0
    # Expect a low similarity score for a constant distribution
    assert similarity < 0.5

    # Test with an empty state dict
    empty_similarity = calculate_cosine_similarity_to_normal({}, num_bins=10)
    assert empty_similarity == 0.0


def test_perform_t_test():
    """Tests the one-sample t-test function."""
    # Group 1: [0, 1, 2, 3, 4], Group 2: [100]
    # The value 100 should be significantly different from the rest.
    data = np.array([0, 1, 2, 3, 4, 100])
    target_indices = [5]  # Test the element '100'

    results = perform_t_test(data, target_indices, significance_level=0.05)

    p_value, is_significant = results[5]
    assert is_significant
    assert p_value < 0.05

    # Test where the target is not significant
    target_indices_ns = [2]  # Test the element '2'
    results_ns = perform_t_test(data, target_indices_ns, significance_level=0.05)
    p_value_ns, is_significant_ns = results_ns[2]
    assert not is_significant_ns
    assert p_value_ns > 0.05


class TestDeltaParameters:
    @staticmethod
    def _make_list_dict(vals_per_client):
        out = []
        for d in vals_per_client:
            out.append({k: v.clone() for k, v in d.items()})
        return out

    def test_basic_difference_single_client_single_key(self):
        current = self._make_list_dict([{"w": torch.tensor([1.0, 3.0, -2.0])}])
        previous = self._make_list_dict([{"w": torch.tensor([0.5, 1.0, -2.5])}])

        deltas = delta_parameters(current, previous)
        torch.testing.assert_close(deltas[0]["w"], torch.tensor([0.5, 2.0, 0.5]))

    def test_multi_clients_multi_keys(self):
        current = self._make_list_dict(
            [
                {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([0.0])},
                {"a": torch.tensor([0.2, -0.3]), "b": torch.tensor([5.0])},
            ]
        )
        previous = self._make_list_dict(
            [
                {"a": torch.tensor([0.5, 1.5]), "b": torch.tensor([0.0])},
                {"a": torch.tensor([-0.8, -0.1]), "b": torch.tensor([4.0])},
            ]
        )

        deltas = delta_parameters(current, previous)
        torch.testing.assert_close(deltas[0]["a"], torch.tensor([0.5, 0.5]))
        torch.testing.assert_close(deltas[0]["b"], torch.tensor([0.0]))
        torch.testing.assert_close(deltas[1]["a"], torch.tensor([1.0, -0.2]))
        torch.testing.assert_close(deltas[1]["b"], torch.tensor([1.0]))

        assert len(deltas) == len(current) == 2
        assert set(deltas[0].keys()) == {"a", "b"}
        assert set(deltas[1].keys()) == {"a", "b"}

    def test_inputs_are_not_modified(self):
        current = self._make_list_dict([{"w": torch.tensor([1.0])}])
        previous = self._make_list_dict([{"w": torch.tensor([0.25])}])
        cur_bak = self._make_list_dict(current)
        prev_bak = self._make_list_dict(previous)

        _ = delta_parameters(current, previous)

        torch.testing.assert_close(current[0]["w"], cur_bak[0]["w"])
        torch.testing.assert_close(previous[0]["w"], prev_bak[0]["w"])

    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
    def test_dtype_and_device_preserved(self, dtype):
        x1 = torch.tensor([1, 2], dtype=dtype)
        x0 = torch.tensor([0, 1], dtype=dtype)
        current = self._make_list_dict([{"w": x1}])
        previous = self._make_list_dict([{"w": x0}])

        deltas = delta_parameters(current, previous)
        assert deltas[0]["w"].dtype == dtype
        assert deltas[0]["w"].device == x1.device

    def test_zero_when_identical(self):
        t = torch.randn(3)
        current = self._make_list_dict([{"w": t}])
        previous = self._make_list_dict([{"w": t}])

        deltas = delta_parameters(current, previous)
        torch.testing.assert_close(deltas[0]["w"], torch.zeros_like(t))

    def test_raises_on_length_mismatch(self):
        current = self._make_list_dict([{"w": torch.tensor([1.0])}, {"w": torch.tensor([2.0])}])
        previous = self._make_list_dict([{"w": torch.tensor([0.5])}])
        with pytest.raises(ValueError):
            _ = delta_parameters(current, previous)

    def test_raises_on_key_mismatch(self):
        current = self._make_list_dict([{"a": torch.tensor([1.0]), "b": torch.tensor([2.0])}])
        previous = self._make_list_dict([{"a": torch.tensor([0.5]), "c": torch.tensor([1.5])}])
        with pytest.raises(ValueError):
            _ = delta_parameters(current, previous)
