import os
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader, Subset
from src.datasets.dataset_utils import AdultIncomeDataset
from src.datasets.dataset_utils import separate_data, split_data, check, save_file


dir_path = "./data/adult_income/"
num_classes = 2

def distributed_adult_income(num_clients, niid, balance, partition, data_points_per_client):
    '''
    Allocate Adult Income dataset to users:
    1. IID if (niid, partition, class_per_client) = (False, 'pat', num_classes).
    2. Non-IID via Pathological Splits if (True, 'pat', class_per_client).
    3. Non-IID via Dirichlet Distribution if (True, 'dir'), with lower alpha creating more uneven splits.

    Inputs:
    1. dir_path : The directory path to save the data
    2. num_clients (int): number of all clients in system
    3. niid (bool): boolean that specifies if the data should be distributed non-iid (True) or iid (False)
    4. balance (bool): boolean that specifies if the total number of data points per client is fixed or not (False)
    5. partition (string): Dirichlet distribution or Pathological Splits
    6. data_points_per_client (int): No. data points per client

    Outputs:
    1. X (list of dataset content): the i-th entry contains the dataset content distributed to client i
    2. y (list of dataset labels): the i-th entry contains the corresponding dataset labels distributed to client i
    3. statistic (list): The i-th entry shows the number of data points per label for client i.
    '''

    if not os.path.exists(dir_path):
        print("Path does not exist")
        os.makedirs(dir_path)

    config_path = dir_path + "config.json"
    train_path = dir_path + "train/"
    test_path = dir_path + "test/"

    train_file = os.path.join(train_path, "adult.data")
    test_file = os.path.join(test_path, "adult.test")

    if check(config_path, train_file, test_file, num_clients, num_classes, niid, balance, partition):
        return

    trainset = AdultIncomeDataset(csv_file="./data/adult_income/train/adult.data")
    testset = AdultIncomeDataset(csv_file="./data/adult_income/test/adult.test")

    train_indices = np.random.permutation(len(trainset))[:data_points_per_client*num_clients]
    test_indices = np.random.permutation(len(testset))[:int(data_points_per_client * num_clients * 0.25)]

    subset_trainset = Subset(trainset, train_indices)
    subset_testset = Subset(testset, test_indices)

    trainloader = torch.utils.data.DataLoader(
        subset_trainset, batch_size=len(trainset.data), shuffle=False)
    testloader = torch.utils.data.DataLoader(
        subset_testset, batch_size=len(testset.data), shuffle=False)

    for _, train_data in enumerate(trainloader, 0):
        subset_trainset.data, subset_trainset.labels = train_data
    for _, test_data in enumerate(testloader, 0):
        subset_testset.data, subset_testset.labels = test_data

    dataset_feature = []
    dataset_label = []

    dataset_feature.extend(subset_trainset.data.cpu().detach().numpy())
    dataset_feature.extend(subset_testset.data.cpu().detach().numpy())
    dataset_label.extend(subset_trainset.labels.cpu().detach().numpy())
    dataset_label.extend(subset_testset.labels.cpu().detach().numpy())
    
    dataset_feature = np.array(dataset_feature)
    dataset_label = np.array(dataset_label)

    # Now separate and distribute the data
    X, y, statistic = separate_data((dataset_feature, dataset_label), num_clients, num_classes, niid, balance,
                                    partition)
    train_data, test_data = split_data(X, y)

    # Save the distributed data
    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes, statistic, niid,
              balance, partition)

