# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at

#   http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


import os
import json
import math
import torch
import numpy
import argparse
from scipy.io import arff
import re
import copy
from collections import Counter
from collections import defaultdict
from sklearn import metrics
from scipy.spatial.distance import cdist
from numpy import dot
from numpy.linalg import norm
import matplotlib.pyplot as plt


import pickle
from sklearn.model_selection import train_test_split

def dirichlet_sampling(labels, num_users, alpha, num_classes, vis=False, fig_name="cluster"):
    """
    Sort labels and use dirichlet resampling to split the labels
    :param dataset:
    :param num_users:
    :return:
    """
    K = num_classes
    N = labels.shape[0]
    threshold = 0.5
    min_require_size = N / num_users * (1 - threshold)
    max_require_size = N / num_users * (1 + threshold)
    min_size, max_size = 0, 1e6
    iter_idx = 0

    while (
            min_size < min_require_size or max_size > max_require_size
    ) and iter_idx < 1000:
        idx_batch = [[] for _ in range(num_users)]
        # plt.clf()
        for k in range(K):
            idx_k = numpy.where(labels == k)[0]
            numpy.random.shuffle(idx_k)
            proportions = numpy.random.dirichlet(numpy.repeat(alpha, num_users))

            # avoid adding over
            proportions = numpy.array(
                [
                    p * (len(idx_j) < N / num_users)
                    for p, idx_j in zip(proportions, idx_batch)
                ]
            )
            proportions = proportions / proportions.sum()
            proportions = (numpy.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(idx_batch, numpy.split(idx_k, proportions))
            ]

            min_size = min([len(idx_j) for idx_j in idx_batch])
            max_size = max([len(idx_j) for idx_j in idx_batch])

        iter_idx += 1

    # divide and assign
    dict_users = {i: idx for i, idx in enumerate(idx_batch)}
    return dict_users

def load_RTD_dataset_final(path='RTD/',smoothing=False, cluster_beta=0.1, num_examples_per_client=2400, num_labeled_examples=640, num_clients_per_clusters=[33,33,34], num_clusters = 3):

    # Load the RTD Dataset

    with open(path + 'features', 'rb') as fp:
        features = pickle.load(fp)

    with open(path + 'labels', 'rb') as fp:
        labels = pickle.load(fp)


    train_size = features.shape[0]

    nb_dims = 3;
    length = int(features.shape[1] / nb_dims);
    load_length = features.shape[1];

    train = numpy.empty((train_size, nb_dims, length))

    train[:, 0, :] = features[:, 0:load_length:3];
    train[:, 1, :] = features[:, 1:load_length:3];
    train[:, 2, :] = features[:, 2:load_length:3];
    train_labels = numpy.argmax(labels, axis=1)



    train[:, 0, :] = train[:, 0, :] - numpy.expand_dims(train[:, 0, 0], axis=-1);
    train[:, 1, :] = train[:, 1, :] - numpy.expand_dims(train[:, 1, 0], axis=-1);
    train[:, 2, :] = train[:, 2, :] - numpy.expand_dims(train[:, 2, 0], axis=-1);


    # Smoothing
    if smoothing:
        for j in range(len(train)):
            for k in range(nb_dims):
                train[j, k, :] = numpy.convolve(train[j, k, :],
                                                numpy.array([1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], 'f'),
                                                mode='same')



    # Sampling indices from Dirichlet Distribution

    beta = cluster_beta
    dict_users = dirichlet_sampling(
        train_labels, num_clusters, beta, 10, vis=False, fig_name="y_shift"
    )
    idx_list = [dict_users[i] for i in range(len(dict_users))]
    idx_list_to_return=[[numpy.random.choice(idx_list[i], num_examples_per_client, replace=False)  for j in range(num_clients_per_clusters[i])] for i in range(len(idx_list))]


    dir_train=[];
    dir_train_labels=[];
    dir_test=[];
    dir_test_labels=[];
    dir_val=[];
    dir_val_labels=[];

    dir_train_labeled_dat=[];
    dir_train_labeled_dat_labels =[];

    token_set_x = numpy.empty((10, nb_dims, length))
    token_set_labels = numpy.empty(10, dtype=int)

    for k in range(10):

        idx_k = numpy.where(train_labels == k)[0]
        rand_idx=idx_k[numpy.random.randint(len(idx_k))]
        token_set_x[k,:,:] = train[rand_idx]
        token_set_labels[k]=train_labels[rand_idx]

    for i in range(len(idx_list)):
        for j in range(num_clients_per_clusters[i]):
            idx_train, idx_test=train_test_split(idx_list_to_return[i][j], test_size=0.1, random_state=42)
            idx_train, idx_val = train_test_split(idx_train, test_size=0.1, random_state=42)
            idx_train_lab=numpy.random.choice(idx_train,num_labeled_examples, replace=False)

            dir_train+=[train[idx_train]]
            dir_train_labels += [train_labels[idx_train]]

            dir_train_labeled_dat += [train[idx_train_lab]];
            dir_train_labeled_dat_labels += [train_labels[idx_train_lab]];

            dir_test += [train[idx_test]]
            dir_test_labels += [train_labels[idx_test]]


            dir_val += [train[idx_val]]
            dir_val_labels += [train_labels[idx_val]]

    return dir_train, dir_train_labels, dir_test, dir_test_labels, token_set_x, token_set_labels, dir_train_labeled_dat, dir_train_labeled_dat_labels, dir_val, dir_val_labels

def main(args):

    torch.manual_seed(4321)
    numpy.random.seed(1234)
    path=args.path
    cluster_beta=args.cluster_beta
    num_clients_per_clusters=args.num_clients_per_clusters
    num_examples_per_client=args.num_examples_per_client
    num_labeled_examples=args.num_labeled_examples



    train, train_labels, test_, test_labels_, tokenset, tokenset_labels, train_labeled_dat, train_labeled_dat_labels, val_, val_labels_  = load_RTD_dataset_final(path, cluster_beta=cluster_beta,num_clients_per_clusters=num_clients_per_clusters,num_examples_per_client=num_examples_per_client,num_labeled_examples=num_labeled_examples)


    with open('train_encoder_x', 'wb') as fp:
        pickle.dump(train, fp)
    with open('train_encoder_y', 'wb') as fp:
        pickle.dump(train_labels, fp)

    with open('test_x', 'wb') as fp:
        pickle.dump(test_, fp)
    with open('test_y', 'wb') as fp:
        pickle.dump(test_labels_, fp)

    with open('val_x', 'wb') as fp:
        pickle.dump(val_, fp)
    with open('val_y', 'wb') as fp:
        pickle.dump(val_labels_, fp)

    with open('token_x', 'wb') as fp:
        pickle.dump(tokenset, fp)
    with open('token_y', 'wb') as fp:
        pickle.dump(tokenset_labels, fp)


    with open('train_labeled_x', 'wb') as fp:
        pickle.dump(train_labeled_dat, fp)
    with open('train_labeled_y', 'wb') as fp:
        pickle.dump(train_labeled_dat_labels, fp)




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default='RTD/')
    parser.add_argument('--cluster_beta', type=float, default=0.1)
    parser.add_argument('--num_clusters', type=int, default=3)
    parser.add_argument(
        '--num_clients_per_clusters',
        type=int,
        nargs='+',
        default=[33,33,34]
    )
    parser.add_argument(
        '--num_examples_per_client',
        type=int,
        default=2400
    )
    parser.add_argument(
        '--num_labeled_examples',
        type=int,
        default=1944
    )



    args = parser.parse_args()
    main(args)