{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This notebook is a simulation of implementing FedDCA in the financial domain, where client num is 10, each client's base data size is 100, and each client's retrieval num is 5000. I hope this notebook will help you gain a deeper understanding of FedDCA!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img src=\"./src/overview.png\" alt=\"\" width=\"1000\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import random\n",
    "import heapq\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from contextlib import contextmanager\n",
    "from pprint import pprint as pp\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.cluster import KMeans\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "import datasets\n",
    "from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset\n",
    "import pandas as pd\n",
    "from FlagEmbedding import FlagModel\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.cluster import MiniBatchKMeans, KMeans"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load pretrained encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = FlagModel('BAAI/bge-large-en-v1.5', \n",
    "                  query_instruction_for_retrieval=\"\",\n",
    "                  use_fp16=True,\n",
    "                )\n",
    "# model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', model_kwargs={\"torch_dtype\":torch.bfloat16})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Composition of public datasets(code, medical, financial, mathematical, general)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_data = load_dataset(\"sahil2801/CodeAlpaca-20k\")[\"train\"]\n",
    "fin_data = load_dataset(\"FinGPT/fingpt-sentiment-train\")[\"train\"]\n",
    "med_data = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\")[\"train\"]\n",
    "general_data = load_dataset(\"tatsu-lab/alpaca\")[\"train\"]\n",
    "math_data = load_dataset(\"TIGER-Lab/MathInstruct\")[\"train\"]\n",
    "\n",
    "def alpaca_format(example):\n",
    "    if example['input'] == \"\":\n",
    "        example[\"instruction\"] = example[\"instruction\"]\n",
    "    else:\n",
    "        example[\"instruction\"] = example[\"instruction\"] + \" \" + example['input']\n",
    "    example[\"response\"] = example['output']\n",
    "    return example\n",
    "\n",
    "def process_sft_dataset(dataset_name, dataset, dataset_sample=None)->datasets.Dataset:\n",
    "    if dataset_name in [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"yahma/alpaca-cleaned\", \"FinGPT/fingpt-sentiment-train\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"sujet-ai/Sujet-Finance-Instruct-177k\"]:\n",
    "        dataset = dataset.filter(lambda example: example['task_type'] == \"qa\")\n",
    "        dataset = dataset.rename_column(\"inputs\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"answer\", \"response\")\n",
    "    elif dataset_name in [\"WizardLM/WizardLM_evol_instruct_70k\"]:\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif dataset_name in [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\", \"gbharti/finance-alpaca\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"TIGER-Lab/MathInstruct\"]:\n",
    "        df = pd.DataFrame(dataset)\n",
    "        df = df.drop_duplicates(subset=['instruction'])\n",
    "        dataset = datasets.Dataset.from_pandas(df)\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "    elif dataset_name in [\"lighteval/MATH\"]:\n",
    "        dataset = dataset.rename_column(\"solution\", \"response\")\n",
    "        dataset = dataset.rename_column(\"problem\", \"instruction\")\n",
    "        dataset = dataset.remove_columns(['level', 'type'])\n",
    "    elif dataset_name in ['gsm8k']:\n",
    "        dataset = dataset.rename_column(\"question\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"answer\", \"response\")\n",
    "    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']:     \n",
    "        dataset = dataset.remove_columns(['instruction'])\n",
    "        dataset = dataset.rename_column(\"input\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif \"math\" in dataset_name:\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Dataset {dataset_name} is not supported.\")\n",
    "    dataset = dataset.shuffle(seed=42)\n",
    "    if dataset_sample:\n",
    "        num_sample = min(len(dataset), dataset_sample)\n",
    "        dataset = dataset.select(range(num_sample))\n",
    "    print(f\">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====\")\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_data = []\n",
    "for name, dataset in zip([\"FinGPT/fingpt-sentiment-train\",\"lucasmccabe-lmi/CodeAlpaca-20k\",\"TIGER-Lab/MathInstruct\",\"medalpaca/medical_meadow_medical_flashcards\",\"tatsu-lab/alpaca\",],[fin_data,code_data,math_data,med_data,general_data]):\n",
    "    tmp = process_sft_dataset(name, dataset)\n",
    "    print(tmp.column_names)\n",
    "    processed_data.append(tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def labeling(example, label):\n",
    "    example[\"label\"] = label\n",
    "    return example\n",
    "\n",
    "label = [\"fin\",\"code\",\"math\",\"med\",\"gen\",] # Label each data's domain for the later domain coverage computation.\n",
    "for i, data in enumerate(processed_data):\n",
    "    data = data.map(lambda example: labeling(example, label[i]), batched=False)\n",
    "    processed_data[i] = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_domain_data = processed_data[0] # The first processed_data is financial domain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "selected_idxs = random.sample(range(len(in_domain_data)), 1000)\n",
    "base_data = in_domain_data.select(selected_idxs)\n",
    "\n",
    "# Construct each client's base data.\n",
    "clients_data = []\n",
    "for i in range(10):\n",
    "    clients_data.append(base_data.shard(10,i))\n",
    "\n",
    "# Use the remaining in-domain data as part of the public data.\n",
    "in_domain_data = in_domain_data.select(list(set(range(len(in_domain_data)))-set(selected_idxs)))\n",
    "print(len(in_domain_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start greedy client center selection in FedDCA!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform k-means clustering for client 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k=10 \n",
    "\n",
    "base_0_embeddings = model.encode(clients_data[0][\"instruction\"])\n",
    "\n",
    "kmeans = KMeans(n_clusters=k, random_state=0).fit(base_0_embeddings)\n",
    "labels = kmeans.labels_\n",
    "\n",
    "# Computing each cluster's size.\n",
    "counts = np.bincount(labels)\n",
    "\n",
    "# Find the center of the biggest cluster.\n",
    "largest_cluster_label = np.argmax(counts)\n",
    "cluster_center_0 = kmeans.cluster_centers_[largest_cluster_label]\n",
    "print(cluster_center_0.shape)\n",
    "client_centers = cluster_center_0.reshape((1,-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform greedy client center selection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img src=\"./src/greedy_client_center_selection.png\" alt=\"\" width=\"1000\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10-1): # Perform greedy client center selection on other nine clients' cluster center set.\n",
    "    i=i+1\n",
    "    base_i_embeddings = model.encode(clients_data[i][\"instruction\"])\n",
    "    kmeans = KMeans(n_clusters=k, random_state=0).fit(base_i_embeddings)\n",
    "    labels = kmeans.labels_\n",
    "    similarity_scores = np.sum(kmeans.cluster_centers_ @ client_centers.T, axis=-1)\n",
    "    selected_idxs = np.argsort(similarity_scores)[:10-i]      \n",
    "    counts = np.bincount(labels)\n",
    "    \n",
    "    \"\"\"\n",
    "    We consider maximizing the domain coverage from two aspects: \n",
    "    1) Select a client center that can represent the distribution of the local data. \n",
    "    2) To optimize the cross-client domain coverage, we filter client centers that are close to the previously selected client centers.\n",
    "    \"\"\"\n",
    "    # Sorting based on the size of the cluster.\n",
    "    largest_cluster_labels = np.argsort(-counts) # Descending order.\n",
    "    largest_cluster_label = -1\n",
    "    for j in largest_cluster_labels:\n",
    "        if j in selected_idxs:\n",
    "            largest_cluster_label = j\n",
    "            break\n",
    "    \n",
    "    selected_cluster_center = kmeans.cluster_centers_[largest_cluster_label]\n",
    "    client_centers = np.concatenate([client_centers, selected_cluster_center.reshape((1,-1))])\n",
    "client_centers = torch.tensor(client_centers, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct the public data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "369347\n"
     ]
    }
   ],
   "source": [
    "public_data = concatenate_datasets(processed_data)\n",
    "public_data = public_data.select(list(set(range(len(public_data)))-set(selected_idxs))) # Filter the selected data for local datasets.\n",
    "print(len(public_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public_embeddings = model.encode(public_data[\"instruction\"])\n",
    "public_embeddings = torch.tensor(public_embeddings, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform dense retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "retrival_nums = [5000] # Retrieval num\n",
    "domain = \"fin\"\n",
    "root = \"\" # Config your own root path for save dataset.\n",
    "\n",
    "for retrival_num in retrival_nums:\n",
    "    client_datasets = []\n",
    "    for i, sampled_data in enumerate(clients_data):\n",
    "        similarity_scores = torch.matmul(client_centers[i,:].cuda(), (public_embeddings.T).cuda()).cpu()\n",
    "        # Filter public data with client center similarity greater than or equal to 0.7\n",
    "        filtered_scores = [(score.item(), idx) for idx, score in enumerate(similarity_scores) if score < 0.7]\n",
    "        top_idxs = heapq.nlargest(retrival_num, range(len(filtered_scores)-1), key=lambda x:filtered_scores[x])\n",
    "        clinet_data = public_data.select(top_idxs)\n",
    "        clinet_data = concatenate_datasets([sampled_data, clinet_data])\n",
    "        client_datasets.append(clinet_data)\n",
    "        \n",
    "    for i, clinet_data in enumerate(client_datasets): \n",
    "        clinet_data.save_to_disk(f\"{root}/{domain}_{retrival_num}_{i}.parquet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize similarity score distribution of dense retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "random.seed(10)\n",
    "\n",
    "fig, axes = plt.subplots(5, 2, figsize=(15, 25)) \n",
    "\n",
    "axes = axes.flatten()  \n",
    "\n",
    "for i, sampled_data in enumerate(clients_data):\n",
    "    print(i)\n",
    "    \n",
    "    similarity_scores = torch.matmul(torch.tensor(client_centers[i,:]).cuda(), torch.tensor(public_embeddings.T).cuda()).cpu()\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    # Assuming similarity_scores is a 1D tensor, convert it to a NumPy array if needed\n",
    "    if hasattr(similarity_scores, 'numpy'):\n",
    "        similarity_scores = similarity_scores.numpy()\n",
    "\n",
    "    # Create a histogram with bins of size 0.1\n",
    "    bins = np.arange(0, 1.1, 0.05)  # Bins from 0 to 1 with step of 0.1\n",
    "    hist, bin_edges = np.histogram(similarity_scores, bins=bins)\n",
    "\n",
    "    for j in range(len(hist)):\n",
    "        print(f'Range [{bin_edges[j]}, {bin_edges[j + 1]}): {hist[j]}')\n",
    "\n",
    "    \n",
    "    ax = axes[i]\n",
    "    ax.hist(similarity_scores, bins=bins, edgecolor='black')\n",
    "    ax.set_xlabel('Similarity Score')\n",
    "    ax.set_ylabel('Frequency')\n",
    "    ax.set_title(f'Histogram of Similarity Scores {i + 1}')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "fileId": "69f353c0-21ed-4f1d-8832-470ac96015d8",
  "filePath": "/mnt/bn/data-tns-live-llm/leon/Cherry_LLM/niid_data/iid2niid_fin_QA.ipynb",
  "kernelspec": {
   "display_name": "Python 3.9.2 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
