{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-01 01:59:27.790879: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-05-01 01:59:27.791003: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-05-01 01:59:27.791055: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-05-01 01:59:27.800374: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-05-01 01:59:28.875440: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import csv\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "from tqdm.notebook import tqdm\n",
    "import tensorflow as tf\n",
    "import tensorflow_federated as tff\n",
    "BASE_DIR = \"..\"\n",
    "TEST_DIR = \"test_dir\"\n",
    "sys.path.insert(0, BASE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c56b0098198c4ecfa859aa43e9ed4acd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/715 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45fe9de7413740b6a86c8c805f0e30f9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/715 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "shakespeare_train, shakespeare_test = tff.simulation.datasets.shakespeare.load_data()\n",
    "fed_datasets = {\"train\": shakespeare_train, \"test\": shakespeare_test}\n",
    "client_sizes = {\"train\": defaultdict(int), \"test\": defaultdict(int)}\n",
    "\n",
    "for split, fed_dataset in fed_datasets.items():\n",
    "    print(split)\n",
    "    for client_id in tqdm(fed_dataset.client_ids):\n",
    "        dataloader = fed_dataset.create_tf_dataset_for_client(client_id)\n",
    "        for data in dataloader:\n",
    "            num_chars = len(tf.strings.bytes_split(data['snippets']))\n",
    "            client_sizes[split][client_id] += num_chars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in client_sizes.keys():\n",
    "    fname = os.path.join(BASE_DIR, 'dataset_statistics', f'shakespeare_client_sizes_{split}.csv')\n",
    "    pd.DataFrame.from_dict(data=client_sizes[split], orient='index').to_csv(fname, header=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Weird clients that have 0 sizes??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': [],\n",
       " 'test': ['ALL_S_WELL_THAT_ENDS_WELL_DROMIO_OF_EPHESUS',\n",
       "  'ALL_S_WELL_THAT_ENDS_WELL_HERALD',\n",
       "  'ALL_S_WELL_THAT_ENDS_WELL_OF_EPHESUS',\n",
       "  'ALL_S_WELL_THAT_ENDS_WELL_PATRICIANS',\n",
       "  'ALL_S_WELL_THAT_ENDS_WELL_SILIUS',\n",
       "  'PERICLES__PRINCE_OF_TYRE_CHILDREN',\n",
       "  'PERICLES__PRINCE_OF_TYRE_GROOM',\n",
       "  'PERICLES__PRINCE_OF_TYRE_PAGE',\n",
       "  'PERICLES__PRINCE_OF_TYRE_SERVINGMAN',\n",
       "  'PERICLES__PRINCE_OF_TYRE_SURREY',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_BOTH',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ERPINGHAM',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GARGRAVE',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_HUNTSMAN',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_JUSTICE',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_LAWYER',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MICHAEL',\n",
       "  'THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_SHERIFF',\n",
       "  'THE_TAMING_OF_THE_SHREW_ALL_SERVANTS',\n",
       "  'THE_TAMING_OF_THE_SHREW_BANDITTI',\n",
       "  'THE_TAMING_OF_THE_SHREW_FRANCISCO',\n",
       "  'THE_TAMING_OF_THE_SHREW_PHILOTUS',\n",
       "  'THE_TAMING_OF_THE_SHREW_TIMANDRA',\n",
       "  'THE_TRAGEDY_OF_KING_LEAR_LION',\n",
       "  'THE_TRAGEDY_OF_KING_LEAR_SEYTON',\n",
       "  'THE_TRAGEDY_OF_KING_LEAR_STARVELING']}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "weird_clients = {}\n",
    "for split in client_sizes.keys():\n",
    "    weird_clients[split] = [cid for cid in client_sizes[split].keys() if client_sizes[split][cid] == 0]\n",
    "display(weird_clients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'Worth six on him.', shape=(), dtype=string)\n",
      "tf.Tensor(b\"Lepidus is high-colour'd.\", shape=(), dtype=string)\n",
      "tf.Tensor(b\"man i' th' world.\\nWho, my master?\", shape=(), dtype=string)\n",
      "tf.Tensor(b'Why, then we shall have a stirring world again.', shape=(), dtype=string)\n",
      "tf.Tensor(b'As they pinch one another by the disposition, he', shape=(), dtype=string)\n",
      "tf.Tensor(b\"Are you so brave? I'll have you talk'd with anon.\", shape=(), dtype=string)\n",
      "tf.Tensor(b'Faith, look you, one cannot tell how to say that;', shape=(), dtype=string)\n",
      "tf.Tensor(b\"Here, sir; I'd have beaten him like a dog, but for\", shape=(), dtype=string)\n",
      "tf.Tensor(b'a ravisher, so it cannot be denied but peace is a great maker of\\ncuckolds.\\n', shape=(), dtype=string)\n",
      "tf.Tensor(b\"broil'd and eaten him too.\\nAnd he's as like to do't as any man I can imagine.\", shape=(), dtype=string)\n",
      "tf.Tensor(b'disturbing the lords within.\\nBy my hand, I had thought to have strucken him with', shape=(), dtype=string)\n",
      "tf.Tensor(b'head that he gives entrance to such companions? Pray get you out.\\nAway? Get you away.', shape=(), dtype=string)\n",
      "tf.Tensor(b'hard for him, I have heard him say so himself.\\nAn he had been cannibally given, he might have', shape=(), dtype=string)\n",
      "tf.Tensor(b'      Enter a third SERVINGMAN. The first meets him\\nAnd I shall.                              Exit', shape=(), dtype=string)\n",
      "tf.Tensor(b'for the defence of a town our general is excellent.\\nCome, we are fellows and friends. He was ever too', shape=(), dtype=string)\n",
      "tf.Tensor(b\"him; he had, sir, a kind of face, methought- I cannot tell how to\\nterm it.\\nSo did I, I'll be sworn. He is simply the rarest\", shape=(), dtype=string)\n",
      "tf.Tensor(b\"cries out 'No more!'; reconciles them to his entreaty and himself\\nto th' drink.\\nWhy, this it is to have a name in great men's\", shape=(), dtype=string)\n",
      "tf.Tensor(b'a cudgel; and yet my mind gave me his clothes made a false report\\nof him.\\nNay, I knew by his face that there was something in', shape=(), dtype=string)\n",
      "tf.Tensor(b\"This peace is nothing but to rust iron, increase tailors, and\\nbreed ballad-makers.\\n'Tis so; and as war in some sort may be said to be\", shape=(), dtype=string)\n",
      "tf.Tensor(b'fellowship. I had as lief have a reed that will do me no service\\nas a partizan I could not heave.\\nWhence are you, sir? Has the porter his eyes in his', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# a random sample\n",
    "for split in client_sizes.keys():\n",
    "        dataloader = fed_datasets[split].create_tf_dataset_for_client(fed_dataset.client_ids[123])\n",
    "        for data in dataloader:\n",
    "            print(data['snippets'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n",
      "tf.Tensor(b'', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# weird clients\n",
    "for split in client_sizes.keys():\n",
    "    for weird_client in weird_clients[split]:\n",
    "        dataloader = fed_datasets[split].create_tf_dataset_for_client(weird_client)\n",
    "        for data in dataloader:\n",
    "            print(data['snippets'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "for split in client_sizes.keys():\n",
    "    for weird_client in weird_clients[split]:\n",
    "        dataloader = fed_datasets[split].create_tf_dataset_for_client(weird_client)\n",
    "        for data in dataloader:\n",
    "            num_chars = len(tf.strings.bytes_split(data['snippets']))\n",
    "            print(num_chars)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Just remove them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in client_sizes.keys():\n",
    "    for weird_client in weird_clients[split]:\n",
    "        del client_sizes[split][weird_client]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': [], 'test': []}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "weird_clients = {}\n",
    "for split in client_sizes.keys():\n",
    "    weird_clients[split] = [cid for cid in client_sizes[split].keys() if client_sizes[split][cid] == 0]\n",
    "display(weird_clients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in client_sizes.keys():\n",
    "    fname = os.path.join(BASE_DIR, 'dataset_statistics', f'shakespeare_client_sizes_{split}.csv')\n",
    "    pd.DataFrame.from_dict(data=client_sizes[split], orient='index').to_csv(fname, header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
