import numpy as np
from typing import Dict, List, Optional, Set, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoTokenizer, DataCollatorWithPadding


class DatasetSplit(Dataset):
    """
    A dataset wrapper showing how client-level subsets can be created.
    This implementation uses GLUE-MRPC as an example and can be expanded.
    """

    def __init__(self, dataset, idxs: List[int], args) -> None:
        self.dataset = dataset
        self.idxs = list(idxs)
        self.args = args

    def __len__(self) -> int:
        return len(self.idxs)

    def __getitem__(self, item: int):
        """
        Example implementation for GLUE-MRPC formatted batches.
        Extend this block when supporting new datasets.
        """
        if self.args.dataset == "mrpc":
            data = self.dataset[int(self.idxs[item])]
            return {
                "input_ids": data["input_ids"],
                "labels": data["labels"],
                "token_type_ids": data["token_type_ids"],
                "attention_mask": data["attention_mask"]
            }

        raise ValueError(f"Dataset example undefined for: {self.args.dataset}")


# --------------------------------------------------------------------------- #
#                           Data Partition (IID Example)                      #
# --------------------------------------------------------------------------- #

def iid(dataset, num_users: int) -> Dict[int, Set[int]]:
    """
    IID client partitioning example for federated learning setups.
    Each user receives equal-length random sample subsets.
    """
    print("Assigning training samples (IID example)…")

    num_items = len(dataset) // num_users
    dict_users: Dict[int, Set[int]] = {}
    all_indices = list(range(len(dataset)))

    for user in range(num_users):
        chosen = set(np.random.choice(all_indices, num_items, replace=False))
        dict_users[user] = chosen
        all_indices = list(set(all_indices) - chosen)

    return dict_users


# --------------------------------------------------------------------------- #
#                       Partition + Tokenization  (Example)                   #
# --------------------------------------------------------------------------- #

def load_partition(args):
    """
    Demonstrates one dataset-loading pipeline using MRPC as an example.
    You may add additional datasets following the same structure.
    """

    dict_users: Optional[Dict[int, Set[int]]] = None

    # ============================= MRPC EXAMPLE ============================= #
    if args.dataset == "mrpc":

        raw_dataset = load_dataset("glue", "mrpc", keep_in_memory=True)
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

        def tokenize(example):
            return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

        tokenized = raw_dataset.map(tokenize, batched=True)
        args.data_collator = DataCollatorWithPadding(tokenizer)

        tokenized = tokenized.remove_columns(["sentence1", "sentence2", "idx"])
        tokenized = tokenized.rename_column("label", "labels")
        tokenized.set_format("torch")

        dataset_train = tokenized["train"]
        dataset_test = tokenized["validation"]
        args.num_classes = 2

        # example client-split strategy
        if getattr(args, "iid", True):
            dict_users = iid(dataset_train, args.num_users)

    else:
        raise ValueError(f"Dataset example is not implemented for: {args.dataset}")

    return args, dataset_train, dataset_test, None, None, dict_users, None