import pandas as pd
import numpy as np
import os
from logging import info

# def csv2xy(path):
#     x_all, y_all = [],[]
#     for filename in os.listdir(path):
#         if filename.endswith(".csv") :
#             df = pd.read_csv (os.path.join(path, filename),header=None)
#             x = df.iloc[:,1:]
#             if x.isnull().values.any():
#                 print(filename+" NaN in x, skip")
#                 continue
#             y =  df.iloc[:, 0]
#             x_all.append(x)
#             y_all.append(y)
#         else:
#             continue
#     x_arr = np.concatenate(x_all, axis=0 )
#     y_arr = np.concatenate(y_all, axis=0 )
#     return x_arr, y_arr

def load_all_csv_data(train_data_path, test_data_path):
  '''
    Load preprocessed data from csv
  '''
  def build_test_dataset(path):
      df = pd.read_csv (path,header=None)
      y =  df.iloc[:, 0]
      # y = to_categorical(y, num_classes=vocab_size)
      x = df.iloc[:,1:]
      return x, y

  # load training data into a dict: {key: data_chunk}
  # the key contains info about the data_chunk: the user it belongs to and its label
  def build_train_dataset(path):
    train_data_dict = {}
    all_uids = set()
    for filename in os.listdir(path):
        if filename.endswith(".csv") :
            all_uids.add(filename)
            df = pd.read_csv (os.path.join(path, filename),header=None)
            y =  df.iloc[:, 0]
      # y = to_categorical(y, num_classes=vocab_size)
            x = df.iloc[:,1:]
            train_data_dict[filename] = (x, y)
         # print(os.path.join(directory, filename))
        else:
            continue
    return train_data_dict, sorted(list(all_uids))
  train_data_dict, all_uids = build_train_dataset(train_data_path)
  # info('train x shape: {}'.format(train_data_dict[class_available_data_keys[0][0]][0].shape))
  # info('train y shape: {}'.format(train_data_dict[class_available_data_keys[0][0]][1].shape))

  x_test, y_test = build_test_dataset(test_data_path)
  info('test x shape: {}'.format(x_test.shape))
  info('test y shape: {}'.format(y_test.shape))

  return train_data_dict, all_uids, x_test, y_test

def csv_to_dict(train_data_path):
  '''
    Load preprocessed data from csv
  '''

  # load training data into a dict: {key: data_chunk}
  # the key contains info about the data_chunk: the user it belongs to and its label
  def build_train_dataset(path):
    train_data_dict = {}
    all_uids = set()
    for filename in os.listdir(path):
        if filename.endswith(".csv") :
            all_uids.add(filename)
            df = pd.read_csv (os.path.join(path, filename),header=None)
            y =  df.iloc[:, 0]
      # y = to_categorical(y, num_classes=vocab_size)
            x = df.iloc[:,1:]
            n_vocab = len(y.unique())
            train_data_dict[filename] = (x, y, n_vocab)
         # print(os.path.join(directory, filename))
        else:
            continue
    return train_data_dict, sorted(list(all_uids))
  train_data_dict, all_uids = build_train_dataset(train_data_path)
  # info('train x shape: {}'.format(train_data_dict[class_available_data_keys[0][0]][0].shape))
  # info('train y shape: {}'.format(train_data_dict[class_available_data_keys[0][0]][1].shape))

  return train_data_dict, all_uids

def load_single_csv(path):
    df = pd.read_csv (path,header=None)
    y =  df.iloc[:, 0]
    # y = to_categorical(y, num_classes=vocab_size)
    x = df.iloc[:,1:]
    return x, y

def nan_in_list(myarray):
    # myarray = client_model_weights[uid][key]
    for i in myarray:
        if np.isnan(i).any():
            return True
    return False

def fed_avg_grad(master_model_weights, client_model_weights, sample_clients, lr=1):
  '''
    Basic fedavg
    average the trained weights and
    return the updated values for the master_model_weights
  '''
  # uids = [c for c in sample_clients]
  print('Start aggregation.')
  new_master_weights = []
  diff={}
  mean_diff={}
  # print(type(master_model_weights))
  for key in range(len(master_model_weights)):
    # diff = [client_model_weights[uid][key] - master_model_weights[key] for uid in sample_clients]
    diff[key] = []

    for uid in sample_clients:
        myarray = client_model_weights[uid][key]
        if nan_in_list(myarray):
            print('%s layer %d has nan, skip it.' %(uid, key))
            continue           
        # for i in myarray:
        #     if np.isnan(i).any():
        try:
            diff[key].append([a_i - b_i for a_i, b_i in zip(client_model_weights[uid][key], master_model_weights[key])])
        except:
            print (uid+' fails to compute diff.')
            continue

    mean_diff[key] = [sum(x) / len(x) for x in zip(*diff[key])]
    # print(master_model_weights[key])
    # print(mean_diff[key])
    # new_master_weights[key] = [a+lr*b for a, b in zip(master_model_weights[key],mean_diff[key])]
    res = np.array([a+lr*b for a, b in zip(master_model_weights[key],mean_diff[key])])
    # print(res.shape)
    # print(master_model_weights[key].shape)
    new_master_weights.append(res)
    # print(new_master_weights[key])
    # print('layer %d aggregated.' %(key))
    # master_model_weights[key] + lr * mean_diff
  return new_master_weights

def fed_avg(client_model_weights, sample_clients, lr=1):
  '''
    Basic fedavg
    average the trained weights and
    return the updated values for the master_model_weights
  '''
  # uids = [c for c in sample_clients]
  print('Start aggregation.')
  new_master_weights = []
  w={}
  mean_diff={}
  # print(type(master_model_weights))
  for key in range(len(client_model_weights[sample_clients[0]])):
    # diff = [client_model_weights[uid][key] - master_model_weights[key] for uid in sample_clients]
    w[key] = []

    for uid in sample_clients:
        myarray = client_model_weights[uid][key]
        if nan_in_list(myarray):
            print('%s layer %d has nan, skip it.' %(uid, key))
            continue           
        w[key].append([a for a in client_model_weights[uid][key]])
#         try:
#             diff[key].append([a_i - b_i for a_i, b_i in zip(client_model_weights[uid][key], master_model_weights[key])])
#         except:
#             print (uid+' fails to compute diff.')
#             continue
    mean_diff[key] = [sum(x) / len(x) for x in zip(*w[key])]    
    res = np.array([a for a in mean_diff[key]])
#     res = np.array([a+lr*b for a, b in zip(master_model_weights[key],mean_diff[key])])
    new_master_weights.append(res)

  return new_master_weights

def weight_avg(w1, w2):
  '''
    Basic fedavg
    average the trained weights and
    return the updated values for the master_model_weights
  '''
  # uids = [c for c in sample_clients]
  print('Start average.')
  new_master_weights = []
  w={}
  mean={}
  # print(type(master_model_weights))
  for key in range(len(w1)):
    # diff = [client_model_weights[uid][key] - master_model_weights[key] for uid in sample_clients]
    w[key] = []        
    w[key].append([a for a in w1[key]])
    w[key].append([a for a in w2[key]])
    mean[key] = [sum(x) / len(x) for x in zip(*w[key])]    
    res = np.array([a for a in mean[key]])
#     res = np.array([a+lr*b for a, b in zip(master_model_weights[key],mean_diff[key])])
    new_master_weights.append(res)

  return new_master_weights

def weight_list_avg(w_list):
  '''
    Basic fedavg
    average the trained weights and
    return the updated values for the master_model_weights
  '''
  # uids = [c for c in sample_clients]
  print('Start weight list average.')
  new_master_weights = []
  w={}
  mean={}
  # print(type(master_model_weights))
  for key in range(len(w_list[0])):
    # diff = [client_model_weights[uid][key] - master_model_weights[key] for uid in sample_clients]
    w[key] = []
    for i in range(len(w_list)):
        w[key].append([a for a in w_list[i][key]])
        # w[key].append([a for a in w2[key]])
    mean[key] = [sum(x) / len(x) for x in zip(*w[key])]    
    res = np.array([a for a in mean[key]])
#     res = np.array([a+lr*b for a, b in zip(master_model_weights[key],mean_diff[key])])
    new_master_weights.append(res)

  return new_master_weights