import numpy as np
import pytest
import torch
from torch.utils.data import Dataset

from byzantine_robust_fl.utils.data_sampling import (
    _get_dataset_labels,
    sample_iid,
    sample_noniid_by_dirichlet,
    sample_noniid_by_shards,
)


# --- Mock Dataset for Controlled Testing ---
class MockDataset(Dataset):
    """Provide a mock dataset tailored to predictable testing scenarios."""

    def __init__(self, num_samples=1000, num_classes=10, label_attr_name="targets"):
        self.num_samples = num_samples
        self.num_classes = num_classes
        # Create balanced labels
        labels = np.repeat(np.arange(num_classes), num_samples // num_classes)
        np.random.shuffle(labels)
        setattr(self, label_attr_name, labels)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        labels = getattr(self, next(iter(self.__dict__)))  # Get the label attr
        return torch.randn(1, 28, 28), labels[idx]


# --- Pytest Fixtures ---
@pytest.fixture
def balanced_dataset():
    """Return a balanced dataset with 1,000 samples across ten classes."""
    return MockDataset(num_samples=1000, num_classes=10)


@pytest.fixture
def small_dataset():
    """Return a compact dataset for edge-case testing."""
    return MockDataset(num_samples=100, num_classes=5)


# --- Test Suite ---


class TestHelperFunctions:
    """Exercise internal helper utilities."""

    def test_get_dataset_labels_with_targets(self):
        """Verify label extraction from 'targets' attribute."""
        dataset = MockDataset(label_attr_name="targets")
        labels = _get_dataset_labels(dataset)
        assert len(labels) == len(dataset)
        assert isinstance(labels, np.ndarray)

    def test_get_dataset_labels_with_labels(self):
        """Verify label extraction from 'labels' attribute."""
        dataset = MockDataset(label_attr_name="labels")
        labels = _get_dataset_labels(dataset)
        assert len(labels) == len(dataset)

    def test_get_dataset_labels_with_train_labels(self):
        """Verify label extraction from 'train_labels' attribute."""
        dataset = MockDataset(label_attr_name="train_labels")
        labels = _get_dataset_labels(dataset)
        assert len(labels) == len(dataset)

    def test_get_dataset_labels_raises_error(self):
        """Ensure an informative error is raised when labels are absent."""
        dataset = MockDataset(label_attr_name="data_labels")  # An unknown attribute
        with pytest.raises(AttributeError, match="Could not find a labels or targets attribute"):
            _get_dataset_labels(dataset)


class TestIIDSampling:
    """Validate ``sample_iid`` behaviour."""

    def test_iid_partitioning(self, balanced_dataset):
        """Check basic IID partitioning properties."""
        num_users = 10
        user_indices = sample_iid(balanced_dataset, num_users)

        assert len(user_indices) == num_users
        assert all(len(indices) == len(balanced_dataset) // num_users for indices in user_indices.values())

        # Check for overlaps - REMOVED because the source code allows replacement.
        all_assigned_indices = np.concatenate(list(user_indices.values()))
        # assert len(np.unique(all_assigned_indices)) == len(all_assigned_indices)

        # We can still check if the total number of assigned indices is correct.
        assert len(all_assigned_indices) == len(balanced_dataset)


class TestNonIIDSampling:
    """Validate non-IID sampling utilities."""

    def test_noniid_by_shards(self, balanced_dataset):
        """Check shard-based non-IID partitioning."""
        num_users = 10
        num_shards_per_user = 2
        user_indices = sample_noniid_by_shards(balanced_dataset, num_users, num_shards_per_user)

        assert len(user_indices) == num_users

        # Check that total items per user is correct
        total_shards = num_users * num_shards_per_user
        images_per_shard = len(balanced_dataset) // total_shards
        total_items_per_user = images_per_shard * num_shards_per_user
        assert all(len(v) == total_items_per_user for v in user_indices.values())

        # Check the non-IID property: users should have limited label diversity
        labels = _get_dataset_labels(balanced_dataset)
        len(np.unique(labels))
        for user_id in user_indices:
            # Skip check if user has no data, which can happen in some edge cases
            if len(user_indices[user_id]) == 0:
                continue
            user_labels = labels[user_indices[user_id]]
            # A user should have significantly fewer classes than the total
            # With 2 shards out of 20, they should have around 2 classes.
            assert len(np.unique(user_labels)) <= 4  # Allow for some edge cases

    def test_noniid_by_dirichlet_highly_skewed(self, balanced_dataset):
        """Evaluate Dirichlet sampling with a highly skewed prior."""
        num_users = 10
        alpha = 0.01
        user_indices, class_counts = sample_noniid_by_dirichlet(balanced_dataset, num_users, alpha)

        assert len(user_indices) == num_users
        assert class_counts.shape == (num_users, 10)

        # Check that total assigned indices match the sum of counts
        total_assigned = sum(len(v) for v in user_indices.values())
        assert total_assigned == np.sum(class_counts)

        # Check non-IID property: each user should have data from very few classes
        for i in range(num_users):
            # The number of classes with non-zero samples should be small
            num_active_classes = np.count_nonzero(class_counts[i, :])
            assert num_active_classes <= 3

    def test_noniid_by_dirichlet_uniform(self, balanced_dataset):
        """Evaluate Dirichlet sampling under a near-uniform prior."""
        num_users = 10
        alpha = 1000.0
        user_indices, class_counts = sample_noniid_by_dirichlet(balanced_dataset, num_users, alpha)

        # Check IID-like property: samples per class should be roughly equal for each user
        total_samples_per_user = sum(len(v) for v in user_indices.values()) / num_users
        expected_samples_per_class = total_samples_per_user / 10  # 10 classes

        for i in range(num_users):
            # All classes should be present
            assert np.all(class_counts[i, :] > 0)
            # The count for each class should be close to the expected mean
            assert np.allclose(class_counts[i, :], expected_samples_per_class, atol=10)
