import os
import json
import numpy.random as rv
from collections import Counter

#def calc_data_stat(labels):



def projection(vec):
    copy_vec = vec
    vec = sorted(vec, reverse=True)

    # Get MaxRho
    for i, val in enumerate(vec):
        tmp = 0
        for j in range(i):
            tmp += vec[j]

        if val-(tmp-1)/(i+1)>0:
            max_rho = i

    tmp1 = 0
    for i in range(max_rho+1):
        tmp1 += vec[i]

    tmp2 = (tmp1-1)/(max_rho+1)

    new = []
    for i in range(len(vec)):
        new.append(max(copy_vec[i]-tmp2,0))

    return new


import pickle
import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.pylab as pylab

c_r = cm.get_cmap('Reds')
c_b = cm.get_cmap('Blues')
c_g = cm.get_cmap('Greens')
c_o = cm.get_cmap('Oranges')
c_p = cm.get_cmap('Purples')
c_t = cm.get_cmap('tab10')


def plot(train_data, test_data, alpha):

    num_labels = 10
    '''
    ftsize = 100
    params = {'legend.fontsize': ftsize,
         'axes.labelsize': ftsize,
         'axes.titlesize':ftsize,
         'xtick.labelsize':ftsize,
         'ytick.labelsize':ftsize}
    pylab.rcParams.update(params)
    lw = 9
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams['axes.labelweight'] = 'bold'
    '''




    color_list = ['red', 'lime', 'blue']
    #plt.figure(figsize=(24,18))
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    plt.setp(ax1,
             yticks=list(np.arange(0,num_labels,1)))
    plt.setp(ax2,
             yticks=list(np.arange(0, num_labels, 1)))

    fig.suptitle('For alpha={}'.format(alpha))

    for i in range(len(train_data)):
        for j in list(train_data[i].keys()):
            ax1.scatter(i, j, s=train_data[i][j], c='b')

    ax1.set_xlabel('Client')
    ax1.set_ylabel('Label')
    ax1.set_yticks(np.arange(0,num_labels,1))
    ax1.set_title('Train Data')
    #legend_properties = {'weight': 'bold'}
    # plt.yscale('log')


    for i in range(len(test_data)):
        for j in list(test_data[i].keys()):
            ax2.scatter(i, j, s=test_data[i][j], c='b')


    ax2.set_xlabel('Client')
    ax2.set_ylabel('Label')
    ax2.set_yticks(np.arange(0,num_labels,1))
    ax2.set_title('Test Data')


    for i in range(len(test_data)):

        for j in list(test_data[i].keys()):
            if j in list(train_data[i].keys()):
                ax3.scatter(i, j, s=test_data[i][j]+train_data[i][j], c='b')
            else:
                ax3.scatter(i, j, s=test_data[i][j], c='b')

        for j in list(train_data[i].keys()):
            if j not in list(test_data[i].keys()):
                ax3.scatter(i, j, s=train_data[i][j], c='b')



    ax3.set_xlabel('Client')
    ax3.set_ylabel('Label')
    ax3.set_yticks(np.arange(0,num_labels,1))
    ax3.set_title('All Data')


    plt.show()
    #plt.title('K={}, C={:.2f}'.format(ensize, frac))
    #plt.legend()
    #plt.grid()
    fig.savefig('save/1_K100_clust2_tl50000_diffdiven05_alpha_{}.pdf'. \
                format(alpha), bbox_inches='tight')


def main_plot(train_data, val_data, test_data, args):


    num_labels = 10
    '''
    ftsize = 30
    params = {'legend.fontsize': 23,
              'axes.labelsize': ftsize,
              'axes.titlesize': ftsize,
              'xtick.labelsize': 25,
              #'ytick.labelsize': 25}
              'ytick.labelsize': 15}
    pylab.rcParams.update(params)
    plt.rcParams["font.family"] = "Times New Roman"
    plt.rcParams["font.weight"] = "normal"
    plt.rcParams["axes.labelweight"] = "normal"
    '''
    fig1, (ax11, ax12, ax13) = plt.subplots(1, 3, figsize=(15,5))

    plt.setp(ax11,
             yticks=list(np.arange(0, num_labels, 1)))
    plt.setp(ax12,
             yticks=list(np.arange(0, num_labels, 1)))
    plt.setp(ax13,
             yticks=list(np.arange(0, num_labels, 1)))

    for i in range(len(train_data)):
        for j in list(train_data[i].keys()):
                ax11.scatter(i, j, s=train_data[i][j]*0.9, c='darkblue')

    for i in range(len(val_data)):
        for j in list(val_data[i].keys()):
                ax12.scatter(i, j, s=val_data[i][j]*0.9, c='darkblue')

    for i in range(len(test_data)):
        for j in list(test_data[i].keys()):
                ax13.scatter(i, j, s=test_data[i][j]*0.9, c='darkblue')

    ax11.set_xlabel('Client')
    ax11.set_ylabel('Label')
    ax11.set_yticks(np.arange(0, num_labels, 5))
    #fig1.tight_layout()
    ax11.set_title('train Data-Stat')

    y_minor_ticks = np.arange(0, num_labels, 1)
    x_minor_ticks = np.arange(0, 30, 1)

    ax12.set_xlabel('Client')
    ax12.set_ylabel('Label')
    ax12.set_yticks(np.arange(0, num_labels, 5))
    # fig1.tight_layout()
    ax12.set_title('val Data-Stat')

    ax13.set_xlabel('Client')
    ax13.set_ylabel('Label')
    ax13.set_yticks(np.arange(0, num_labels, 5))
    # fig1.tight_layout()
    ax13.set_title('test Data-Stat')

    #plt.show()
    ax11.set_xticks(x_minor_ticks, minor=True)
    ax12.set_xticks(x_minor_ticks, minor=True)
    ax13.set_xticks(x_minor_ticks, minor=True)
    ax11.set_yticks(y_minor_ticks, minor=True)
    ax12.set_yticks(y_minor_ticks, minor=True)
    ax13.set_yticks(y_minor_ticks, minor=True)
    ax11.grid(which='both')
    ax12.grid(which='both')
    ax13.grid(which='both')

    train_ratio = args.train_ratio
    args.test_ratio = 1 - (args.val_ratio+args.train_ratio)
    plt.savefig('fig/data'+args.hetero+"_valr{}_testr{}_train{}.pdf".format(args.val_ratio, args.test_ratio, train_ratio))

    plt.clf()
    #legend_properties = {'weight': 'bold'}
    # plt.yscale('log')


def read_data_withval(train_data_dir, test_data_dir, args, alg):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    clients = []
    groups = []
    train_data = {}
    val_data = {}
    test_data = {}

    train_file_name = str(args.numusers)+'_mytrain_'+args.hetero+'.json'
    test_file_name = str(args.numusers)+'_mytest_' + args.hetero + '.json'
    RS = rv.RandomState(2022)

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith(train_file_name)]
    for f in train_files:
        file_path = os.path.join(train_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith(test_file_name)]
    for f in test_files:
        file_path = os.path.join(test_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_data.update(cdata['user_data'])

    clients = list(train_data.keys())

    choose = [0.1, 0.05, 0.2, 0.3, 0.5]
    p = [0.5, 0.4, 0.15, 0.1, 0.05]
    choose = [0.6]
    p = [1]

    val_ratio = RS.choice(choose, len(clients), True, p)
    test_ratio = RS.choice(choose, len(clients), True, p)
    train_ratio = RS.choice(choose, len(clients), True, p)
    data_statistics_train = {}
    data_statistics_val = {}
    data_statistics_test = {}
    ratios_val ={}
    ratios_train = {}
    ratios_test = {}

    new_test_data = {}

    for ii, client_id in enumerate(clients):


        tmp_dict = {}
        tmp_dict_test = {}

        tmp_train_data = train_data[client_id]['x']
        tmp_train_data_y = train_data[client_id]['y']

        val_ratio[ii] = args.val_ratio
        test_ratio[ii] = args.test_ratio
        train_ratio[ii] = args.train_ratio
        '''
        if len(train_data[client_id]['x'])>400:
            val_ratio[ii] = args.val_ratio
        elif 400>len(train_data[client_id]['x'])>100:
            val_ratio[ii] = args.val_ratio
        else:
            val_ratio[ii] = args.val_ratio
        '''

        val_size = int(val_ratio[ii]*len(train_data[client_id]['x']))
        train_size = int(train_ratio[ii] * len(train_data[client_id]['x']))
        test_size = int((test_ratio[ii]-0.1)*len(train_data[client_id]['x']))

        tmp_dict['x'] = tmp_train_data[:val_size] #+ test_data[client_id]['x']
        tmp_dict['y'] = tmp_train_data_y[:val_size] #+ test_data[client_id]['y']

        train_data[client_id]['x'] = tmp_train_data[val_size:val_size+train_size]
        train_data[client_id]['y'] = tmp_train_data_y[val_size:val_size+train_size]

        tmp_dict_test['x'] = tmp_train_data[val_size+train_size:] + test_data[client_id]['x']
        tmp_dict_test['y'] = tmp_train_data_y[val_size+train_size:] + test_data[client_id]['y']

        #train_data[client_id]['x'] = train_data[client_id]['x'][val_size:val_size+test_size]
        #train_data[client_id]['y'] = train_data[client_id]['y'][val_size:val_size+test_size]

        if val_size == 0:
            val_data[client_id] = train_data[client_id]
        else:
            val_data[client_id] = tmp_dict

        new_test_data[client_id] = tmp_dict_test
        #new_test_data[client_id] = test_data[client_id]

        data_statistics_train[ii] = dict(Counter(train_data[client_id]['y']))
        data_statistics_val[ii] = dict(Counter(val_data[client_id]['y']))
        data_statistics_test[ii] = dict(Counter(new_test_data[client_id]['y']))
        ratios_train[ii] = len(train_data[client_id]['y'])
        ratios_val[ii] = len(val_data[client_id]['y'])
        ratios_test[ii] = len(new_test_data[client_id]['y'])

    if alg == 'fedavg' and args.samedata:
        train_data[client_id]['x'] += val_data[client_id]['x']
        train_data[client_id]['y'] += val_data[client_id]['y']

        '''
        uname = 'f_{0:05d}'.format(i)
        A = np.array(self.train_data[uname]['x'])
        y = np.array(self.train_data[uname]['y'])
        sample_idx = np.random.choice(A.shape[0], self.bs)
        a = A[sample_idx]
        targets = np.zeros((self.bs, 10))
        targets[np.arange(self.bs), y[sample_idx].astype('int')] = 1
        grad = - a.T @ (targets - softmax(a @ x)) / self.bs
        grad[:61] += 10e-4 * self.central_parameter[:61]
        '''

    #print('Clients'+str(clients)): list of f_00000 , ...
    #print('Groups'+str(groups)) : empty
    #print('TrainData'+str(train_data))  # dictionaries of client id within x and y dictionary
    #print('TestData'+str(test_data))
    #print('ValData'+str(val_data))

    return clients, groups, train_data, new_test_data, val_data, data_statistics_train, data_statistics_val, data_statistics_test, ratios_train, ratios_val, ratios_test


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users

    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    clients = []
    groups = []
    train_data = {}
    val_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith('train.json')]
    for f in train_files:
        file_path = os.path.join(train_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith('test.json')]
    for f in test_files:
        file_path = os.path.join(test_data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_data.update(cdata['user_data'])


    clients = list(train_data.keys())

    # print('Clients'+str(clients)): list of f_00000 , ...
    # print('Groups'+str(groups)) : empty
    # print('TrainData'+str(train_data))  # dictionaries of client id within x and y dictionary
    # print('TestData'+str(test_data))
    # print('ValData'+str(val_data))

    return clients, groups, train_data, test_data

def main():
    train_data_dir = './data/'
    test_data_dir = './data/'
    clients, groups, train_data, test_data = read_data(train_data_dir, test_data_dir)
    
    #A = np.array(train_data['f_00010']['x'])
    #y = np.array(train_data['f_00010']['y'])
    #print(y)




if __name__ == "__main__":
    main()