import os
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Subset
from src.datasets.dataset_utils import separate_data, split_data, check, save_file


dir_path = "./data/mnist/"
num_classes = 10

def distributed_mnist(num_clients, niid, balance, partition, data_points_per_client):
    '''
    Allocate mnist data to users:
    1. IID if (niid, partition, class_per_client) = (False, 'pat', num_classes).
    2. Non-IID via Pathological Splits if (True, 'pat', class_per_client).
        This extreme form of non-iid distribution results in each client receiving data from only a few classes.
    3. Non-IID via Dirichlet Distribution if (True, 'dir'), with lower alpha creating more uneven splits.
        This is a more controlled data splitting method, lower alpha -> more uneven (highly non-iid) distribution.

    Inputs:
    1. dir_path :
    2. num_clients (int): number of all clients in system
    3. niid (bool): boolean that specifies if the data should be distributed non-iid (True) or iid (False)
    4. balance (bool): boolean that specifies if the total number of data points per client is fixed or not (False).
    5. partition (string): whether data is non-iid distributed based on Dirichlet distribution or Pathological Splits
    6. data_points_per_client (int): No. data points per client

    Outputs:
    1. X (list of dataset content): the i-th entry contains the dataset content distributed to client i
    2. y (list of dataset labels): the i-th entry contains the corresponding dataset labels distributed to client i
    3. statistic (list): The i-th entry shows the number of data points per label for client i.
    '''

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    # Setup directory for train/test data
    config_path = dir_path + "config.json"
    train_path = dir_path + "train/"
    test_path = dir_path + "test/"

    if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
        return

    # Get MNIST data
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    trainset = torchvision.datasets.MNIST(
        root=dir_path + 'rawdata', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(
        root=dir_path + 'rawdata', train=False, download=True, transform=transform)

    # Define the indices of the data you want to select (e.g., first 1000 samples)
    #indices = list(range(data_points_per_client*num_clients))
    train_indices = np.random.permutation(len(trainset))[:data_points_per_client*num_clients]
    test_indices = np.random.permutation(len(testset))[:int(data_points_per_client * num_clients * 0.25)]

    # Create a subset of the dataset using these indices
    subset_trainset = Subset(trainset, train_indices)
    subset_testset = Subset(testset, test_indices)

    trainloader = torch.utils.data.DataLoader(
        subset_trainset, batch_size=len(trainset.data), shuffle=False)
    testloader = torch.utils.data.DataLoader(
        subset_testset, batch_size=len(testset.data), shuffle=False)

    for _, train_data in enumerate(trainloader, 0):
        subset_trainset.data, subset_trainset.targets = train_data
    for _, test_data in enumerate(testloader, 0):
        subset_testset.data, subset_testset.targets = test_data

    dataset_image = []
    dataset_label = []

    dataset_image.extend(subset_trainset.data.cpu().detach().numpy())
    dataset_image.extend(subset_testset.data.cpu().detach().numpy())
    dataset_label.extend(subset_trainset.targets.cpu().detach().numpy())
    dataset_label.extend(subset_testset.targets.cpu().detach().numpy())
    dataset_image = np.array(dataset_image)
    dataset_label = np.array(dataset_label)

    X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes,
                                    niid, balance, partition)
    train_data, test_data = split_data(X, y)

    print(len(train_data))

    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes,
              statistic, niid, balance, partition)

