#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide dataset partitioning helpers for federated learning experiments."""

from typing import Dict, List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset


def _get_dataset_labels(dataset: Dataset) -> np.ndarray:
    """Extract labels from a dataset while accounting for common attribute names.

    Args:
        dataset: Dataset instance from which to retrieve labels.

    Returns:
        NumPy array containing one label per sample.

    Raises:
        AttributeError: If no supported label attribute is present.
    """
    if hasattr(dataset, "targets"):
        # For datasets like CIFAR10, FashionMNIST
        labels = dataset.targets
    elif hasattr(dataset, "labels"):
        # Common alternative attribute name
        labels = dataset.labels
    elif hasattr(dataset, "train_labels"):
        # For older versions of MNIST
        labels = dataset.train_labels
    else:
        raise AttributeError("Could not find a labels or targets attribute in the dataset.")

    if isinstance(labels, list):
        return np.array(labels)
    if isinstance(labels, torch.Tensor):
        return labels.numpy()
    return labels


def sample_iid(dataset: Dataset, num_users: int, samples_per_client: int = None) -> Dict[int, np.ndarray]:
    """Sample IID partitions for each client while preserving class balance.

    Args:
        dataset: Dataset to partition.
        num_users: Number of federated clients.
        samples_per_client: Optional fixed number of samples per client.

    Returns:
        Mapping from client identifier to an array of dataset indices.
    """
    labels = _get_dataset_labels(dataset)
    num_classes = len(np.unique(labels))
    total_samples = len(dataset)

    if samples_per_client is None:
        samples_per_client = total_samples // num_users

    user_data_indices = {i: [] for i in range(num_users)}

    for c in range(num_classes):
        class_indices = np.where(labels == c)[0]

        base_num = samples_per_client // num_classes
        leftover = samples_per_client % num_classes

        for user_id in range(num_users):
            chosen = np.random.choice(class_indices, base_num, replace=True)
            user_data_indices[user_id].extend(chosen)

        if leftover > 0:
            extra = np.random.choice(class_indices, num_users * leftover, replace=True)
            split = np.array_split(extra, num_users)
            for user_id, extra_indices in enumerate(split):
                user_data_indices[user_id].extend(extra_indices)

    return {k: np.array(v) for k, v in user_data_indices.items()}


def sample_noniid_by_shards(dataset: Dataset, num_users: int, num_shards_per_user: int = 2) -> Dict[int, np.ndarray]:
    """Sample non-IID partitions by allocating shards of label-sorted data to clients."""
    total_samples = len(dataset)
    num_shards = num_users * num_shards_per_user
    images_per_shard = max(1, total_samples // num_shards)

    labels = _get_dataset_labels(dataset)
    data_indices = np.arange(total_samples)
    sorted_indices_by_label = np.argsort(labels)
    data_indices = data_indices[sorted_indices_by_label]

    user_data_indices = {i: np.array([], dtype=int) for i in range(num_users)}
    shard_indices = np.arange(num_shards)
    np.random.shuffle(shard_indices)

    for i in range(num_users):
        assigned_shards = shard_indices[i * num_shards_per_user : (i + 1) * num_shards_per_user]
        for shard_idx in assigned_shards:
            start = shard_idx * images_per_shard
            end = start + images_per_shard
            if start >= total_samples:
                # not enough data, resample with replacement
                resample = np.random.choice(data_indices, images_per_shard, replace=True)
                user_data_indices[i] = np.concatenate((user_data_indices[i], resample))
            else:
                user_data_indices[i] = np.concatenate(
                    (user_data_indices[i], data_indices[start : min(end, total_samples)]), axis=0
                )
    return user_data_indices


def sample_noniid_by_dirichlet(
    dataset: Dataset, num_users: int, alpha: float = 0.5, samples_per_client: int = None
) -> Tuple[Dict[int, List[int]], np.ndarray]:
    """Sample non-IID partitions by drawing class proportions from a Dirichlet prior.

    Args:
        dataset: Dataset to partition.
        num_users: Number of federated clients.
        alpha: Dirichlet concentration parameter controlling heterogeneity.
        samples_per_client: Optional fixed number of samples per client.

    Returns:
        Tuple of (client-to-indices mapping, per-client class counts).
    """
    labels = _get_dataset_labels(dataset)
    num_classes = len(np.unique(labels))
    len(dataset)

    indices_by_class = {i: np.where(labels == i)[0] for i in range(num_classes)}
    class_distribution_per_user = np.random.dirichlet(alpha=np.full(num_classes, alpha), size=num_users)

    user_data_indices = {i: [] for i in range(num_users)}
    class_counts_per_user = np.zeros((num_users, num_classes), dtype=int)

    for class_idx in range(num_classes):
        class_indices = indices_by_class[class_idx]
        np.random.shuffle(class_indices)

        proportions = class_distribution_per_user[:, class_idx]
        if samples_per_client is None:
            num_samples_for_class = (proportions * len(class_indices)).astype(int)
        else:
            num_samples_for_class = (proportions * samples_per_client).astype(int)

        # 修正过多分配
        while sum(num_samples_for_class) > len(class_indices):
            idx = np.random.choice(np.where(num_samples_for_class > 0)[0])
            num_samples_for_class[idx] -= 1

        class_counts_per_user[:, class_idx] = num_samples_for_class
        assignments = np.cumsum(num_samples_for_class)
        assignments = np.insert(assignments, 0, 0)

        for user_idx in range(num_users):
            start, end = assignments[user_idx], assignments[user_idx + 1]
            if end > len(class_indices):
                chosen = np.random.choice(class_indices, end - start, replace=True)
            else:
                chosen = class_indices[start:end]
            user_data_indices[user_idx].extend(chosen)

    return user_data_indices, class_counts_per_user
