{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook shows the training process of the projector for FedDCA*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import random\n",
    "import heapq\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from contextlib import contextmanager\n",
    "from pprint import pprint as pp\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.cluster import KMeans\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset\n",
    "import pandas as pd\n",
    "from FlagEmbedding import FlagModel\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.cluster import MiniBatchKMeans, KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------using 8*GPUs----------\n"
     ]
    }
   ],
   "source": [
    "model_s = FlagModel('BAAI/bge-large-en-v1.5', \n",
    "                  query_instruction_for_retrieval=\"\",\n",
    "                  use_fp16=True,\n",
    "                )\n",
    "model_c = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', model_kwargs={\"torch_dtype\":torch.float16})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_data = load_dataset(\"sahil2801/CodeAlpaca-20k\")[\"train\"]\n",
    "fin_data = load_dataset(\"FinGPT/fingpt-sentiment-train\")[\"train\"]\n",
    "med_data = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\")[\"train\"]\n",
    "general_data = load_dataset(\"tatsu-lab/alpaca\")[\"train\"]\n",
    "math_data = load_dataset(\"TIGER-Lab/MathInstruct\")[\"train\"]\n",
    "\n",
    "def alpaca_format(example):\n",
    "    if example['input'] == \"\":\n",
    "        example[\"instruction\"] = example[\"instruction\"]\n",
    "    else:\n",
    "        example[\"instruction\"] = example[\"instruction\"] + \" \" + example['input']\n",
    "    example[\"response\"] = example['output']\n",
    "    return example\n",
    "\n",
    "def process_sft_dataset(dataset_name, dataset, dataset_sample=None) -> Dataset:\n",
    "    if dataset_name in [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"yahma/alpaca-cleaned\", \"FinGPT/fingpt-sentiment-train\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"WizardLM/WizardLM_evol_instruct_70k\"]:\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif dataset_name in [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\", \"gbharti/finance-alpaca\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"TIGER-Lab/MathInstruct\"]:\n",
    "        df = pd.DataFrame(dataset)\n",
    "        df = df.drop_duplicates(subset=['instruction'])\n",
    "        dataset = Dataset.from_pandas(df)\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "    elif dataset_name in [\"lighteval/MATH\"]:\n",
    "        dataset = dataset.rename_column(\"solution\", \"response\")\n",
    "        dataset = dataset.rename_column(\"problem\", \"instruction\")\n",
    "        dataset = dataset.remove_columns(['level', 'type'])\n",
    "    elif dataset_name in ['gsm8k']:\n",
    "        dataset = dataset.rename_column(\"question\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"answer\", \"response\")\n",
    "    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: \n",
    "        dataset = dataset.remove_columns(['instruction'])\n",
    "        dataset = dataset.rename_column(\"input\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif \"math\" in dataset_name:\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Dataset {dataset_name} is not supported.\")\n",
    "    dataset = dataset.shuffle(seed=42)\n",
    "    if dataset_sample:\n",
    "        num_sample = min(len(dataset), dataset_sample)\n",
    "        dataset = dataset.select(range(num_sample))\n",
    "    print(f\">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====\")\n",
    "    return dataset\n",
    "\n",
    "processed_data = []\n",
    "for name, dataset in zip([\"lucasmccabe-lmi/CodeAlpaca-20k\",\"FinGPT/fingpt-sentiment-train\",\"medalpaca/medical_meadow_medical_flashcards\",\"tatsu-lab/alpaca\",\"TIGER-Lab/MathInstruct\"],[code_data,fin_data,med_data,general_data,math_data]):\n",
    "    tmp = process_sft_dataset(name,dataset)\n",
    "    processed_data.append(tmp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the projector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "public_data = concatenate_datasets(processed_data)[\"instruction\"]\n",
    "\n",
    "train_data_size = 10000 # Random select 10000 public data for training.\n",
    "train_data = public_data.select(random.sample(range(len(public_data)), 10000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct the train set for contrastive learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_s = model_s.encode(train_data)\n",
    "embeddings_s = torch.Tensor(embeddings_s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "pool = model_c.start_multi_process_pool()\n",
    "embeddings_c = torch.tensor(model_c.encode_multi_process(train_data,pool,precision='float32'))\n",
    "model_c.stop_multi_process_pool(pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The model structure of the projector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Projector(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Projector, self).__init__()\n",
    "        self.fc1 = nn.Linear(384, 128)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(128, 1024)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ContrastiveLoss(nn.Module):\n",
    "    def __init__(self, temperature=0.5):\n",
    "        super(ContrastiveLoss, self).__init__()\n",
    "        self.temperature = temperature\n",
    "\n",
    "    def forward(self, anchor, positive, negatives) -> torch.Tensor:\n",
    "        anchor_pos_similarity = (anchor * positive).sum(dim=1) / self.temperature\n",
    "        anchor_neg_similarity = (anchor.unsqueeze(1) * negatives).sum(dim=2) / self.temperature\n",
    "\n",
    "        logits = torch.cat([anchor_pos_similarity.unsqueeze(1), anchor_neg_similarity], dim=1)\n",
    "        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)\n",
    "\n",
    "        loss = nn.functional.cross_entropy(logits, labels)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, embeddings_c, embeddings_s):\n",
    "        \"\"\"\n",
    "        Initialize the dataset\n",
    "        :param embeddings_c\n",
    "        :param embeddings_s\n",
    "        \"\"\"\n",
    "        self.embeddings_c = embeddings_c\n",
    "        self.embeddings_s = embeddings_s\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.embeddings_c)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        embeddings_c_sample = self.embeddings_c[idx]\n",
    "        embeddings_s_sample = self.embeddings_s[idx]\n",
    "\n",
    "        return embeddings_c_sample, embeddings_s_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = CustomDataset(embeddings_c, embeddings_s)\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projector = Projector().cuda()\n",
    "criterion = ContrastiveLoss(temperature=0.5)\n",
    "num_epochs = 3\n",
    "optimizer = torch.optim.Adam(projector.parameters(), lr=1e-4)\n",
    "\n",
    "# Suppose the dataloader produces data in the form of (embeddings_c_batch, embeddings_s_batch)\n",
    "for epoch in range(num_epochs):\n",
    "    tqdm_dataloader = tqdm(enumerate(dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', total=len(dataloader))\n",
    "    for batch_idx, (embeddings_c_batch, embeddings_s_batch) in tqdm_dataloader:\n",
    "        # Projecting embeddings_c into higher dimensional space\n",
    "        embeddings_c_batch, embeddings_s_batch = embeddings_c_batch.cuda(), embeddings_s_batch.cuda()\n",
    "        projected_c_batch = projector(embeddings_c_batch)\n",
    "        total_loss = 0\n",
    "        # Calculate the loss for each sample and add it up\n",
    "        for i in range(len(embeddings_c_batch)):\n",
    "            # Take the positive sample embedding of the i - th sample\n",
    "            positive = embeddings_s_batch[i]\n",
    "            # Take the negative sample embedding of the i-th sample, here we take the other samples in the batch except itself\n",
    "            negatives = torch.stack([embeddings_s_batch[j] for j in range(len(embeddings_c_batch)) if j != i])\n",
    "            loss = criterion(projected_c_batch[i].unsqueeze(0), positive.unsqueeze(0), negatives)\n",
    "            total_loss += loss\n",
    "        # Calculate the average loss of the batch\n",
    "        batch_loss = total_loss / len(embeddings_c_batch)\n",
    "        tqdm_dataloader.set_description(f'Batch loss: {batch_loss:.4f}')\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        batch_loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_path = \"\"\n",
    "torch.save(projector.state_dict(), output_path) # Save the trained model"
   ]
  }
 ],
 "metadata": {
  "fileId": "3fca732b-5388-49f4-8bd7-abc17446225a",
  "filePath": "/mnt/bn/data-tns-live-llm/leon/Cherry_LLM/niid_data/projector_for_show.ipynb",
  "kernelspec": {
   "display_name": "Python 3.9.2 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
