{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "dataset = np.array([])\n",
    "for i in range(10):\n",
    "    d = i*np.ones(100)\n",
    "    dataset = np.append(dataset,d)\n",
    "\n",
    "num_users = 10\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4]\n",
      "[6]\n",
      "[3]\n",
      "[5]\n",
      "[0]\n",
      "[1]\n",
      "[2]\n",
      "[7]\n",
      "[8]\n",
      "[9]\n",
      "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.\n",
      " 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 8. 2.\n",
      " 2. 9. 0. 7. 3. 1. 5. 0. 6. 0. 6. 8. 4. 7. 4. 9. 9. 5. 3. 4. 3. 6. 6. 0.\n",
      " 2. 7. 3. 6.]\n",
      "[6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6.\n",
      " 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 6. 9. 6.\n",
      " 8. 4. 1. 6. 2. 8. 0. 8. 3. 9. 9. 6. 3. 5. 1. 2. 6. 9. 2. 6. 6. 5. 3. 2.\n",
      " 2. 9. 0. 7.]\n",
      "[3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.\n",
      " 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 5. 1.\n",
      " 2. 7. 8. 8. 1. 7. 7. 7. 7. 7. 8. 4. 2. 9. 7. 0. 8. 1. 4. 1. 1. 0. 1. 0.\n",
      " 5. 0. 0. 4.]\n",
      "[5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5.\n",
      " 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 5. 6. 5.\n",
      " 3. 9. 8. 5. 6. 6. 1. 0. 6. 7. 9. 4. 3. 3. 9. 3. 3. 8. 4. 1. 9. 5. 1. 0.\n",
      " 7. 2. 5. 7.]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 5.\n",
      " 3. 2. 4. 3. 6. 8. 0. 4. 3. 9. 9. 0. 2. 0. 1. 2. 6. 5. 9. 7. 5. 8. 6. 6.\n",
      " 1. 9. 8. 0.]\n",
      "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 5.\n",
      " 0. 7. 6. 5. 4. 6. 0. 0. 5. 6. 5. 7. 7. 0. 9. 0. 4. 4. 9. 4. 6. 5. 3. 3.\n",
      " 9. 2. 1. 3.]\n",
      "[2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.\n",
      " 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 8. 4.\n",
      " 7. 3. 7. 5. 3. 8. 6. 8. 1. 1. 8. 4. 8. 2. 7. 1. 4. 3. 5. 1. 5. 9. 9. 2.\n",
      " 8. 6. 8. 2.]\n",
      "[7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7.\n",
      " 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 7. 6. 4.\n",
      " 7. 8. 0. 1. 4. 7. 4. 7. 2. 7. 5. 5. 9. 8. 9. 7. 5. 7. 2. 4. 0. 3. 3. 2.\n",
      " 7. 0. 3. 1.]\n",
      "[8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.\n",
      " 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 9. 0.\n",
      " 0. 9. 2. 3. 3. 1. 0. 3. 6. 6. 7. 2. 0. 1. 8. 2. 2. 4. 5. 8. 2. 4. 5. 1.\n",
      " 4. 2. 1. 2.]\n",
      "[9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9.\n",
      " 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 9. 8.\n",
      " 3. 0. 1. 9. 5. 4. 2. 4. 2. 4. 6. 5. 5. 9. 3. 4. 1. 7. 8. 5. 8. 8. 8. 9.\n",
      " 1. 4. 6. 3.]\n"
     ]
    }
   ],
   "source": [
    "n_data = int(len(dataset)/num_users) #data per client\n",
    "label_list = range(10)\n",
    "\n",
    "idxs = np.arange(len(dataset),dtype=int)\n",
    "labels = dataset\n",
    "\n",
    "# sort labels\n",
    "idxs_labels = np.vstack((idxs, labels))\n",
    "idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]\n",
    "idxs = idxs_labels[0,:]\n",
    "\n",
    "idxs = idxs.astype(int)\n",
    "dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}\n",
    "\n",
    "#Sample majority class for each user\n",
    "p = 0.7 #fraction of majority class\n",
    "for i in range(num_users):\n",
    "    majority_label = np.random.choice(label_list, 1, replace = False)\n",
    "    label_list = list(set(label_list) - set(majority_label))\n",
    "    \n",
    "    print(majority_label)\n",
    "\n",
    "    majority_label_idxs = majority_label == dataset[idxs]\n",
    "    sub_data_idxs = np.random.choice(idxs[majority_label_idxs], int(p*n_data), replace = False)\n",
    "    \n",
    "    dict_users[i] = sub_data_idxs\n",
    "    idxs = np.array(list(set(idxs) - set(sub_data_idxs)))\n",
    "\n",
    "\n",
    "#Sample rest of data uniformly\n",
    "for i in range(num_users):\n",
    "    sub_data_idxs = np.random.choice(idxs, int((1-p)*n_data), replace = False)\n",
    "    dict_users[i] = np.concatenate((dict_users[i], sub_data_idxs))\n",
    "    idxs = np.array(list(set(idxs) - set(sub_data_idxs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.0 8.0\n",
      "45\n",
      "1.0 6.0\n",
      "45\n",
      "3.0 7.0\n",
      "45\n",
      "4.0 2.0\n",
      "45\n",
      "9.0 0.0\n",
      "45\n"
     ]
    }
   ],
   "source": [
    "n_data = 50\n",
    "num_users = 5\n",
    "idxs = np.arange(len(dataset),dtype=int)\n",
    "labels = dataset\n",
    "label_list = np.unique(dataset)\n",
    "\n",
    "# sort labels\n",
    "idxs_labels = np.vstack((idxs, labels))\n",
    "idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]\n",
    "#print(idxs_labels)\n",
    "idxs = idxs_labels[0,:]\n",
    "idxs = idxs.astype(int)\n",
    "\n",
    "dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}\n",
    "p = 0.1\n",
    "#Sample majority class for each user\n",
    "user_majority_labels = []\n",
    "for i in range(num_users):\n",
    "    majority_labels = np.random.choice(label_list, 2, replace = False)\n",
    "    user_majority_labels.append(majority_labels)\n",
    "\n",
    "    label_list = list(set(label_list) - set(majority_labels))\n",
    "\n",
    "    majority_label_idxs = (majority_labels[0] == labels[idxs]) | (majority_labels[1] == labels[idxs])\n",
    "    sub_data_idxs = np.random.choice(idxs[majority_label_idxs], int(p*n_data), replace = False)\n",
    "\n",
    "    dict_users[i] = np.concatenate((dict_users[i], sub_data_idxs))\n",
    "    idxs = np.array(list(set(idxs) - set(sub_data_idxs)))\n",
    "    \n",
    "    #perc = sum(majority_labels[0] == labels[dict_users[i]])/len(labels[dict_users[i]]) + sum(majority_labels[1] == labels[dict_users[i]])/len(labels[dict_users[i]])\n",
    "    #print(perc)\n",
    "    #print(int((p)*n_data))\n",
    "    \n",
    "if(p < 1.0):\n",
    "    for i in range(num_users):\n",
    "        majority_labels = user_majority_labels[i]\n",
    "\n",
    "        non_majority_label_idxs = (majority_labels[0] != labels[idxs]) | (majority_labels[1] != labels[idxs])\n",
    "        sub_data_idxs = np.random.choice(idxs[non_majority_label_idxs], int((1-p)*n_data), replace = False)\n",
    "        #print(idxs[non_majority_label_idxs])\n",
    "        dict_users[i] = np.concatenate((dict_users[i], sub_data_idxs))\n",
    "        idxs = np.array(list(set(idxs) - set(sub_data_idxs)))\n",
    "        print(majority_labels[0], majority_labels[1])\n",
    "        print(labels[sub_data_idxs])\n",
    "        #print(majority_labels[0])\n",
    "        #print(majority_labels[1])\n",
    "        #print(sum(majority_labels[0] == labels[dict_users[i]])/len(labels[dict_users[i]]))\n",
    "        #print(sum(majority_labels[1] == labels[dict_users[i]])/len(labels[dict_users[i]]))\n",
    "        #print(len(dict_users[i]))\n",
    "        perc = sum(majority_labels[0] == labels[dict_users[i]])/len(labels[dict_users[i]]) + sum(majority_labels[1] == labels[dict_users[i]])/len(labels[dict_users[i]])\n",
    "        #print(majority_labels[0])\n",
    "        #print(perc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
