from sklearn.datasets import load_svmlight_file, dump_svmlight_file
from sklearn.model_selection import train_test_split
from scipy.sparse import coo_matrix, csr_matrix
import numpy as np
import warnings
from sklearn.cluster import KMeans

DATASET_PATH = '../datasets/'
RANDOM_STATE = 100


def load_data(dataset_name, n_workers, logreg, ordered, multilabel=False, d_y=None, npz=False):
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[
            [row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in
                                                                               row]] = 1
        y = labels
    if not ordered:
        np.random.shuffle(indices)
    indices = np.array_split(indices, n_workers)
    data_workers = [(x[ind_i], y[ind_i]) for ind_i in indices]
    return data_workers
 

def load_real_iid_sampling_data(dataset_name, n_workers, logreg, ordered, multilabel=False, d_y=None, npz=False, k=100):
    # Each client sample iid from m independently. k is the sampled number for each client.
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[
            [row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in
                                                                               row]] = 1
        y = labels
    if not ordered:
        np.random.shuffle(indices)

    if k is None or k > n:
        k = n

    data_workers = []
    for _ in range(n_workers):
        sampled_indices = np.random.choice(indices, k, replace=False)
        data_workers.append((x[sampled_indices], y[sampled_indices]))

    return data_workers


def load_non_iid_data(dataset_name, n_workers, logreg, ordered, multilabel=False, d_y=None, npz=False):
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[
            [row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in
                                                                               row]] = 1
        y = labels
    if not ordered:
        np.random.shuffle(indices)

    # Assuming binary classification for simplicity; modify as needed
    positive_indices = indices[y == 1]
    negative_indices = indices[y == 0]

    # Create skewed distributions for each worker
    data_workers = []
    for i in range(n_workers):
        # Adjust these ratios as needed for skew
        pos_ratio = (i + 1) / n_workers
        neg_ratio = 1 - pos_ratio

        pos_samples = int(pos_ratio * len(positive_indices))
        neg_samples = int(neg_ratio * len(negative_indices))

        worker_indices = np.concatenate((np.random.choice(positive_indices, pos_samples, replace=False),
                                         np.random.choice(negative_indices, neg_samples, replace=False)))

        np.random.shuffle(worker_indices)
        data_workers.append((x[worker_indices], y[worker_indices]))

    return data_workers


def load_feature_skewed_data(dataset_name, n_workers, logreg, ordered, multilabel=False, d_y=None, npz=False):
    # Check and warnings
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    # Load the dataset
    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    # Preprocessing
    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[
            [row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in
                                                                               row]] = 1
        y = labels

    # Random shuffling if ordered is False
    if not ordered:
        np.random.shuffle(indices)

    # Convert x to dense format if it's sparse, as KMeans in sklearn does not accept sparse input
    if isinstance(x, (coo_matrix, csr_matrix)):
        x = x.toarray()

    # Cluster the data
    kmeans = KMeans(n_clusters=n_workers, n_init=10)
    cluster_labels = kmeans.fit_predict(x)

    # Distribute data based on clusters
    data_workers = []
    for i in range(n_workers):
        worker_cluster_indices = (cluster_labels == i)
        worker_data_indices = indices[worker_cluster_indices]
        # data_workers.append((x[worker_data_indices], y[worker_data_indices]))
        worker_x = x[worker_data_indices]
        worker_y = y[worker_data_indices]
        data_workers.append((worker_x, worker_y))

    return data_workers


def load_feature_skewed_data_cluster(dataset_name, n_clusters, n_clients_per_cluster, logreg, ordered, multilabel=False,
                             d_y=None, npz=False):
    # Check and warnings
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    # Load the dataset
    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    # Preprocessing
    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[[row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in row]] = 1
        y = labels

    # Random shuffling if ordered is False
    if not ordered:
        np.random.shuffle(indices)

    # Convert x to dense format if it's sparse, as KMeans in sklearn does not accept sparse input
    if isinstance(x, (coo_matrix, csr_matrix)):
        x = x.toarray()

    # Cluster the data
    kmeans = KMeans(n_clusters=n_clusters, n_init=10)
    cluster_labels = kmeans.fit_predict(x)

    # Initialize dictionary to hold data for each cluster
    data_clusters = {i: [] for i in range(n_clusters)}

    # Distribute data based on clusters
    for i in range(n_clusters):
        worker_cluster_indices = (cluster_labels == i)
        worker_data_indices = indices[worker_cluster_indices]

        # Split cluster data among clients
        cluster_data_x = x[worker_data_indices]
        cluster_data_y = y[worker_data_indices]
        split_size = len(cluster_data_x) // n_clients_per_cluster

        for j in range(n_clients_per_cluster):
            start_idx = j * split_size
            if j == n_clients_per_cluster - 1:
                end_idx = len(cluster_data_x)  # Include remaining data in the last client
            else:
                end_idx = start_idx + split_size

            client_data = (cluster_data_x[start_idx:end_idx], cluster_data_y[start_idx:end_idx])
            data_clusters[i].append(client_data)

    return data_clusters



def load_quantity_skewed_data(dataset_name, n_workers, alpha, logreg, ordered, multilabel=False, d_y=None, npz=False):
    # Check and warnings
    if not ordered:
        warnings.warn('<ordered> is set to False.')
    if multilabel & (not logreg):
        raise ValueError('If <multilabel> is True, <logreg> must be True, too.')
    if multilabel & (d_y is None):
        raise ValueError('<d_y> must be an integer.')

    # Load the dataset
    if npz:
        arrays = np.load(DATASET_PATH + dataset_name)
        x = arrays['X']
        y = arrays['y']
    else:
        x, y = load_svmlight_file(DATASET_PATH + dataset_name, multilabel=multilabel)

    # Preprocessing
    n = x.shape[0]
    if (not multilabel) & logreg:
        y = (y != -1.) * 1
    indices = np.arange(n)
    if multilabel and (type(y) == coo_matrix or type(y) == csr_matrix):
        labels = np.zeros(shape=(n, d_y))
        labels[
            [row_num for row_num in range(len(y)) for col_num in y[row_num]], [int(col_num) for row in y for col_num in
                                                                               row]] = 1
        y = labels

    # Random shuffling if ordered is False
    if not ordered:
        np.random.shuffle(indices)

    # Guarantee at least one sample per client
    min_samples_per_client = 1
    n_samples_reserved = min_samples_per_client * n_workers
    remaining_samples = len(indices) - n_samples_reserved

    # Generate distribution ratios using the Dirichlet distribution for the remaining samples
    distribution_ratios = np.random.dirichlet([alpha] * n_workers)
    distribution_ratios = distribution_ratios * remaining_samples

    # Round up to ensure at least one sample per client
    distribution_ratios = np.floor(distribution_ratios) + min_samples_per_client

    # Ensure the total does not exceed the actual number of samples
    while distribution_ratios.sum() > len(indices):
        distribution_ratios[np.argmax(distribution_ratios)] -= 1

    # Split data based on quantity skew
    data_workers = []
    start_index = 0
    for i, n_samples in enumerate(distribution_ratios):
        n_samples = int(n_samples)
        end_index = start_index + n_samples

        worker_data_indices = indices[start_index:end_index]
        worker_x = x[worker_data_indices]
        worker_y = y[worker_data_indices]
        data_workers.append((worker_x, worker_y))

        start_index = end_index

    return data_workers


def load_data_npy(prefix, node_id, postfix_X='_X', postfix_y='_y'):
    path_x = DATASET_PATH + prefix + str(node_id) + postfix_X + '.npy'
    path_y = DATASET_PATH + prefix + str(node_id) + postfix_y + '.npy'
    with open(path_x, 'rb') as file:
        x = np.load(file)
    with open(path_y, 'rb') as file:
        y = np.load(file)
    ones = np.ones((x.shape[0], 1))
    x = np.hstack([ones, x])
    return x, y


def number_of_features(dataset_name):
    x, _ = load_svmlight_file(DATASET_PATH + dataset_name)
    # additional dimension corresponds to the intercept
    return x.shape[1] + 1


def split_data(dataset_name, validation_proportion=0.1):
    x, y = load_svmlight_file(DATASET_PATH + dataset_name)
    x_train, x_validation, y_train, y_validation = train_test_split(x, y, test_size=validation_proportion)
    # print(x_train.shape)
    # print(x_validation.shape)
    dump_svmlight_file(x_train, y_train, DATASET_PATH + dataset_name + 'train')
    dump_svmlight_file(x_validation, y_validation, DATASET_PATH + dataset_name + 'validation')
    return dataset_name + 'train', dataset_name + 'validation'

