{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Client and NumPyClient\n",
    "\n",
    "Welcome to the fourth part of the Flower federated learning tutorial. In the previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html)), we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.dev/docs/tutorial/Flower-2-Strategies-in-FL-PyTorch.html)), and we built our own custom strategy from scratch ([part 3 - WIP](https://flower.dev/docs/tutorial/Flower-3-Building-a-Strategy-PyTorch.html)).\n",
    "\n",
    "In this notebook, we revisit `NumPyClient` and introduce a new baseclass for building clients, simply named `Client`. In previous parts of this tutorial, we've based our client on `NumPyClient`, a convenience class which makes it easy to work with machine learning libraries that have good NumPy interoperability. With `Client`, we gain a lot of flexibility that we didn't have before, but we'll also have to do a few things the we didn't have to do before.\n",
    "\n",
    "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n",
    "\n",
    "Let's go deeper and see what it takes to move from `NumPyClient` to `Client`!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 0: Preparation\n",
    "\n",
    "Before we begin with the actual code, let's make sure that we have everything we need."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installing dependencies\n",
    "\n",
    "First, we install the necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q flwr[simulation] torch torchvision scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have all dependencies installed, we can import everything we need for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from typing import Dict, List, Optional, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchvision.datasets import CIFAR10\n",
    "\n",
    "import flwr as fl\n",
    "\n",
    "DEVICE = torch.device(\"cpu\")  # Try \"cuda\" to train on GPU\n",
    "print(\n",
    "    f\"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data loading\n",
    "\n",
    "Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLIENTS = 10\n",
    "\n",
    "\n",
    "def load_datasets(num_clients: int):\n",
    "    # Download and transform CIFAR-10 (train and test)\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
    "    )\n",
    "    trainset = CIFAR10(\"./dataset\", train=True, download=True, transform=transform)\n",
    "    testset = CIFAR10(\"./dataset\", train=False, download=True, transform=transform)\n",
    "\n",
    "    # Split training set into `num_clients` partitions to simulate different local datasets\n",
    "    partition_size = len(trainset) // num_clients\n",
    "    lengths = [partition_size] * num_clients\n",
    "    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))\n",
    "\n",
    "    # Split each partition into train/val and create DataLoader\n",
    "    trainloaders = []\n",
    "    valloaders = []\n",
    "    for ds in datasets:\n",
    "        len_val = len(ds) // 10  # 10 % validation set\n",
    "        len_train = len(ds) - len_val\n",
    "        lengths = [len_train, len_val]\n",
    "        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))\n",
    "        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))\n",
    "        valloaders.append(DataLoader(ds_val, batch_size=32))\n",
    "    testloader = DataLoader(testset, batch_size=32)\n",
    "    return trainloaders, valloaders, testloader\n",
    "\n",
    "\n",
    "trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model training/evaluation\n",
    "\n",
    "Let's continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self) -> None:\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def get_parameters(net) -> List[np.ndarray]:\n",
    "    return [val.cpu().numpy() for _, val in net.state_dict().items()]\n",
    "\n",
    "\n",
    "def set_parameters(net, parameters: List[np.ndarray]):\n",
    "    params_dict = zip(net.state_dict().keys(), parameters)\n",
    "    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
    "    net.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "\n",
    "def train(net, trainloader, epochs: int):\n",
    "    \"\"\"Train the network on the training set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(net.parameters())\n",
    "    net.train()\n",
    "    for epoch in range(epochs):\n",
    "        correct, total, epoch_loss = 0, 0, 0.0\n",
    "        for images, labels in trainloader:\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net(images)\n",
    "            loss = criterion(net(images), labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            # Metrics\n",
    "            epoch_loss += loss\n",
    "            total += labels.size(0)\n",
    "            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n",
    "        epoch_loss /= len(trainloader.dataset)\n",
    "        epoch_acc = correct / total\n",
    "        print(f\"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}\")\n",
    "\n",
    "\n",
    "def test(net, testloader):\n",
    "    \"\"\"Evaluate the network on the entire test set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    correct, total, loss = 0, 0, 0.0\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        for images, labels in testloader:\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            outputs = net(images)\n",
    "            loss += criterion(outputs, labels).item()\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    loss /= len(testloader.dataset)\n",
    "    accuracy = correct / total\n",
    "    return loss, accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Revisiting NumPyClient\n",
    "\n",
    "So far, we've implemented our client by subclassing `flwr.client.NumPyClient`. The three methods we implemented are `get_parameters`, `fit`, and `evaluate`. Finally, we wrap the creation of instances of this class in a function called `client_fn`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlowerNumPyClient(fl.client.NumPyClient):\n",
    "    def __init__(self, cid, net, trainloader, valloader):\n",
    "        self.cid = cid\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, config):\n",
    "        print(f\"[Client {self.cid}] get_parameters\")\n",
    "        return get_parameters(self.net)\n",
    "\n",
    "    def fit(self, parameters, config):\n",
    "        print(f\"[Client {self.cid}] fit, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        return get_parameters(self.net), len(self.trainloader), {}\n",
    "\n",
    "    def evaluate(self, parameters, config):\n",
    "        print(f\"[Client {self.cid}] evaluate, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "        return float(loss), len(self.valloader), {\"accuracy\": float(accuracy)}\n",
    "\n",
    "\n",
    "def numpyclient_fn(cid) -> FlowerNumPyClient:\n",
    "    net = Net().to(DEVICE)\n",
    "    trainloader = trainloaders[int(cid)]\n",
    "    valloader = valloaders[int(cid)]\n",
    "    return FlowerNumPyClient(cid, net, trainloader, valloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've seen this before, there's nothing new so far. The only *tiny* difference compared to the previous notebook is naming, we've changed `FlowerClient` to `FlowerNumPyClient` and `client_fn` to `numpyclient_fn`. Let's run it to see the output we get:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)\n",
    "client_resources = None\n",
    "if DEVICE.type == \"cuda\":\n",
    "    client_resources = {\"num_gpus\": 1}\n",
    "\n",
    "fl.simulation.start_simulation(\n",
    "    client_fn=numpyclient_fn,\n",
    "    num_clients=2,\n",
    "    config=fl.server.ServerConfig(num_rounds=3),\n",
    "    client_resources=client_resources,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This works as expected, two clients are training for three rounds of federated learning.\n",
    "\n",
    "Let's dive a little bit deeper and discuss how Flower executes this simulation. Whenever a client is selected to do some work, `start_simulation` calls the function `numpyclient_fn` to create an instance of our `FlowerNumPyClient` (along with loading the model and the data).\n",
    "\n",
    "But here's the perhaps surprising part: Flower doesn't actually use the `FlowerNumPyClient` object directly. Instead, it wraps the object to makes it look like a subclass of `flwr.client.Client`, not `flwr.client.NumPyClient`. In fact, the Flower core framework doesn't know how to handle `NumPyClient`'s, it only knows how to handle `Client`'s. `NumPyClient` is just a convenience abstraction built on top of `Client`. \n",
    "\n",
    "Instead of building on top of `NumPyClient`, we can directly build on top of `Client`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Moving from `NumPyClient` to `Client`\n",
    "\n",
    "Let's try to do the same thing using `Client` instead of `NumPyClient`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr.common import (\n",
    "    Code,\n",
    "    EvaluateIns,\n",
    "    EvaluateRes,\n",
    "    FitIns,\n",
    "    FitRes,\n",
    "    GetParametersIns,\n",
    "    GetParametersRes,\n",
    "    Status,\n",
    "    ndarrays_to_parameters,\n",
    "    parameters_to_ndarrays,\n",
    ")\n",
    "\n",
    "\n",
    "class FlowerClient(fl.client.Client):\n",
    "    def __init__(self, cid, net, trainloader, valloader):\n",
    "        self.cid = cid\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:\n",
    "        print(f\"[Client {self.cid}] get_parameters\")\n",
    "\n",
    "        # Get parameters as a list of NumPy ndarray's\n",
    "        ndarrays: List[np.ndarray] = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object\n",
    "        parameters = ndarrays_to_parameters(ndarrays)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return GetParametersRes(\n",
    "            status=status,\n",
    "            parameters=parameters,\n",
    "        )\n",
    "\n",
    "    def fit(self, ins: FitIns) -> FitRes:\n",
    "        print(f\"[Client {self.cid}] fit, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        # Update local model, train, get updated parameters\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        ndarrays_updated = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object\n",
    "        parameters_updated = ndarrays_to_parameters(ndarrays_updated)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return FitRes(\n",
    "            status=status,\n",
    "            parameters=parameters_updated,\n",
    "            num_examples=len(self.trainloader),\n",
    "            metrics={},\n",
    "        )\n",
    "\n",
    "    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:\n",
    "        print(f\"[Client {self.cid}] evaluate, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "        # return float(loss), len(self.valloader), {\"accuracy\": float(accuracy)}\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return EvaluateRes(\n",
    "            status=status,\n",
    "            loss=float(loss),\n",
    "            num_examples=len(self.valloader),\n",
    "            metrics={\"accuracy\": float(accuracy)},\n",
    "        )\n",
    "\n",
    "\n",
    "def client_fn(cid) -> FlowerClient:\n",
    "    net = Net().to(DEVICE)\n",
    "    trainloader = trainloaders[int(cid)]\n",
    "    valloader = valloaders[int(cid)]\n",
    "    return FlowerClient(cid, net, trainloader, valloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we discuss the code in more detail, let's try to run it! Gotta make sure our new `Client`-based client works, right?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fl.simulation.start_simulation(\n",
    "    client_fn=client_fn,\n",
    "    num_clients=2,\n",
    "    config=fl.server.ServerConfig(num_rounds=3),\n",
    "    client_resources=client_resources,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it, we're now using `Client`. It probably looks similar to what we've done with `NumPyClient`. So what's the difference?\n",
    "\n",
    "First of all, it's more code. But why? The difference comes from the fact that `Client` expects us to take care of parameter serialization and deserialization. For Flower to be able to send parameters over the network, it eventually needs to turn these parameters into `bytes`. Turning parameters (e.g., NumPy `ndarray`'s) into raw bytes is called serialization. Turning raw bytes into something more useful (like NumPy `ndarray`'s) is called deserialization. Flower needs to do both: it needs to serialize parameters on the server-side and send them to the client, the client needs to deserialize them to use them for local training, and then serialize the updated parameters again to send them back to the server, which (finally!) deserializes them again in order to aggregate them with the updates received from other clients.\n",
    "\n",
    "The only *real* difference between Client and NumPyClient is that NumPyClient takes care of serialization and deserialization for you. It can do so because it expects you to return parameters as NumPy ndarray's, and it knows how to handle these. This makes working with machine learning libraries that have good NumPy support (most of them) a breeze.\n",
    "\n",
    "In terms of API, there's one major difference: all methods in Client take exactly one argument (e.g., `FitIns` in `Client.fit`) and return exactly one value (e.g., `FitRes` in `Client.fit`). The methods in `NumPyClient` on the other hand have multiple arguments (e.g., `parameters` and `config` in `NumPyClient.fit`) and multiple return values (e.g., `parameters`, `num_example`, and `metrics` in `NumPyClient.fit`) if there are multiple things to handle. These `*Ins` and `*Res` objects in `Client` wrap all the individual values you're used to from `NumPyClient`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Custom serialization\n",
    "\n",
    "Here we will explore how to implement custom serialization with a simple example.\n",
    "\n",
    "But first what is serialization? Serialization is just the process of converting an object into raw bytes, and equally as important,\n",
    "deserialization is the process of converting raw bytes back into an object. This is very useful for network communication.\n",
    "Indeed, without serialization, you could not just a Python object through the internet.\n",
    "\n",
    "Federated Learning relies heavily on internet communication for training by sending Python objects back and forth between the clients and \n",
    "the server. This means that serialization is an essential part of Federated Learning.\n",
    "\n",
    "In the following section, we will write a basic example where instead of sending a serialized version of our `ndarray`s containing our parameters,\n",
    "we will first convert the `ndarray` into sparse matrices, before sending them. This technique can be used to save bandwidth, as in certain cases\n",
    "where the weights of a model are sparse (containing many 0 entries), converting them to a sparse matrix can greatly improve their bytesize.\n",
    "\n",
    "### Our custom serialization/deserialization functions\n",
    "\n",
    "This is where the real serialization/deserialization will happen, especially in `ndarray_to_sparse_bytes` for serialization and \n",
    "`sparse_bytes_to_ndarray` for deserialization.\n",
    "\n",
    "Note that we imported the `scipy.sparse` library in order to convert our arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import BytesIO\n",
    "from typing import cast\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from flwr.common.typing import NDArray, NDArrays, Parameters\n",
    "\n",
    "\n",
    "def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:\n",
    "    \"\"\"Convert NumPy ndarrays to parameters object.\"\"\"\n",
    "    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]\n",
    "    return Parameters(tensors=tensors, tensor_type=\"numpy.ndarray\")\n",
    "\n",
    "\n",
    "def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:\n",
    "    \"\"\"Convert parameters object to NumPy ndarrays.\"\"\"\n",
    "    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]\n",
    "\n",
    "\n",
    "def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:\n",
    "    \"\"\"Serialize NumPy ndarray to bytes.\"\"\"\n",
    "    bytes_io = BytesIO()\n",
    "\n",
    "    if len(ndarray.shape) > 1:\n",
    "        # We convert our ndarray into a sparse matrix\n",
    "        ndarray = torch.tensor(ndarray).to_sparse_csr()\n",
    "\n",
    "        # And send it by utilizng the sparse matrix attributes\n",
    "        # WARNING: NEVER set allow_pickle to true.\n",
    "        # Reason: loading pickled data can execute arbitrary code\n",
    "        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n",
    "        np.savez(\n",
    "            bytes_io,  # type: ignore\n",
    "            crow_indices=ndarray.crow_indices(),\n",
    "            col_indices=ndarray.col_indices(),\n",
    "            values=ndarray.values(),\n",
    "            allow_pickle=False,\n",
    "        )\n",
    "    else:\n",
    "        # WARNING: NEVER set allow_pickle to true.\n",
    "        # Reason: loading pickled data can execute arbitrary code\n",
    "        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html\n",
    "        np.save(bytes_io, ndarray, allow_pickle=False)\n",
    "    return bytes_io.getvalue()\n",
    "\n",
    "\n",
    "def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:\n",
    "    \"\"\"Deserialize NumPy ndarray from bytes.\"\"\"\n",
    "    bytes_io = BytesIO(tensor)\n",
    "    # WARNING: NEVER set allow_pickle to true.\n",
    "    # Reason: loading pickled data can execute arbitrary code\n",
    "    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html\n",
    "    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore\n",
    "\n",
    "    if \"crow_indices\" in loader:\n",
    "        # We convert our sparse matrix back to a ndarray, using the attributes we sent\n",
    "        ndarray_deserialized = (\n",
    "            torch.sparse_csr_tensor(\n",
    "                crow_indices=loader[\"crow_indices\"],\n",
    "                col_indices=loader[\"col_indices\"],\n",
    "                values=loader[\"values\"],\n",
    "            )\n",
    "            .to_dense()\n",
    "            .numpy()\n",
    "        )\n",
    "    else:\n",
    "        ndarray_deserialized = loader\n",
    "    return cast(NDArray, ndarray_deserialized)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Client-side\n",
    "\n",
    "To be able to able to serialize our `ndarray`s into sparse parameters, we will just have to call our custom functions in our `flwr.client.Client`.\n",
    "\n",
    "Indeed, in `get_parameters` we need to serialize the parameters we got from our network using our custom `ndarrays_to_sparse_parameters` defined above.\n",
    "\n",
    "In `fit`, we first need to deserialize the parameters coming from the server using our custom `sparse_parameters_to_ndarrays` and then we need to \n",
    "serialize our local results with `ndarrays_to_sparse_parameters`.\n",
    "\n",
    "In `evaluate`, we will only need to deserialize the global parameters with our custom function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr.common import (\n",
    "    Code,\n",
    "    EvaluateIns,\n",
    "    EvaluateRes,\n",
    "    FitIns,\n",
    "    FitRes,\n",
    "    GetParametersIns,\n",
    "    GetParametersRes,\n",
    "    Status,\n",
    ")\n",
    "\n",
    "\n",
    "class FlowerClient(fl.client.Client):\n",
    "    def __init__(self, cid, net, trainloader, valloader):\n",
    "        self.cid = cid\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:\n",
    "        print(f\"[Client {self.cid}] get_parameters\")\n",
    "\n",
    "        # Get parameters as a list of NumPy ndarray's\n",
    "        ndarrays: List[np.ndarray] = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object using our custom function\n",
    "        parameters = ndarrays_to_sparse_parameters(ndarrays)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return GetParametersRes(\n",
    "            status=status,\n",
    "            parameters=parameters,\n",
    "        )\n",
    "\n",
    "    def fit(self, ins: FitIns) -> FitRes:\n",
    "        print(f\"[Client {self.cid}] fit, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's using our custom function\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        # Update local model, train, get updated parameters\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        ndarrays_updated = get_parameters(self.net)\n",
    "\n",
    "        # Serialize ndarray's into a Parameters object using our custom function\n",
    "        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return FitRes(\n",
    "            status=status,\n",
    "            parameters=parameters_updated,\n",
    "            num_examples=len(self.trainloader),\n",
    "            metrics={},\n",
    "        )\n",
    "\n",
    "    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:\n",
    "        print(f\"[Client {self.cid}] evaluate, config: {ins.config}\")\n",
    "\n",
    "        # Deserialize parameters to NumPy ndarray's using our custom function\n",
    "        parameters_original = ins.parameters\n",
    "        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)\n",
    "\n",
    "        set_parameters(self.net, ndarrays_original)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "\n",
    "        # Build and return response\n",
    "        status = Status(code=Code.OK, message=\"Success\")\n",
    "        return EvaluateRes(\n",
    "            status=status,\n",
    "            loss=float(loss),\n",
    "            num_examples=len(self.valloader),\n",
    "            metrics={\"accuracy\": float(accuracy)},\n",
    "        )\n",
    "\n",
    "\n",
    "def client_fn(cid) -> FlowerClient:\n",
    "    net = Net().to(DEVICE)\n",
    "    trainloader = trainloaders[int(cid)]\n",
    "    valloader = valloaders[int(cid)]\n",
    "    return FlowerClient(cid, net, trainloader, valloader)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Server-side\n",
    "\n",
    "For this example, we will just use `FedAvg` as a strategy. \n",
    "To change the serialization and deserialization here, we only need to reimplement the `evaluate` and `aggregate_fit` functions of `FedAvg`.\n",
    "The other functions of the strategy will be inherited from the super class `FedAvg`.\n",
    "\n",
    "As you can see only one line as change in `evaluate`:\n",
    "\n",
    "```python\n",
    "parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)\n",
    "```\n",
    "\n",
    "And for `aggregate_fit`, we will first deserialize every result we received:\n",
    "\n",
    "```python\n",
    "weights_results = [\n",
    "    (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
    "    for _, fit_res in results\n",
    "]\n",
    "```\n",
    "\n",
    "And then serialize the aggregated result:\n",
    "\n",
    "```python\n",
    "parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from logging import WARNING\n",
    "from typing import Callable, Dict, List, Optional, Tuple, Union\n",
    "\n",
    "from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar\n",
    "from flwr.common.logger import log\n",
    "from flwr.server.client_proxy import ClientProxy\n",
    "from flwr.server.strategy import FedAvg\n",
    "from flwr.server.strategy.aggregate import aggregate\n",
    "\n",
    "WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = \"\"\"\n",
    "Setting `min_available_clients` lower than `min_fit_clients` or\n",
    "`min_evaluate_clients` can cause the server to fail when there are too few clients\n",
    "connected to the server. `min_available_clients` must be set to a value larger\n",
    "than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class FedSparse(FedAvg):\n",
    "    def __init__(\n",
    "        self,\n",
    "        *,\n",
    "        fraction_fit: float = 1.0,\n",
    "        fraction_evaluate: float = 1.0,\n",
    "        min_fit_clients: int = 2,\n",
    "        min_evaluate_clients: int = 2,\n",
    "        min_available_clients: int = 2,\n",
    "        evaluate_fn: Optional[\n",
    "            Callable[\n",
    "                [int, NDArrays, Dict[str, Scalar]],\n",
    "                Optional[Tuple[float, Dict[str, Scalar]]],\n",
    "            ]\n",
    "        ] = None,\n",
    "        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "        accept_failures: bool = True,\n",
    "        initial_parameters: Optional[Parameters] = None,\n",
    "        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "    ) -> None:\n",
    "        \"\"\"Custom FedAvg strategy with sparse matrices.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        fraction_fit : float, optional\n",
    "            Fraction of clients used during training. Defaults to 0.1.\n",
    "        fraction_evaluate : float, optional\n",
    "            Fraction of clients used during validation. Defaults to 0.1.\n",
    "        min_fit_clients : int, optional\n",
    "            Minimum number of clients used during training. Defaults to 2.\n",
    "        min_evaluate_clients : int, optional\n",
    "            Minimum number of clients used during validation. Defaults to 2.\n",
    "        min_available_clients : int, optional\n",
    "            Minimum number of total clients in the system. Defaults to 2.\n",
    "        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]\n",
    "            Optional function used for validation. Defaults to None.\n",
    "        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional\n",
    "            Function used to configure training. Defaults to None.\n",
    "        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional\n",
    "            Function used to configure validation. Defaults to None.\n",
    "        accept_failures : bool, optional\n",
    "            Whether or not accept rounds containing failures. Defaults to True.\n",
    "        initial_parameters : Parameters, optional\n",
    "            Initial global model parameters.\n",
    "        \"\"\"\n",
    "\n",
    "        if (\n",
    "            min_fit_clients > min_available_clients\n",
    "            or min_evaluate_clients > min_available_clients\n",
    "        ):\n",
    "            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)\n",
    "\n",
    "        super().__init__(\n",
    "            fraction_fit=fraction_fit,\n",
    "            fraction_evaluate=fraction_evaluate,\n",
    "            min_fit_clients=min_fit_clients,\n",
    "            min_evaluate_clients=min_evaluate_clients,\n",
    "            min_available_clients=min_available_clients,\n",
    "            evaluate_fn=evaluate_fn,\n",
    "            on_fit_config_fn=on_fit_config_fn,\n",
    "            on_evaluate_config_fn=on_evaluate_config_fn,\n",
    "            accept_failures=accept_failures,\n",
    "            initial_parameters=initial_parameters,\n",
    "            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,\n",
    "            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,\n",
    "        )\n",
    "\n",
    "    def evaluate(\n",
    "        self, server_round: int, parameters: Parameters\n",
    "    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
    "        \"\"\"Evaluate model parameters using an evaluation function.\"\"\"\n",
    "        if self.evaluate_fn is None:\n",
    "            # No evaluation function provided\n",
    "            return None\n",
    "\n",
    "        # We deserialize using our custom method\n",
    "        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)\n",
    "\n",
    "        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})\n",
    "        if eval_res is None:\n",
    "            return None\n",
    "        loss, metrics = eval_res\n",
    "        return loss, metrics\n",
    "\n",
    "    def aggregate_fit(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, FitRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n",
    "    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
    "        if not results:\n",
    "            return None, {}\n",
    "        # Do not aggregate if there are failures and failures are not accepted\n",
    "        if not self.accept_failures and failures:\n",
    "            return None, {}\n",
    "\n",
    "        # We deserialize each of the results with our custom method\n",
    "        weights_results = [\n",
    "            (sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
    "            for _, fit_res in results\n",
    "        ]\n",
    "\n",
    "        # We serialize the aggregated result using our cutom method\n",
    "        parameters_aggregated = ndarrays_to_sparse_parameters(\n",
    "            aggregate(weights_results)\n",
    "        )\n",
    "\n",
    "        # Aggregate custom metrics if aggregation fn was provided\n",
    "        metrics_aggregated = {}\n",
    "        if self.fit_metrics_aggregation_fn:\n",
    "            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]\n",
    "            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)\n",
    "        elif server_round == 1:  # Only log this warning once\n",
    "            log(WARNING, \"No fit_metrics_aggregation_fn provided\")\n",
    "\n",
    "        return parameters_aggregated, metrics_aggregated"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now run our custom serialization example!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy = FedSparse()\n",
    "\n",
    "fl.simulation.start_simulation(\n",
    "    strategy=strategy,\n",
    "    client_fn=client_fn,\n",
    "    num_clients=2,\n",
    "    config=fl.server.ServerConfig(num_rounds=3),\n",
    "    client_resources=client_resources,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recap\n",
    "\n",
    "In this part of the tutorial, we've seen how we can build clients by subclassing either `NumPyClient` or `Client`. `NumPyClient` is a convenience abstraction that makes it easier to work with machine learning libraries that have good NumPy interoperability. `Client` is a more flexible abstraction that allows us to do things that are not possible in `NumPyClient`. In order to do so, it requires us to handle parameter serialization and deserialization ourselves."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "Before you continue, make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)\n",
    "\n",
    "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n",
    "\n",
    "This is the final part of the Flower tutorial (for now!), congratulations! You're now well equipped to understand the rest of the documentation. There are many topics we didn't cover in the tutorial, we recommend the following resources:\n",
    "\n",
    "- [Read Flower Docs](https://flower.dev/docs/)\n",
    "- [Check out Flower Code Examples](https://github.com/adap/flower/tree/main/examples)\n",
    "- [Use Flower Baselines for your research](https://flower.dev/docs/using-baselines.html)\n",
    "- [Watch Flower Summit 2022 videos](https://flower.dev/conf/flower-summit-2022/)\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Flower-4-Client-and-NumPyClient-PyTorch.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "flower-3.7.12",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
