import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
tf.config.set_visible_devices([], 'GPU')

def preprocess_dataset(dataset):
    """Normalize and preprocess the dataset."""
    dataset['image'] = np.float32(dataset['image']) / 255.0
    dataset.pop('id', None)  # Remove 'id' field if present
    return dataset

def pad_or_trim_data(client_data, target_size):
    """Pad or trim client data to have the same number of samples."""
    for client in client_data:
        num_samples = len(client['image'])
        if num_samples < target_size:
            # Padding
            padding_images = np.zeros((target_size - num_samples,) + client['image'].shape[1:])
            padding_labels = np.zeros((target_size - num_samples,) + client['label'].shape[1:])
            client['image'] = np.concatenate([client['image'], padding_images])
            client['label'] = np.concatenate([client['label'], padding_labels])
        elif num_samples > target_size:
            # Trimming
            client['image'] = client['image'][:target_size]
            client['label'] = client['label'][:target_size]
    return client_data

def distribute_data_iid(dataset, n_clients, n_shards_per_client, use_max_padding=False):
    """Distribute data IID among clients."""
    client_data = []
    num_samples_per_client = len(dataset['image']) // n_clients

    for _ in range(n_clients):
        client_shards = []
        for _ in range(n_shards_per_client):
            indices = np.random.choice(len(dataset['image']), num_samples_per_client, replace=False)
            shard = {'image': dataset['image'][indices], 'label': dataset['label'][indices]}
            client_shards.append(shard)

        client_data.append({'image': np.concatenate([shard['image'] for shard in client_shards], axis=0),
                            'label': np.concatenate([shard['label'] for shard in client_shards], axis=0)})
    return client_data

def distribute_data_non_iid(dataset, n_clients, n_shards_per_client, use_max_padding=False):
    """Distribute data non-IID among clients."""
    client_data = []
    num_samples_per_shard = len(dataset['image']) // (n_clients * n_shards_per_client)
    num_classes = len(np.unique(dataset['label']))
    class_indices = [np.where(dataset['label'] == i)[0] for i in range(num_classes)]

    for _ in range(n_clients):
        client_shards = []
        for _ in range(n_shards_per_client):
            selected_classes = np.random.choice(num_classes, 2, replace=False)
            indices = np.concatenate([np.random.choice(class_indices[cls], num_samples_per_shard // 2, replace=False) for cls in selected_classes])
            shard = {'image': dataset['image'][indices], 'label': dataset['label'][indices]}
            client_shards.append(shard)

        client_data.append({'image': np.concatenate([shard['image'] for shard in client_shards], axis=0),
                            'label': np.concatenate([shard['label'] for shard in client_shards], axis=0)})

    return client_data

def get_datasets(dataset_name: str,
                 n_clients: int = 1,
                 n_shards_per_client: int = 2,
                 iid: bool = True,
                 use_max_padding: bool = False):
    """
    Load and preprocess datasets with options for federated learning settings.
    :param dataset_name: Name of the dataset.
    :param n_clients: Number of clients for federated learning.
    :param n_shards_per_client: Number of shards per client.
    :param iid: If True, distribute data IID, else non-IID.
    :param use_max_padding: If True, pad or trim data to have the same number of samples.
    """
    ds_builder = tfds.builder(dataset_name)
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

    train_ds = preprocess_dataset(train_ds)
    test_ds = preprocess_dataset(test_ds)

    if n_clients > 1:
        if iid:
            train_clients = distribute_data_iid(train_ds, n_clients, n_shards_per_client)
        else:
            train_clients = distribute_data_non_iid(train_ds, n_clients, n_shards_per_client)

        # Determine the size to pad or trim
        sizes = [len(client['image']) for client in train_clients]
        target_size = max(sizes) if use_max_padding else min(sizes)
        train_clients = pad_or_trim_data(train_clients, target_size)
        reshaped_clients = {'image': [], 'label': []}
        for client in train_clients:
            reshaped_clients['image'].append(client['image'])
            reshaped_clients['label'].append(client['label'])

        reshaped_clients['image'] = np.stack(reshaped_clients['image'])
        reshaped_clients['label'] = np.stack(reshaped_clients['label'])
        return reshaped_clients, test_ds
    else:
        train_ds = {k: np.expand_dims(v, axis=0) for k,v in train_ds.items()}
        # test_ds = {k: np.expand_dims(v, axis=0) for k,v in test_ds.items()}

        return train_ds, test_ds


if __name__ == '__main__':
    # For single dataset
    train_ds, test_ds = get_datasets("mnist")
    print(train_ds['image'].shape, train_ds['label'].shape)
    print(test_ds['image'].shape, test_ds['label'].shape)
    # Or for federated learning
    train_clients, test_ds = get_datasets("mnist", n_clients=10, iid=False, n_shards_per_client=2)
    print(train_clients['image'].shape, train_clients['label'].shape)
    print(test_ds['image'].shape, test_ds['label'].shape)

    train_clients, test_ds = get_datasets("mnist", n_clients=10, iid=True, n_shards_per_client=2)
    print(train_clients['image'].shape, train_clients['label'].shape)
    print(test_ds['image'].shape, test_ds['label'].shape)

    train_clients, test_ds = get_datasets("mnist", n_clients=10, iid=False, n_shards_per_client=10)
    print(train_clients['image'].shape, train_clients['label'].shape)
    print(test_ds['image'].shape, test_ds['label'].shape)