{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import necessary libraries and modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%bash\n",
    "\n",
    "# kaggle datasets download -d nageshsingh/dna-sequence-dataset\n",
    "# mkdir data data/DNA\n",
    "# unzip dna-sequence-dataset.zip -d data/DNA\n",
    "# rm dna-sequence-dataset.zip\n",
    "\n",
    "# kaggle datasets download -d masoudnickparvar/brain-tumor-mri-dataset\n",
    "# mkdir -p data/MRI\n",
    "# unzip brain-tumor-mri-dataset.zip -d data/MRI\n",
    "# rm brain-tumor-mri-dataset.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import time\n",
    "from collections import OrderedDict\n",
    "from typing import (\n",
    "    List, Tuple, Dict, Optional, Callable, Union\n",
    ")\n",
    "import tenseal as ts\n",
    "\n",
    "import numpy as np\n",
    "import torchvision\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import flwr as fl\n",
    "from flwr.common import (\n",
    "    Metrics, EvaluateIns, EvaluateRes, FitIns, FitRes, MetricsAggregationFn, \n",
    "    Scalar, logger, ndarrays_to_parameters_custom, parameters_to_ndarrays_custom,\n",
    "    Parameters, NDArrays\n",
    ")\n",
    "from flwr.server.client_proxy import ClientProxy\n",
    "from flwr.server.client_manager import ClientManager\n",
    "from flwr.server.strategy.aggregate import weighted_loss_avg\n",
    "from logging import WARNING\n",
    "import pennylane as qml\n",
    "\n",
    "from utils import *\n",
    "\n",
    "os.environ['TOKENIZERS_PARALLELISM'] = 'false'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creation of FHE Keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combo_keys(client_path=\"secret.pkl\", server_path=\"server_key.pkl\"):\n",
    "    \"\"\"\n",
    "    To create the public/private keys combination\n",
    "    args:\n",
    "        client_path: path to save the secret key (str)\n",
    "        server_path: path to save the server public key (str)\n",
    "    \"\"\"\n",
    "    context_client = security.context()\n",
    "    security.write_query(client_path, {\"contexte\": context_client.serialize(save_secret_key=True)})\n",
    "    security.write_query(server_path, {\"contexte\": context_client.serialize()})\n",
    "\n",
    "    _, context_client = security.read_query(client_path)\n",
    "    _, context_server = security.read_query(server_path)\n",
    "\n",
    "    context_client = ts.context_from(context_client)\n",
    "    context_server = ts.context_from(context_server)\n",
    "    print(\"Is the client context private?\", (\"Yes\" if context_client.is_private() else \"No\"))\n",
    "    print(\"Is the server context private?\", (\"Yes\" if context_server.is_private() else \"No\"))\n",
    "\n",
    "\n",
    "secret_path = \"secret.pkl\"\n",
    "public_path = \"server_key.pkl\"\n",
    "if os.path.exists(secret_path):\n",
    "    print(\"it exists\")\n",
    "    _, context_client = security.read_query(secret_path)\n",
    "\n",
    "else:\n",
    "    combo_keys(client_path=secret_path, server_path=public_path) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Architecture Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Quantum device definitions for MRI and DNA\n",
    "mri_n_qubits = 4\n",
    "dna_n_qubits = 7\n",
    "expert_vector = (mri_n_qubits+dna_n_qubits) // 2 + 1\n",
    "num_of_expert = 2\n",
    "\n",
    "# Define the MRI network\n",
    "class MRINet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MRINet, self).__init__()\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 16, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "            nn.Conv2d(16, 32, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(32 * 56 * 56, 128),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(128, expert_vector),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "\n",
    "class DNANet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DNANet, self).__init__()        \n",
    "        self.fc1 = nn.Linear(input_sp, 64)\n",
    "        self.fc2 = nn.Linear(64, 32)\n",
    "        self.fc3 = nn.Linear(32, 16)\n",
    "        self.fc4 = nn.Linear(16, 8)\n",
    "        self.fc5 = nn.Linear(8, expert_vector)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Forward pass of the neural network\n",
    "        \"\"\"\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = F.relu(self.fc3(x))\n",
    "        x = F.relu(self.fc4(x))\n",
    "        x = self.fc5(x)\n",
    "        return x\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, num_classes_mri, num_classes_dna):\n",
    "        super(Net, self).__init__()\n",
    "        self.mri_net = MRINet()\n",
    "        self.dna_net = DNANet()\n",
    "        \n",
    "        self.feature_dim = expert_vector\n",
    "        self.num_heads = expert_vector\n",
    "        \n",
    "        self.attention = nn.MultiheadAttention(embed_dim=num_of_expert*self.feature_dim, num_heads=self.num_heads)\n",
    "        self.fc_gate = nn.Linear(num_of_expert*self.feature_dim, 2) \n",
    "        self.fc2_mri = nn.Linear(self.feature_dim, num_classes_mri)\n",
    "        self.fc2_dna = nn.Linear(self.feature_dim, num_classes_dna)\n",
    "        \n",
    "    def forward(self, mri_input, dna_input):\n",
    "        mri_features = self.mri_net(mri_input)\n",
    "        dna_features = self.dna_net(dna_input)\n",
    "        combined_features = torch.cat((mri_features, dna_features), dim=1)\n",
    "        combined_features = combined_features.unsqueeze(0)\n",
    "        attn_output, _ = self.attention(combined_features, combined_features, combined_features)\n",
    "        attn_output = attn_output.squeeze(0)\n",
    "        gates = F.softmax(self.fc_gate(attn_output), dim=1)\n",
    "        combined_output = (gates[:, 0].unsqueeze(1) * mri_features + \n",
    "                           gates[:, 1].unsqueeze(1) * dna_features)\n",
    "        mri_output = self.fc2_mri(combined_output)\n",
    "        dna_output = self.fc2_dna(combined_output)\n",
    "        return mri_output, dna_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the FlowerClient class for federated learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlowerClient(fl.client.NumPyClient):\n",
    "    def __init__(self, cid, net, trainloader, valloader, device, batch_size, save_results, matrix_path, roc_path,\n",
    "                 yaml_path, he, classes, context_client):\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "        self.cid = cid\n",
    "        self.device = device\n",
    "        self.batch_size = batch_size\n",
    "        self.save_results = save_results\n",
    "        self.matrix_path = matrix_path\n",
    "        self.roc_path = roc_path\n",
    "        self.yaml_path = yaml_path\n",
    "        self.he = he\n",
    "        self.classes = classes\n",
    "        self.context_client = context_client\n",
    "\n",
    "    def get_parameters(self, config):\n",
    "        print(f\"[Client {self.cid}] get_parameters\")\n",
    "        return get_parameters2(self.net, self.context_client)\n",
    "\n",
    "    def fit(self, parameters, config):\n",
    "        server_round = config['server_round']\n",
    "        local_epochs = config['local_epochs']\n",
    "        lr = float(config[\"learning_rate\"])\n",
    "\n",
    "        print(f'[Client {self.cid}, round {server_round}] fit, config: {config}')\n",
    "\n",
    "        set_parameters(self.net, parameters, self.context_client)\n",
    "\n",
    "        criterion_mri = torch.nn.CrossEntropyLoss()\n",
    "        criterion_dna = torch.nn.CrossEntropyLoss()\n",
    "        optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)\n",
    "\n",
    "        results = engine.train(self.net, self.trainloader, self.valloader, optimizer=optimizer, loss_fn=(criterion_mri, criterion_dna),\n",
    "                               epochs=local_epochs, device=self.device, task=\"Multimodal\")\n",
    "\n",
    "        if self.save_results:\n",
    "            print\n",
    "            save_graphs_multimodal(self.save_results, local_epochs, results, f\"_Client {self.cid}\")\n",
    "\n",
    "        return get_parameters2(self.net, self.context_client), len(self.trainloader), {}\n",
    "\n",
    "    def evaluate(self, parameters, config):\n",
    "        print(f\"[Client {self.cid}] evaluate, config: {config}\")\n",
    "        set_parameters(self.net, parameters, self.context_client)\n",
    "\n",
    "        loss, accuracy, y_pred, y_true, y_proba = engine.test_multimodal_health(self.net, self.valloader,\n",
    "                                                              loss_fn=(torch.nn.CrossEntropyLoss(),torch.nn.CrossEntropyLoss()), device=self.device)\n",
    "\n",
    "        loss_mri, loss_dna = loss\n",
    "        accuracy_mri, accuracy_dna = accuracy\n",
    "        y_pred_mri, y_pred_dna = y_pred\n",
    "        y_true_mri, y_true_dna = y_true\n",
    "        y_proba_mri, y_proba_dna = y_proba\n",
    "\n",
    "        if self.save_results:\n",
    "            os.makedirs(self.save_results, exist_ok=True)\n",
    "            if self.matrix_path:\n",
    "                save_matrix(y_true_mri, y_pred_mri, self.save_results + \"MRI_\" + self.matrix_path, self.classes[0])\n",
    "                save_matrix(y_true_dna, y_pred_dna, self.save_results + \"DNA_\" + self.matrix_path, self.classes[1])\n",
    "            if self.roc_path:\n",
    "                save_roc(y_true_mri, y_proba_mri, self.save_results + \"MRI_\" + self.roc_path, len(self.classes[0]))\n",
    "                save_roc(y_true_dna, y_proba_dna, self.save_results + \"DNA_\" + self.roc_path, len(self.classes[1]))\n",
    "\n",
    "        return float(loss_mri), len(self.valloader), {\"accuracy\": float(accuracy_mri)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the client_common function to set up the Flower client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def client_common(cid, model_save, path_yaml, path_roc, results_save, path_matrix,\n",
    "                  batch_size, trainloaders, valloaders, DEVICE, CLASSES,\n",
    "                  he=False, secret_path=\"\", server_path=\"\"):\n",
    "    trainloader = trainloaders[int(cid)]\n",
    "    valloader = valloaders[int(cid)]\n",
    "\n",
    "    context_client = None\n",
    "    net = Net(len(CLASSES[0]), len(CLASSES[1])).to(DEVICE)\n",
    "\n",
    "    if he:\n",
    "        print(\"Run with homomorphic encryption\")\n",
    "        if os.path.exists(secret_path):\n",
    "            with open(secret_path, 'rb') as f:\n",
    "                query = pickle.load(f)\n",
    "            context_client = ts.context_from(query[\"contexte\"])\n",
    "        else:\n",
    "            context_client = security.context()\n",
    "            with open(secret_path, 'wb') as f:\n",
    "                encode = pickle.dumps({\"contexte\": context_client.serialize(save_secret_key=True)})\n",
    "                f.write(encode)\n",
    "        secret_key = context_client.secret_key()\n",
    "    else:\n",
    "        print(\"Run WITHOUT homomorphic encryption\")\n",
    "\n",
    "    if os.path.exists(model_save):\n",
    "        print(\" To get the checkpoint\")\n",
    "        checkpoint = torch.load(model_save, map_location=DEVICE)['model_state_dict']\n",
    "        if he:\n",
    "            print(\"to decrypt model\")\n",
    "            server_query, server_context = security.read_query(server_path)\n",
    "            server_context = ts.context_from(server_context)\n",
    "            for name in checkpoint:\n",
    "                print(name)\n",
    "                checkpoint[name] = torch.tensor(\n",
    "                    security.deserialized_layer(name, server_query[name], server_context).decrypt(secret_key)\n",
    "                )\n",
    "        net.load_state_dict(checkpoint)\n",
    "\n",
    "    return FlowerClient(cid, net, trainloader, valloader, device=DEVICE, batch_size=batch_size,\n",
    "                        matrix_path=path_matrix, roc_path=path_roc, save_results=results_save, yaml_path=path_yaml,\n",
    "                        he=he, context_client=context_client, classes=CLASSES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define utility functions for federated learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:\n",
    "    accuracies = [num_examples * m[\"accuracy\"] for num_examples, m in metrics]\n",
    "    examples = [num_examples for num_examples, _ in metrics]\n",
    "    return {\"accuracy\": sum(accuracies) / sum(examples)}\n",
    "\n",
    "def evaluate2(server_round: int, parameters: NDArrays,\n",
    "              config: Dict[str, Scalar]) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
    "    set_parameters(central, parameters)\n",
    "    loss, accuracy, y_pred, y_true, y_proba = engine.test_multimodal_health(central, testloader, loss_fn=(torch.nn.CrossEntropyLoss(),torch.nn.CrossEntropyLoss()),\n",
    "                                                          device=DEVICE)\n",
    "    loss_mri, loss_dna = loss\n",
    "    accuracy_mri, accuracy_dna = accuracy\n",
    "    y_pred_mri, y_pred_dna = y_pred\n",
    "    y_true_mri, y_true_dna = y_true\n",
    "    y_proba_mri, y_proba_dna = y_proba\n",
    "\n",
    "    print(f\"Server-side evaluation MRI loss {loss_mri} / MRI accuracy {accuracy_mri}\")\n",
    "    print(f\"Server-side evaluation DNA loss {loss_dna} / DNA accuracy {accuracy_dna}\")\n",
    "    return (loss_mri, loss_dna), {\"accuracy\": (accuracy_mri, accuracy_dna)}\n",
    "\n",
    "def get_on_fit_config_fn(epoch=2, lr=0.001, batch_size=32) -> Callable[[int], Dict[str, str]]:\n",
    "    def fit_config(server_round: int) -> Dict[str, str]:\n",
    "        config = {\n",
    "            \"learning_rate\": str(lr),\n",
    "            \"batch_size\": str(batch_size),\n",
    "            \"server_round\": server_round,\n",
    "            \"local_epochs\": epoch\n",
    "        }\n",
    "        return config\n",
    "    return fit_config\n",
    "\n",
    "def aggreg_fit_checkpoint(server_round, aggregated_parameters, central_model, path_checkpoint,\n",
    "                          context_client=None, server_path=\"\"):\n",
    "    if aggregated_parameters is not None:\n",
    "        print(f\"Saving round {server_round} aggregated_parameters...\")\n",
    "        aggregated_ndarrays: List[np.ndarray] = parameters_to_ndarrays_custom(aggregated_parameters, context_client)\n",
    "        if context_client:   \n",
    "            server_response = {\"contexte\": server_context.serialize()}\n",
    "            for i, key in enumerate(central_model.state_dict().keys()):\n",
    "                try:\n",
    "                    server_response[key] = aggregated_ndarrays[i].serialize()\n",
    "                except:\n",
    "                    server_response[key] = aggregated_ndarrays[i]\n",
    "            security.write_query(server_path, server_response)\n",
    "        else:\n",
    "            params_dict = zip(central_model.state_dict().keys(), aggregated_ndarrays)\n",
    "            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})\n",
    "            central_model.load_state_dict(state_dict, strict=True)\n",
    "            if path_checkpoint:\n",
    "                torch.save({\n",
    "                    'model_state_dict': central_model.state_dict(),\n",
    "                }, path_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the FedCustom strategy class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A Strategy from scratch with the same sampling of the clients as it is in FedAvg\n",
    "# and then change the configuration dictionary\n",
    "class FedCustom(fl.server.strategy.Strategy):\n",
    "    def __init__(\n",
    "            self,\n",
    "            fraction_fit: float = 1.0,\n",
    "            fraction_evaluate: float = 1.0,\n",
    "            min_fit_clients: int = 2,\n",
    "            min_evaluate_clients: int = 2,\n",
    "            min_available_clients: int = 2,\n",
    "            evaluate_fn: Optional[\n",
    "                    Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]\n",
    "                ] = None,\n",
    "            on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "            on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,\n",
    "            accept_failures: bool = True,\n",
    "            initial_parameters: Optional[Parameters] = None,\n",
    "            fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "            evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,\n",
    "            context_client=None\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        self.fraction_fit = fraction_fit\n",
    "        self.fraction_evaluate = fraction_evaluate\n",
    "        self.min_fit_clients = min_fit_clients\n",
    "        self.min_evaluate_clients = min_evaluate_clients\n",
    "        self.min_available_clients = min_available_clients\n",
    "        self.evaluate_fn = evaluate_fn\n",
    "        self.on_fit_config_fn = on_fit_config_fn\n",
    "        self.on_evaluate_config_fn = on_evaluate_config_fn,\n",
    "        self.accept_failures = accept_failures\n",
    "        self.initial_parameters = initial_parameters\n",
    "        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn\n",
    "        self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn\n",
    "        self.context_client = context_client\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        return f\"FedCustom (accept_failures={self.accept_failures})\"\n",
    "\n",
    "    def initialize_parameters(\n",
    "        self, client_manager: ClientManager\n",
    "    ) -> Optional[Parameters]:\n",
    "        \"\"\"Initialize global model parameters.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        initial_parameters = self.initial_parameters\n",
    "        self.initial_parameters = None  # Don't keep initial parameters in memory\n",
    "        return initial_parameters\n",
    "\n",
    "    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
    "        \"\"\"Return sample size and required number of clients.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        num_clients = int(num_available_clients * self.fraction_fit)\n",
    "        return max(num_clients, self.min_fit_clients), self.min_available_clients\n",
    "\n",
    "    def configure_fit(\n",
    "        self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
    "    ) -> List[Tuple[ClientProxy, FitIns]]:\n",
    "        \"\"\"Configure the next round of training.\"\"\"\n",
    "        # Sample clients\n",
    "        sample_size, min_num_clients = self.num_fit_clients(\n",
    "            client_manager.num_available()\n",
    "        )\n",
    "\n",
    "        clients = client_manager.sample(\n",
    "            num_clients=sample_size, min_num_clients=min_num_clients\n",
    "        )\n",
    "        # Create custom configs\n",
    "        n_clients = len(clients)\n",
    "        half_clients = n_clients // 2\n",
    "        # Custom fit config function provided\n",
    "        standard_lr = lr\n",
    "        higher_lr = 0.003\n",
    "        config = {\"server_round\": server_round, \"local_epochs\": 1}\n",
    "        if self.on_fit_config_fn is not None:\n",
    "            # Custom fit config function provided\n",
    "            config = self.on_fit_config_fn(server_round)\n",
    "\n",
    "        # fit_ins = FitIns(parameters, config)\n",
    "        # Return client/config pairs\n",
    "        fit_configurations = []\n",
    "        for idx, client in enumerate(clients):\n",
    "            config[\"learning_rate\"] = standard_lr if idx < half_clients else higher_lr\n",
    "            \"\"\"\n",
    "            Each pair of (ClientProxy, FitRes) constitutes \n",
    "            a successful update from one of the previously selected clients.\n",
    "            \"\"\"\n",
    "            fit_configurations.append(\n",
    "                (\n",
    "                    client,\n",
    "                    FitIns(\n",
    "                        parameters,\n",
    "                        config\n",
    "                    )\n",
    "                )\n",
    "            )\n",
    "        # Successful updates from the previously selected and configured clients\n",
    "        return fit_configurations\n",
    "\n",
    "    def aggregate_fit(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, FitRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n",
    "    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate fit results using weighted average. (each round)\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        if not results:\n",
    "            return None, {}\n",
    "\n",
    "        # Do not aggregate if there are failures and failures are not accepted\n",
    "        if not self.accept_failures and failures:\n",
    "            return None, {}\n",
    "\n",
    "        # Convert results parameters --> array matrix\n",
    "        weights_results = [\n",
    "            (parameters_to_ndarrays_custom(fit_res.parameters, self.context_client), fit_res.num_examples)\n",
    "            for _, fit_res in results\n",
    "        ]\n",
    "\n",
    "        # Aggregate parameters using weighted average between the clients and convert back to parameters object (bytes)\n",
    "        parameters_aggregated = ndarrays_to_parameters_custom(aggregate_custom(weights_results))\n",
    "\n",
    "        metrics_aggregated = {}\n",
    "        # Aggregate custom metrics if aggregation fn was provided\n",
    "        if self.fit_metrics_aggregation_fn:\n",
    "            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]\n",
    "            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)\n",
    "\n",
    "        elif server_round == 1:  # Only log this warning once\n",
    "            logger.log(WARNING, \"No fit_metrics_aggregation_fn provided\")\n",
    "\n",
    "        # Same function as SaveModelStrategy(fl.server.strategy.FedAvg)\n",
    "        \"\"\"Aggregate model weights using weighted average and store checkpoint\"\"\"\n",
    "        aggreg_fit_checkpoint(server_round, parameters_aggregated, central, model_save,\n",
    "                              self.context_client, path_crypted)\n",
    "        return parameters_aggregated, metrics_aggregated\n",
    "\n",
    "    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
    "        \"\"\"Use a fraction of available clients for evaluation.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        num_clients = int(num_available_clients * self.fraction_evaluate)\n",
    "        return max(num_clients, self.min_evaluate_clients), self.min_available_clients\n",
    "\n",
    "    def configure_evaluate(\n",
    "        self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
    "    ) -> List[Tuple[ClientProxy, EvaluateIns]]:\n",
    "        \"\"\"Configure the next round of evaluation.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        # Do not configure federated evaluation if fraction eval is 0.\n",
    "        if self.fraction_evaluate == 0.0:\n",
    "            return []\n",
    "\n",
    "        # Parameters and config\n",
    "        config = {}  # {\"server_round\": server_round, \"local_epochs\": 1}\n",
    "\n",
    "        evaluate_ins = EvaluateIns(parameters, config)\n",
    "\n",
    "        # Sample clients\n",
    "        sample_size, min_num_clients = self.num_evaluation_clients(\n",
    "            client_manager.num_available()\n",
    "        )\n",
    "\n",
    "        clients = client_manager.sample(\n",
    "            num_clients=sample_size, min_num_clients=min_num_clients\n",
    "        )\n",
    "\n",
    "        # Return client/config pairs\n",
    "        # Each pair of (ClientProxy, FitRes) constitutes a successful update from one of the previously selected clients\n",
    "        return [(client, evaluate_ins) for client in clients]\n",
    "\n",
    "    def aggregate_evaluate(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, EvaluateRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],\n",
    "    ) -> Tuple[Optional[float], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate evaluation losses using weighted average.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        if not results:\n",
    "            return None, {}\n",
    "\n",
    "        # Do not aggregate if there are failures and failures are not accepted\n",
    "        if not self.accept_failures and failures:\n",
    "            return None, {}\n",
    "\n",
    "        # Aggregate loss\n",
    "        loss_aggregated = weighted_loss_avg(\n",
    "            [\n",
    "                (evaluate_res.num_examples, evaluate_res.loss)\n",
    "                for _, evaluate_res in results\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        metrics_aggregated = {}\n",
    "        # Aggregate custom metrics if aggregation fn was provided\n",
    "        if self.evaluate_metrics_aggregation_fn:\n",
    "            eval_metrics = [(res.num_examples, res.metrics) for _, res in results]\n",
    "            metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)\n",
    "\n",
    "        # Only log this warning once\n",
    "        elif server_round == 1:\n",
    "            logger.log(WARNING, \"No evaluate_metrics_aggregation_fn provided\")\n",
    "\n",
    "        return loss_aggregated, metrics_aggregated\n",
    "\n",
    "    def evaluate(\n",
    "        self, server_round: int, parameters: Parameters\n",
    "    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
    "        \"\"\"Evaluate global model parameters using an evaluation function.\"\"\"\n",
    "        # Same function as FedAvg(Strategy)\n",
    "        if self.evaluate_fn is None:\n",
    "            # Let's assume we won't perform the global model evaluation on the server side.\n",
    "            return None\n",
    "\n",
    "        # if we have a global model evaluation on the server side :\n",
    "        parameters_ndarrays = parameters_to_ndarrays_custom(parameters, self.context_client)\n",
    "        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})\n",
    "\n",
    "        # if you haven't results\n",
    "        if eval_res is None:\n",
    "            return None\n",
    "\n",
    "        loss, metrics = eval_res\n",
    "        return loss, metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up the federated learning strategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up your variables directly\n",
    "he = False\n",
    "data_path = 'data/'\n",
    "dataset = 'DNA+MRI'\n",
    "yaml_path = './results_classical_standard/FL_DNA+MRI/results.yml'\n",
    "seed = 0\n",
    "num_workers = 0\n",
    "max_epochs = 10\n",
    "batch_size = 32\n",
    "splitter = 10\n",
    "device = 'gpu'\n",
    "number_clients = 10\n",
    "save_results = 'results_classical_standard/FL_DNA+MRI/'\n",
    "matrix_path = 'confusion_matrix.png'\n",
    "roc_path = 'roc.png'\n",
    "model_save = 'results_classical_standard/FL_DNA+MRI/DNA+MRI_fl.pt'\n",
    "min_fit_clients = 10\n",
    "min_avail_clients = 10\n",
    "min_eval_clients = 10\n",
    "rounds = 20\n",
    "frac_fit = 1.0\n",
    "frac_eval = 0.5\n",
    "lr = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server_context = None\n",
    "DEVICE = torch.device(choice_device(device))\n",
    "CLASSES = classes_string(dataset)\n",
    "trainloaders, valloaders, testloader = data_setup.load_datasets(num_clients=number_clients, batch_size=batch_size, resize=224, seed=seed, num_workers=num_workers, splitter=splitter, dataset=dataset, data_path=data_path, data_path_val=None)\n",
    "_, input_sp = next(iter(testloader))[0][1].shape\n",
    "central = Net(len(CLASSES[0]), len(CLASSES[1])).to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy = FedCustom(\n",
    "    fraction_fit=frac_fit,\n",
    "    fraction_evaluate=frac_eval,\n",
    "    min_fit_clients=min_fit_clients,\n",
    "    min_evaluate_clients=min_eval_clients if min_eval_clients else number_clients // 2,\n",
    "    min_available_clients=min_avail_clients,\n",
    "    evaluate_metrics_aggregation_fn=weighted_average,\n",
    "    initial_parameters=ndarrays_to_parameters_custom(get_parameters2(central)),\n",
    "    evaluate_fn=None if he else evaluate2,\n",
    "    on_fit_config_fn=get_on_fit_config_fn(epoch=max_epochs, batch_size=batch_size),\n",
    "    context_client=server_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloaders, valloaders, testloader = data_setup.load_datasets(num_clients=number_clients,\n",
    "                                                                batch_size=batch_size,\n",
    "                                                                resize=224,\n",
    "                                                                seed=seed,\n",
    "                                                                num_workers=num_workers,\n",
    "                                                                splitter=splitter,\n",
    "                                                                dataset=dataset,  # Use the specified dataset\n",
    "                                                                data_path=data_path,\n",
    "                                                                data_path_val=None)  # Use the same path for validation data\n",
    "\n",
    "def client_fn(cid: str) -> FlowerClient:\n",
    "    return client_common(cid,\n",
    "                         model_save, path_yaml, path_roc, results_save, path_matrix,\n",
    "                         batch_size, trainloaders, valloaders, DEVICE, CLASSES, he, secret_path, server_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the client_fn function and set up the simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "print(\"flwr\", fl.__version__)\n",
    "print(\"numpy\", np.__version__)\n",
    "print(\"torch\", torch.__version__)\n",
    "print(\"torchvision\", torchvision.__version__)\n",
    "print(f\"Training on {DEVICE}\")\n",
    "\n",
    "client_resources = None\n",
    "\n",
    "if DEVICE.type == \"cuda\":\n",
    "    client_resources = {\"num_gpus\": 1}\n",
    "\n",
    "model_save = model_save\n",
    "path_yaml = yaml_path\n",
    "path_roc = roc_path\n",
    "results_save = save_results\n",
    "path_matrix = matrix_path\n",
    "batch_size = batch_size\n",
    "he = he\n",
    "secret_path = 'secret.pkl'\n",
    "server_path = 'secret.pkl'\n",
    "path_crypted = 'server.pkl'\n",
    "\n",
    "print(\"Start simulation\")\n",
    "start_simulation = time.time()\n",
    "fl.simulation.start_simulation(\n",
    "    client_fn=client_fn,\n",
    "    num_clients=number_clients,\n",
    "    config=fl.server.ServerConfig(num_rounds=rounds),\n",
    "    strategy=strategy,\n",
    "    client_resources=client_resources\n",
    ")\n",
    "print(f\"Simulation Time = {time.time() - start_simulation} seconds\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
