{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook shows the computation of domain coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import random\n",
    "import heapq\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from contextlib import contextmanager\n",
    "from pprint import pprint as pp\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.cluster import KMeans\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from datasets import load_dataset, concatenate_datasets, load_from_disk, Dataset\n",
    "import pandas as pd\n",
    "from FlagEmbedding import FlagModel\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.cluster import MiniBatchKMeans, KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_data = load_dataset(\"sahil2801/CodeAlpaca-20k\")[\"train\"]\n",
    "fin_data = load_dataset(\"FinGPT/fingpt-sentiment-train\")[\"train\"]\n",
    "med_data = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\")[\"train\"]\n",
    "general_data = load_dataset(\"tatsu-lab/alpaca\")[\"train\"]\n",
    "math_data = load_dataset(\"TIGER-Lab/MathInstruct\")[\"train\"]\n",
    "\n",
    "def alpaca_format(example):\n",
    "    if example['input'] == \"\":\n",
    "        example[\"instruction\"] = example[\"instruction\"]\n",
    "    else:\n",
    "        example[\"instruction\"] = example[\"instruction\"] + \" \" + example['input']\n",
    "    example[\"response\"] = example['output']\n",
    "    return example\n",
    "\n",
    "def process_sft_dataset(dataset_name, dataset, dataset_sample=None) -> Dataset:\n",
    "    if dataset_name in [\"lucasmccabe-lmi/CodeAlpaca-20k\", \"yahma/alpaca-cleaned\", \"FinGPT/fingpt-sentiment-train\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"WizardLM/WizardLM_evol_instruct_70k\"]:\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif dataset_name in [\"tatsu-lab/alpaca\", \"vicgalle/alpaca-gpt4\", \"gbharti/finance-alpaca\"]:\n",
    "        dataset = dataset.map(alpaca_format, remove_columns=['input', 'output', 'text'], desc=f\"Preprocessing {dataset_name} for unified format.\")\n",
    "    elif dataset_name in [\"TIGER-Lab/MathInstruct\"]:\n",
    "        df = pd.DataFrame(dataset)\n",
    "        df = df.drop_duplicates(subset=['instruction'])\n",
    "        dataset = Dataset.from_pandas(df)\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "    elif dataset_name in [\"lighteval/MATH\"]:\n",
    "        dataset = dataset.rename_column(\"solution\", \"response\")\n",
    "        dataset = dataset.rename_column(\"problem\", \"instruction\")\n",
    "        dataset = dataset.remove_columns(['level', 'type'])\n",
    "    elif dataset_name in ['gsm8k']:\n",
    "        dataset = dataset.rename_column(\"question\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"answer\", \"response\")\n",
    "    elif dataset_name in ['medalpaca/medical_meadow_medical_flashcards']: \n",
    "        dataset = dataset.remove_columns(['instruction'])\n",
    "        dataset = dataset.rename_column(\"input\", \"instruction\")\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    elif \"math\" in dataset_name:\n",
    "        dataset = dataset.remove_columns(['source'])\n",
    "        dataset = dataset.rename_column(\"output\", \"response\")\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Dataset {dataset_name} is not supported.\")\n",
    "    dataset = dataset.shuffle(seed=42)\n",
    "    if dataset_sample:\n",
    "        num_sample = min(len(dataset), dataset_sample)\n",
    "        dataset = dataset.select(range(num_sample))\n",
    "    print(f\">> ===== After processing, Dataset {dataset_name} has {len(dataset)} examples. =====\")\n",
    "    return dataset\n",
    "\n",
    "processed_data = []\n",
    "for name, dataset in zip([\"lucasmccabe-lmi/CodeAlpaca-20k\",\"FinGPT/fingpt-sentiment-train\",\"medalpaca/medical_meadow_medical_flashcards\",\"tatsu-lab/alpaca\",\"TIGER-Lab/MathInstruct\"],[code_data,fin_data,med_data,general_data,math_data]):\n",
    "    tmp = process_sft_dataset(name,dataset)\n",
    "    processed_data.append(tmp)\n",
    "    \n",
    "public_data = concatenate_datasets(processed_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = FlagModel('BAAI/bge-large-en-v1.5', \n",
    "                  query_instruction_for_retrieval=\"\",\n",
    "                  use_fp16=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch.nn.functional import cosine_similarity as cosine_similarity\n",
    "\n",
    "def coverage(A, V):\n",
    "    A_tensor = torch.tensor(A, dtype=torch.float32)\n",
    "    V_tensor = torch.tensor(V, dtype=torch.float32)\n",
    "    # Calculate the domain coverage of set A\n",
    "    similarities = torch.matmul(V_tensor, A_tensor.T)\n",
    "    # Calculate the maximum similarity for each v in V\n",
    "    max_similarities = torch.max(similarities, dim=1).values\n",
    "    # Sum the similarity\n",
    "    total_similarity = torch.sum(max_similarities).item()/len(max_similarities)\n",
    "    return total_similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public_data_embeddings = model.encode(public_data[\"instruction\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_embeddings = public_data_embeddings[:20022]\n",
    "med_embeddings = public_data_embeddings[96794:130749]\n",
    "fin_embeddings = public_data_embeddings[20022:96794]\n",
    "math_embeddings = public_data_embeddings[182751:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings = [\"code_5000\"] \n",
    "root = \"\"\n",
    "domain = \"code\"\n",
    "\n",
    "for setting in tqdm(settings):\n",
    "    cross_client_datas = []\n",
    "    for i in range(10):\n",
    "        cross_client_datas.append(load_from_disk(f\"{root}/{setting}_{i}.parquet\"))\n",
    "    cross_client_datas = concatenate_datasets(cross_client_datas)\n",
    "    cross_client_datas = cross_client_datas.filter(lambda example: example['label'] == domain) # Filter out-of-domain data.\n",
    "    datas_embeddings = model.encode(cross_client_datas[\"instruction\"])\n",
    "\n",
    "    if \"code\" in setting: domain_embeddings = code_embeddings\n",
    "    elif \"med\" in setting: domain_embeddings = med_embeddings\n",
    "    elif \"fin\" in setting: domain_embeddings = fin_embeddings\n",
    "    else: domain_embeddings = math_embeddings\n",
    "\n",
    "    domain_coverage = coverage(datas_embeddings, domain_embeddings)"
   ]
  }
 ],
 "metadata": {
  "fileId": "4dca4984-8c91-450b-a4aa-c773f4bfaa6f",
  "filePath": "/mnt/bn/data-tns-live-llm/leon/Cherry_LLM/niid_data/domain_coverage.ipynb",
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
