import pytest
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from byzantine_robust_fl.local_update import (
    ClientUpdater,
    DatasetSplit,
    ServerEvaluator,
)

# --- Mock Objects and Fixtures for Testing ---


class MockDataset(Dataset):
    """Provide a deterministic dataset fixture for unit tests."""

    def __init__(self, num_samples=100, num_classes=10, feature_dim=8):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        # Create predictable labels (e.g., 0, 1, 2, ..., 9, 0, 1, ...)
        self.targets = [i % self.num_classes for i in range(num_samples)]
        self.data = [torch.randn(feature_dim) for _ in range(num_samples)]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]


class SimpleModel(nn.Module):
    """Define a lightweight linear model used in unit tests."""

    def __init__(self, in_features=8, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.fc(x)


@pytest.fixture
def mock_dataset():
    """Return a canonical mock dataset fixture."""
    return MockDataset(num_samples=100, num_classes=10, feature_dim=8)


@pytest.fixture
def device():
    """Return the torch device used in tests."""
    return torch.device("cpu")


# --- Test Cases ---


class TestDatasetSplit:
    """Validate the behaviour of the DatasetSplit wrapper."""

    def test_len(self, mock_dataset):
        """Verify that the split reports the expected length."""
        indices = list(range(0, 50, 5))  # 10 indices
        split = DatasetSplit(mock_dataset, indices)
        assert len(split) == 10

    def test_getitem(self, mock_dataset):
        """Ensure the split retrieves the matching original sample."""
        indices = [10, 25, 42]
        split = DatasetSplit(mock_dataset, indices)

        # The 2nd item in the split should correspond to the 25th in the original
        original_img, original_label = mock_dataset[25]
        split_img, split_label = split[1]

        assert torch.equal(original_img, split_img)
        assert original_label == split_label


class TestClientUpdater:
    """Exercise client-side training and evaluation logic."""

    @pytest.fixture
    def client_updater(self, device, mock_dataset):
        """Instantiate a ClientUpdater fixture for downstream tests."""
        client_indices = list(range(20))  # First 20 samples
        return ClientUpdater(device, mock_dataset, client_indices, batch_size=10)

    def test_initialization(self, client_updater):
        """Verify that the client constructs the expected data loaders."""
        assert isinstance(client_updater.train_loader, DataLoader)
        assert isinstance(client_updater.test_loader, DataLoader)
        # Test loader batch size should be the full dataset length for this client
        assert client_updater.test_loader.batch_size == 20

    def test_evaluate(self, client_updater, device):
        """Validate evaluation metrics produced by a predictable model."""

        # A model that always predicts class 0 with high confidence
        class MockPredictModel(SimpleModel):
            def forward(self, x):
                output = torch.zeros(x.size(0), 10)
                output[:, 0] = 1.0  # High logit for class 0
                return output

        model = MockPredictModel().to(device)
        # In the first 20 samples, there are 2 samples for class 0 (at index 0 and 10)
        # So, accuracy should be 2/20 = 0.1
        accuracy, loss = client_updater.evaluate(model)
        assert isinstance(accuracy, float)
        assert isinstance(loss, float)
        assert accuracy == pytest.approx(0.1)

    def test_train(self, client_updater, device):
        """Confirm that local training executes and updates weights."""
        model = SimpleModel().to(device)
        # Store a copy of the initial weights
        initial_weights = model.fc.weight.clone().detach()

        # Run training for one epoch
        state_dict, loss, acc = client_updater.train(model=model, learning_rate=0.1, local_epochs=1)

        # Check return types
        assert isinstance(state_dict, dict)
        assert isinstance(loss, float)
        assert isinstance(acc, float)

        # Check that the model weights have been updated
        trained_weights = state_dict["fc.weight"]
        assert not torch.equal(initial_weights, trained_weights)


class TestServerEvaluator:
    """Assess server-side evaluation utilities."""

    @pytest.fixture
    def server_evaluator(self, device, mock_dataset):
        """Instantiate a ServerEvaluator fixture for downstream tests."""
        # Use all 100 samples for server evaluation
        server_indices = list(range(100))
        return ServerEvaluator(device, mock_dataset, server_indices, batch_size=10)

    def test_create_class_based_loaders(self, server_evaluator, mock_dataset):
        """Verify that data partitions are built per class."""
        loaders = server_evaluator.class_loaders
        # Should have one loader for each of the 10 classes
        assert len(loaders) == mock_dataset.num_classes

        # Check the loader for class 3
        class_3_loader = loaders[3]
        all_labels_in_loader = []
        for _, labels in class_3_loader:
            all_labels_in_loader.extend(labels.tolist())

        # All labels must be 3, and there should be 10 such samples
        assert all(label == 3 for label in all_labels_in_loader)
        assert len(all_labels_in_loader) == 10

    def test_evaluate_by_class(self, server_evaluator, device):
        """Validate per-class evaluation outputs."""

        # A model that correctly predicts even-numbered classes but
        # incorrectly predicts odd-numbered classes as class 0.
        class MockEvenOddModel(SimpleModel):
            def forward(self, x):
                # This is a dummy forward pass; the logic is in the evaluation
                return super().forward(x)

        model = MockEvenOddModel().to(device)

        # Monkeypatch the model's forward pass for predictable results
        def predictable_forward(images):
            labels = torch.cat([y for _, y in server_evaluator.class_loaders[0].dataset])
            preds = labels.clone()
            # For odd labels, predict 0. For even labels, predict correctly.
            preds[labels % 2 != 0] = 0

            output = torch.zeros(images.size(0), 10)
            for i, p in enumerate(preds):
                output[i, p] = 1.0
            return output

        # For this test, we can directly manipulate the model behavior.
        # However, a simpler approach is to check the output shapes and types.
        accuracies, losses = server_evaluator.evaluate_by_class(model)

        assert isinstance(accuracies, list)
        assert isinstance(losses, list)
        assert len(accuracies) == len(server_evaluator.class_loaders)
        assert len(losses) == len(server_evaluator.class_loaders)
        assert all(0.0 <= acc <= 1.0 for acc in accuracies)
