import torch
import torchvision
import torchvision.transforms as transforms
from collections import Counter
import numpy as np
from torch.utils.data import DataLoader

def data_split(num_devices, num_samples, num_labels,server_samples, dataset_type="MNIST", check=False):
    
    """
    Read a specified dataset (MNIST, FashionMNIST, CIFAR10, or CIFAR100) from torchvision, perform data distribution, and return the data in Tensor type.

    Parameters:
    num_devices: The number of clients
    num_samples: The number of samples to be drawn from the training set for each client
    server_samples: The number of samples to be drawn from the training set for the server (actually used for testing only, not for training)
    dataset_type: The type of dataset to be loaded. Optional values are "MNIST", "FashionMNIST", "CIFAR10", or "CIFAR100"
    check: Whether to check the distributed data and print relevant information. The default is False

    Returns:
    client_data: A dictionary containing client data. The keys are client indices, and the values are corresponding data tuples (training images, training labels, test images, test labels), all of which are of type torch.Tensor
    server_data: A dictionary containing server data, including keys such as server training images, training replica labels, test images, test labels, etc., all of which are of type torch.Tensor
    """
    
    if dataset_type == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        max_label = 10

    else:
        raise ValueError("Invalid dataset_type. Choose either 'MNIST', 'FashionMNIST', 'CIFAR10' or 'CIFAR100'.")

    client_data = {}
    server_data = {}

    train_loader = DataLoader(trainset, batch_size=len(trainset), shuffle=True)
    test_loader = DataLoader(testset, batch_size=len(testset), shuffle=True)


    train_imgs, train_labels = next(iter(train_loader))
    test_imgs, test_labels = next(iter(test_loader))

    train_permutation = torch.randperm(train_labels.shape[0])
    train_imgs = train_imgs[train_permutation]
    train_labels = train_labels[train_permutation]

    test_permutation = torch.randperm(test_labels.shape[0])
    test_imgs = test_imgs[test_permutation]
    test_labels = test_labels[test_permutation]

    for client_idx in range(num_devices):
        
        if dataset_type in ["MNIST"]:
            client_labels = [((client_idx * num_labels + j) % max_label) for j in range(num_labels)]

        train_mask = torch.isin(train_labels, torch.tensor(client_labels))
        test_mask = torch.isin(test_labels, torch.tensor(client_labels))

        train_indices = torch.where(train_mask)[0][:num_samples]
        test_indices = torch.where(test_mask)[0][:num_samples // 4]

        client_train_imgs = train_imgs[train_indices]
        client_train_labels = train_labels[train_indices]
        client_test_imgs = test_imgs[test_indices]
        client_test_labels = test_labels[test_indices]

        client_data[client_idx] = (
            client_train_imgs,
            client_train_labels,
            client_test_imgs,
            client_test_labels
        )

    all_client_train_imgs = torch.cat([client_data[i][0] for i in range(num_devices)], dim=0)
    all_client_train_labels = torch.cat([client_data[i][1] for i in range(num_devices)], dim=0)
    all_client_test_imgs = torch.cat([client_data[i][2] for i in range(num_devices)], dim=0)
    all_client_test_labels = torch.cat([client_data[i][3] for i in range(num_devices)], dim=0)

    server_train_indices = torch.randperm(all_client_train_labels.shape[0])[:server_samples]
    server_test_indices = torch.randperm(all_client_test_labels.shape[0])[:server_samples // 4]

    server_data['global_train_imgs'] = all_client_train_imgs[server_train_indices]
    server_data['global_train_labels'] = all_client_train_labels[server_train_indices]
    server_data['global_test_imgs'] = all_client_test_imgs[server_test_indices]
    server_data['global_test_labels'] = all_client_test_labels[server_test_indices]

    if check:
        print("Client data info:")
        for client_idx, (X_train, y_train, X_test, y_test) in client_data.items():
            print(f"Client {client_idx}")
            label_count_client = Counter(y_train.tolist())
            unique_labels_client = y_train.unique()
            unique_labels_str = ', '.join([str(label.item()) for label in unique_labels_client])
            print(unique_labels_str)
            print(', '.join([f'{label}: {count}' for label, count in label_count_client.items()]))
        print('---------------------------------------')
       
        print("Server data info:")
        print(f"Training set size: {server_data['global_train_imgs'].shape[0]}")
        print(f"Test set size: {server_data['global_test_imgs'].shape[0]}")
        label_counts_server_train = Counter(server_data['global_train_labels'].tolist())
        print("Label counts in server training set:", end=' ')
        print(", ".join([f"{label}: {count}" for label, count in label_counts_server_train.items()]))
        print('-------------')
        label_counts_server_test = Counter(server_data['global_test_labels'].tolist())
        print("Label counts in server test set:", end=' ')
        print(", ".join([f"{label}: {count}" for label, count in label_counts_server_test.items()]))

    return client_data, server_data