import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from transformers import AutoTokenizer

def distribute_data_iid(dataset, n_clients, n_shards_per_client):
    """
    Distribute the dataset in an IID fashion across clients.
    :param dataset: The dataset to distribute.
    :param n_clients: Number of clients.
    :param n_shards_per_client: Number of shards per client.
    :return: A list of datasets, one for each client.
    """
    n_shards = n_clients * n_shards_per_client
    shard_size = len(dataset['input_ids']) // n_shards

    # Shuffle the data indices
    indices = np.random.permutation(len(dataset['input_ids']))

    shards = []
    for i in range(n_shards):
        shard_indices = indices[i * shard_size:(i + 1) * shard_size]
        shard = {key: dataset[key][shard_indices] for key in dataset}
        shards.append(shard)

    # Distribute shards to clients
    client_datasets = []
    for i in range(n_clients):
        client_shards = shards[i * n_shards_per_client:(i + 1) * n_shards_per_client]
        client_dataset = {key: np.concatenate([shard[key] for shard in client_shards], axis=0) for key in dataset}
        client_datasets.append(client_dataset)

    return client_datasets

def distribute_data_non_iid(dataset, n_clients, n_shards_per_client):
    """
    Distribute the dataset in a non-IID fashion across clients.
    :param dataset: The dataset to distribute.
    :param n_clients: Number of clients.
    :param n_shards_per_client: Number of shards per client.
    :return: A list of datasets, one for each client.
    """
    n_shards = n_clients * n_shards_per_client
    shard_size = len(dataset['input_ids']) // n_shards

    # Sort the data by label to create non-IID shards
    sorted_indices = np.argsort(dataset['label'])
    shards = []
    for i in range(n_shards):
        shard_indices = sorted_indices[i * shard_size:(i + 1) * shard_size]
        shard = {key: dataset[key][shard_indices] for key in dataset}
        shards.append(shard)

    # Distribute shards to clients
    client_datasets = []
    for i in range(n_clients):
        client_shards = shards[i * n_shards_per_client:(i + 1) * n_shards_per_client]
        client_dataset = {key: np.concatenate([shard[key] for shard in client_shards], axis=0) for key in dataset}
        client_datasets.append(client_dataset)

    return client_datasets



def get_nlp_datasets(dataset_name: str,
                     n_clients: int = 1,
                     n_shards_per_client: int = 2,
                     iid: bool = True,
                     use_max_padding: bool = False,
                     model_name: str = "bert-base-uncased"):
    """
    Load and preprocess NLP datasets with options for federated learning settings.
    :param dataset_name: Name of the dataset.
    :param n_clients: Number of clients for federated learning.
    :param n_shards_per_client: Number of shards per client.
    :param iid: If True, distribute data IID, else non-IID.
    :param use_max_padding: If True, pad or trim data to have the same number of samples.
    :param model_name: Name of the tokenizer/model to use (e.g., BERT, GPT).
    """
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Load and prepare dataset
    ds_builder = tfds.builder(dataset_name)
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

    # Preprocess the datasets (tokenization and padding)
    train_ds = preprocess_nlp_dataset(train_ds, tokenizer)
    test_ds = preprocess_nlp_dataset(test_ds, tokenizer)

    if n_clients > 1:
        if iid:
            train_clients = distribute_data_iid(train_ds, n_clients, n_shards_per_client)
        else:
            train_clients = distribute_data_non_iid(train_ds, n_clients, n_shards_per_client)

        # Determine the size to pad or trim
        sizes = [len(client['input_ids']) for client in train_clients]
        target_size = max(sizes) if use_max_padding else min(sizes)
        train_clients = pad_or_trim_data(train_clients, target_size, padding_value=tokenizer.pad_token_id)
        reshaped_clients = {'input_ids': [], 'attention_mask': [], 'label': []}
        for client in train_clients:
            reshaped_clients['input_ids'].append(client['input_ids'])
            reshaped_clients['attention_mask'].append(client['attention_mask'])
            reshaped_clients['label'].append(client['label'])

        reshaped_clients['input_ids'] = np.stack(reshaped_clients['input_ids'])
        reshaped_clients['attention_mask'] = np.stack(reshaped_clients['attention_mask'])
        reshaped_clients['label'] = np.stack(reshaped_clients['label'])
        return reshaped_clients, test_ds
    else:
        train_ds = {k: np.expand_dims(v, axis=0) for k,v in train_ds.items()}
        return train_ds, test_ds

def preprocess_nlp_dataset(dataset, tokenizer):
    """
    Preprocess an NLP dataset by tokenizing the text and preparing input tensors.
    :param dataset: Dataset to preprocess.
    :param tokenizer: Tokenizer to use.
    :return: Preprocessed dataset.
    """
    inputs = tokenizer(list(dataset['text']), padding=True, truncation=True, return_tensors='np')
    dataset['input_ids'] = inputs['input_ids']
    dataset['attention_mask'] = inputs['attention_mask']
    return dataset

def pad_or_trim_data(clients, target_size, padding_value):
    """
    Pad or trim each client's data to a uniform size.
    :param clients: List of client datasets.
    :param target_size: Target size for padding/trimming.
    :param padding_value: Value to use for padding.
    :return: List of padded or trimmed client datasets.
    """
    for client in clients:
        client['input_ids'] = pad_or_trim(client['input_ids'], target_size, padding_value)
        client['attention_mask'] = pad_or_trim(client['attention_mask'], target_size, 0)  # Mask padding should be 0
        client['label'] = client['label'][:target_size]  # Assuming labels are the same length
    return clients

def pad_or_trim(data, target_size, padding_value):
    """
    Pad or trim a sequence to the target size.
    :param data: Sequence to pad or trim.
    :param target_size: Target size for the sequence.
    :param padding_value: Value to use for padding.
    :return: Padded or trimmed sequence.
    """
    if len(data) < target_size:
        data = np.pad(data, (0, target_size - len(data)), 'constant', constant_values=padding_value)
    else:
        data = data[:target_size]
    return data
