{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HkOk1l_44ERM"
   },
   "source": [
    "# Environmental Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5mmqSh3szedt",
    "outputId": "cffd6b5b-3aef-4c21-e9e3-76ba7f9f5b7e"
   },
   "outputs": [],
   "source": [
    "! pip3 install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-ODsn-k4YCmI"
   },
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ar4O_1tszb8L",
    "outputId": "e47caf66-999a-4e53-9c0b-4121f718350e"
   },
   "outputs": [],
   "source": [
    "sys.path.append('') # Append data and code directories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "grJZp6ZSAFgx",
    "outputId": "50546bb0-c542-4094-a551-f0110d54f55c"
   },
   "outputs": [],
   "source": [
    "import datetime\n",
    "dt = datetime.datetime.now()\n",
    "datestr = dt.strftime(\"%m-%d-%Y_%H%M\")\n",
    "print(datestr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hPPfRpmnfjWo"
   },
   "source": [
    "# Train and validate BERT base models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ptimx4CB_Ml8"
   },
   "outputs": [],
   "source": [
    "from run_Bert_model import model_train_validate_test\n",
    "import pandas as pd\n",
    "from utils import Metric\n",
    "import os\n",
    "\n",
    "NUM_REPS = 5 # number of replications on training models to get mean accuracy / std\n",
    "\n",
    "data_path = \"\"\n",
    "test_df = pd.read_csv(os.path.join(data_path,\"test.tsv\"),sep='\\t',header=None, names=['similarity','s1'])\n",
    "target_dir = \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XQtCfaDI3k23"
   },
   "source": [
    "## Real-world training data baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PiJk76q4fkSK"
   },
   "outputs": [],
   "source": [
    "human_train_df = pd.read_csv(os.path.join(data_path,\"train.tsv\"),sep='\\t',header=None, names=['similarity','s1'])\n",
    "human_dev_df = pd.read_csv(os.path.join(data_path,\"dev.tsv\"),sep='\\t',header=None, names=['similarity','s1'])\n",
    "human_train_df = human_train_df.sample(n=6000, random_state=42) # Random downsample human_train_df to 6000 samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mGFbny5-eN7W"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "human_results_df = pd.DataFrame(columns=['acc', 'f1'])\n",
    "pred_save_name = f\"human_pred-{datestr}.csv\"\n",
    "\n",
    "for i in range(NUM_REPS):\n",
    "  torch.cuda.manual_seed(i)\n",
    "  for j in range(NUM_REPS):\n",
    "    model_train_validate_test(human_train_df, human_dev_df, test_df, target_dir, save_name=pred_save_name,\n",
    "                              max_seq_len=50,\n",
    "                              epochs=3,\n",
    "                              batch_size=16,\n",
    "                              lr=2e-05,\n",
    "                              patience=1,\n",
    "                              max_grad_norm=10.0,\n",
    "                              if_save_model=True,\n",
    "                              checkpoint=None,\n",
    "                              seed=j)\n",
    "\n",
    "    test_result = pd.read_csv(os.path.join(target_dir, pred_save_name))\n",
    "    acc, f1 = Metric(test_df.similarity, test_result.prediction)\n",
    "    human_results_df = pd.concat([human_results_df, pd.DataFrame({'acc': [acc], 'f1': [f1]})], ignore_index=True)\n",
    "\n",
    "results_save_name = \"human_results.csv\"\n",
    "human_results_df.to_csv(os.path.join(target_dir, results_save_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "914qC1ZdCH5f"
   },
   "source": [
    "## Synthetic (zero-shot) training data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W6Nrc1fW3vtx"
   },
   "source": [
    "### Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0MUw6H3DD5__"
   },
   "outputs": [],
   "source": [
    "syn_data_df = pd.read_csv(os.path.join(data_path,\"PGDG_6k.tsv\"),sep='\\t', names=['s1','similarity'])\n",
    "syn_data_df = syn_data_df.iloc[1:]\n",
    "syn_data_df['similarity'] = syn_data_df['similarity'].astype('int64')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9o30KkQi34h0"
   },
   "source": [
    "### Train model on all 6,000 synthetic samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aLYA2Tfbaf-C"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Validate prior work, train on all synthetic data\n",
    "all_data_df = syn_data_df.sample(frac=1, random_state=42) # random shuffling\n",
    "train_df = all_data_df.iloc[:int(0.9*len(all_data_df))]\n",
    "dev_df = all_data_df.iloc[int(0.9*len(all_data_df)):]\n",
    "\n",
    "all_syn_results_df = pd.DataFrame(columns=['acc', 'f1'])\n",
    "pred_save_name = f\"all_syn_pred-{datestr}.csv\"\n",
    "\n",
    "for i in range(NUM_REPS):\n",
    "  torch.cuda.manual_seed(i)\n",
    "  for j in range(NUM_REPS):\n",
    "    model_train_validate_test(train_df, dev_df, test_df, target_dir, save_name=pred_save_name,\n",
    "                              max_seq_len=50,\n",
    "                              epochs=3,\n",
    "                              batch_size=16,\n",
    "                              lr=2e-05,\n",
    "                              patience=1,\n",
    "                              max_grad_norm=10.0,\n",
    "                              if_save_model=True,\n",
    "                              checkpoint=None,\n",
    "                              seed=j)\n",
    "\n",
    "    test_result = pd.read_csv(os.path.join(target_dir, pred_save_name))\n",
    "    acc, f1 = Metric(test_df.similarity, test_result.prediction)\n",
    "    all_syn_results_df = pd.concat([all_syn_results_df, pd.DataFrame({'acc': [acc], 'f1': [f1]})], ignore_index=True)\n",
    "\n",
    "results_save_name = f\"all_syn_results-f{datestr}.csv\"\n",
    "all_syn_results_df.to_csv(os.path.join(target_dir, results_save_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2iCFJ1p-3-DO"
   },
   "source": [
    "## Downsampling of \"representative\" synthetic data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EXnz8WcPGn7r"
   },
   "source": [
    "### Embed sentences using Vertex AI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "83gTZSD8GqeW"
   },
   "outputs": [],
   "source": [
    "# init the vertexai package\n",
    "import vertexai\n",
    "import time\n",
    "# Load the text embeddings model\n",
    "from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get embeddings for a list of texts\n",
    "BATCH_SIZE = 5 # how to properly set batch size?\n",
    "\n",
    "PROJECT_ID = \"\"\n",
    "LOCATION   = \"\"\n",
    "\n",
    "vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
    "\n",
    "model = TextEmbeddingModel.from_pretrained(\"text-embedding-004\")\n",
    "\n",
    "def get_embeddings_wrapper(texts):\n",
    "    embs = []\n",
    "    for i in tqdm(range(0, len(texts), BATCH_SIZE)):\n",
    "        time.sleep(1)  # to avoid the quota error\n",
    "        result = model.get_embeddings(texts[i : i + BATCH_SIZE].tolist())\n",
    "        embs = embs + [e.values for e in result]\n",
    "    return embs\n",
    "\n",
    "def get_embeddings_task(texts, task):\n",
    "    '''\n",
    "    Get embeddings for a list of texts with a specific task\n",
    "    task = ;'CLUSTERING' or 'SEMANTIC_SIMILARITY'\n",
    "    '''\n",
    "    embs = []\n",
    "    BATCH_SIZE = 5 # set batch size to the limit\n",
    "    for i in tqdm(range(0, len(texts), BATCH_SIZE)):\n",
    "        inputs = [TextEmbeddingInput(text, task) for text in texts[i : i + BATCH_SIZE]]\n",
    "        batch_embs = model.get_embeddings(inputs)\n",
    "        embs.extend([embedding.values for embedding in batch_embs])\n",
    "    return embs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-G7rm2aroMAF"
   },
   "outputs": [],
   "source": [
    "task = 'CLUSTERING'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uOOnBaPhoS6W"
   },
   "outputs": [],
   "source": [
    "# get embeddings for the question titles and add them as \"embedding\" column\n",
    "embeddings = get_embeddings_task(syn_data_df['s1'], task)\n",
    "# append synthetic reviw with corresponding embedding in dataframe and label of whether the review is positive or negative, as well as the genre\n",
    "syn_data_df = pd.DataFrame({\"s1\": syn_data_df['s1'], \"similarity\": syn_data_df['similarity'], \"embedding\": embeddings})\n",
    "syn_data_df.to_csv(os.path.join(data_path,\"syn_data_embeddings-tCluster.tsv\"), sep='\\t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G_gwrR5T4Rp2"
   },
   "source": [
    "### Build similarity graph from embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xaz6FX8bBTwb"
   },
   "source": [
    "Similarity graph building and greedy max coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QCD7BCA2K9IN"
   },
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "# labels maps i to its label.\n",
    "def build_graph(cos_sim, sim_thresh=0.5, max_degree=None, labels=None):\n",
    "    G = nx.Graph()\n",
    "    for i in range(len(cos_sim)):\n",
    "        G.add_node(i)\n",
    "        # Sort neighbors by similarity in descending order\n",
    "        neighbors = sorted(enumerate(cos_sim[i]), key=lambda x: x[1], reverse=True)\n",
    "        for j, similarity in neighbors:\n",
    "            if j == i:\n",
    "                continue\n",
    "            if max_degree and G.degree(i) >= max_degree:\n",
    "                break  # Exit the inner loop if max_degree is reached\n",
    "            if similarity >= sim_thresh and labels and labels[i]==labels[j]:\n",
    "                G.add_edge(i, j, weight=similarity)\n",
    "        # add self-loop, doesn't count toward max_degree\n",
    "        G.add_edge(i, i, weight=1)\n",
    "    return G\n",
    "\n",
    "# Graph sampling algorithms (max-cover)\n",
    "def max_cover_sampling(graph, k):\n",
    "    nodes = list(graph.nodes())\n",
    "    selected_nodes = set()\n",
    "    covered_nodes = set()\n",
    "    for _ in range(k):\n",
    "      if not nodes:\n",
    "        break\n",
    "      max_cover_node = max([node for node in nodes if node not in selected_nodes], key=lambda n: len(set(graph.neighbors(n)) - covered_nodes))\n",
    "      selected_nodes.add(max_cover_node)\n",
    "      covered_nodes.update(graph.neighbors(max_cover_node))\n",
    "\n",
    "      # Remove neighbors of selected node\n",
    "      for neighbor in graph.neighbors(max_cover_node):\n",
    "        if neighbor in nodes:\n",
    "          nodes.remove(neighbor)\n",
    "    return list(selected_nodes), len(nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NBigHJqvBPyf"
   },
   "source": [
    "Compute the optimal similarity threshold using binary search on graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dQiKID9xkzIt"
   },
   "outputs": [],
   "source": [
    "def calculate_similarity_threshold(data, num_samples, coverage, epsilon=None, labels=None):\n",
    "    total_num = len(data)\n",
    "    if epsilon is None:\n",
    "        # There is a chance that we never get close enough to \"coverage\" to terminate\n",
    "        # the loop. I think at the very least we should have epsilon > 1/total_num.\n",
    "        # So let's set set epsilon equal to the twice of the minimum possible change\n",
    "        # in coverage.\n",
    "        epsilon = 5 * 10 / total_num  # Dynamic epsilon\n",
    "\n",
    "    if coverage < num_samples / total_num:\n",
    "        node_graph = build_graph(data, 1)\n",
    "        samples, rem_nodes = max_cover_sampling(node_graph, num_samples)\n",
    "        return 1, node_graph, samples\n",
    "    # using an integer for sim threhsold avoids lots of floating drama!\n",
    "    sim_upper = 1000\n",
    "    sim_lower = 707 # corresponds to 0.707\n",
    "    max_run = 20\n",
    "    count = 0\n",
    "    current_coverage = 0\n",
    "\n",
    "    # Set sim to sim_lower to run the first iteration with sim_lower. If we\n",
    "    # cannot achieve the coverage with sim_lower, then return the samples.\n",
    "    sim = sim_lower\n",
    "    cap = (2 * total_num * coverage) / num_samples\n",
    "    while abs(current_coverage - coverage) > epsilon and sim_upper - sim_lower > 1:\n",
    "        if count >= max_run:\n",
    "            print(f\"Reached max number of iterations ({max_run}). Breaking...\")\n",
    "            break\n",
    "        count += 1\n",
    "        print(f\"sim: {sim / 1000}\")\n",
    "\n",
    "        node_graph = build_graph(data, sim / 1000, max_degree=cap, labels=labels)\n",
    "        samples, rem_nodes = max_cover_sampling(node_graph, num_samples)\n",
    "        current_coverage = (total_num - rem_nodes) / total_num\n",
    "        if current_coverage < coverage:\n",
    "            sim_upper = sim\n",
    "        else:\n",
    "            sim_lower = sim\n",
    "        sim = (sim_upper + sim_lower) / 2\n",
    "\n",
    "    return sim / 1000, node_graph, samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aviBowIfBYdQ"
   },
   "source": [
    "Model training functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E6Zcp1r02a31"
   },
   "outputs": [],
   "source": [
    "def train_model_max_cover(data_df, test_df, target_dir, seed=0):\n",
    "  # Train just on the text data\n",
    "  data_df = data_df.drop(columns=['embedding'])\n",
    "\n",
    "  # Shuffle entries\n",
    "  data_df = data_df.sample(frac=1, random_state=42)\n",
    "\n",
    "  # Train-Dev split\n",
    "  train_df = data_df.iloc[:int(0.9*len(data_df))]\n",
    "  dev_df = data_df.iloc[int(0.9*len(data_df)):]\n",
    "  pred_save_name = f\"max_k{K}_pred-{datestr}.csv\"\n",
    "\n",
    "  test_result = model_train_validate_test(train_df, dev_df, test_df, target_dir, save_name=pred_save_name,\n",
    "         max_seq_len=50,\n",
    "         epochs=3,\n",
    "         batch_size=16,\n",
    "         lr=2e-05,\n",
    "         patience=1,\n",
    "         max_grad_norm=10.0,\n",
    "         if_save_model=True,\n",
    "         checkpoint=None,\n",
    "         seed=seed)\n",
    "\n",
    "  # test_result = pd.read_csv(os.path.join(target_dir, pred_save_name))\n",
    "  acc, f1 = Metric(test_df.similarity, test_result.prediction)\n",
    "  return acc, f1\n",
    "\n",
    "def train_model_random(K, data_df, test_df, target_dir, seed=0):\n",
    "  print(len(data_df))\n",
    "  data_df = data_df.sample(frac=1, random_state=42) # Shuffle the training data before train-dev split\n",
    "  train_df = data_df.iloc[:int(0.9*len(data_df))]\n",
    "  dev_df = data_df.iloc[int(0.9*len(data_df)):]\n",
    "\n",
    "  pred_save_name = f\"random_k{K}_pred-{datestr}.csv\"\n",
    "  model_train_validate_test(train_df, dev_df, test_df, target_dir, save_name=pred_save_name,\n",
    "         max_seq_len=50,\n",
    "         epochs=3,\n",
    "         batch_size=16,\n",
    "         lr=2e-05,\n",
    "         patience=1,\n",
    "         max_grad_norm=10.0,\n",
    "         if_save_model=True,\n",
    "         checkpoint=None,\n",
    "         seed=seed)\n",
    "\n",
    "  test_result = pd.read_csv(os.path.join(target_dir, pred_save_name))\n",
    "  acc, f1 = Metric(test_df.similarity, test_result.prediction)\n",
    "  return acc, f1\n",
    "\n",
    "def train_model_kmeans(K, data_df, test_df, target_dir, seed=0):\n",
    "  embed_data = np.array(data_df['embedding'].tolist())\n",
    "\n",
    "  # Find kmeans centers\n",
    "  kmeans = KMeans(n_clusters=K, random_state=j, n_init=\"auto\").fit(embed_data)\n",
    "\n",
    "\n",
    "  # Loop over all clusters and find index of closest point to the cluster center and append to closest_pt_idx list.\n",
    "  closest_pt_idx = []\n",
    "  for iclust in range(kmeans.n_clusters):\n",
    "      # get all points assigned to each cluster:\n",
    "      cluster_pts = embed_data[kmeans.labels_ == iclust]\n",
    "      # get all indices of points assigned to this cluster:\n",
    "      cluster_pts_indices = np.where(kmeans.labels_ == iclust)[0]\n",
    "\n",
    "      cluster_cen = kmeans.cluster_centers_[iclust]\n",
    "      min_idx = np.argmin([euclidean(embed_data[idx], cluster_cen) for idx in cluster_pts_indices])\n",
    "\n",
    "      # Testing:\n",
    "      # print('closest point to cluster center: ', cluster_pts[min_idx])\n",
    "      # print('closest index of point to cluster center: ', cluster_pts_indices[min_idx])\n",
    "      # print('  ', embed_data[cluster_pts_indices[min_idx]])\n",
    "      closest_pt_idx.append(cluster_pts_indices[min_idx])\n",
    "\n",
    "  data_df = data_df.iloc[closest_pt_idx]\n",
    "  print(f\"Points sampled = {len(data_df)}\")\n",
    "\n",
    "  train_df = data_df.iloc[:int(0.9*len(data_df))]\n",
    "  dev_df = data_df.iloc[int(0.9*len(data_df)):]\n",
    "  pred_save_name = f\"kmeans_k{K}_pred-{datestr}.csv\"\n",
    "\n",
    "  model_train_validate_test(train_df, dev_df, test_df, target_dir, save_name=pred_save_name,\n",
    "         max_seq_len=50,\n",
    "         epochs=3,\n",
    "         batch_size=16,\n",
    "         lr=2e-05,\n",
    "         patience=1,\n",
    "         max_grad_norm=10.0,\n",
    "         if_save_model=True,\n",
    "         checkpoint=None,\n",
    "         seed=seed)\n",
    "\n",
    "  test_result = pd.read_csv(os.path.join(target_dir, pred_save_name))\n",
    "  acc, f1 = Metric(test_df.similarity, test_result.prediction)\n",
    "  return acc, f1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IFtHYSJStSEM"
   },
   "source": [
    "### Train model on $k$ Selected Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Shuu7QWEEHv5",
    "outputId": "51cd7e85-d03c-4cbe-fd3d-05fbdf97ec2e"
   },
   "outputs": [],
   "source": [
    "# Set Ks from 100 to 1200\n",
    "Ks = list(range(100, 3000, 300))\n",
    "print(Ks)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "ACS:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 357
    },
    "id": "j2MAn8vI8jqW",
    "outputId": "c2de7dd3-ab77-41d9-8860-81b422868db4"
   },
   "outputs": [],
   "source": [
    "cos_sim = cosine_similarity(syn_data_df['embedding'].tolist())\n",
    "\n",
    "max_k_results_df = pd.DataFrame(columns=['K', 'sim_thresh', 'acc', 'f1'])\n",
    "results_save_name = f\"max_k_results-{datestr}.csv\"\n",
    "\n",
    "for K in Ks:\n",
    "  print(f\"K: {K}\")\n",
    "  _, _, selected_samples = calculate_similarity_threshold(cos_sim, K, 0.9)\n",
    "\n",
    "  sel_data_df = syn_data_df.iloc[selected_samples]\n",
    "\n",
    "  for i in range(NUM_REPS):\n",
    "    torch.cuda.manual_seed(i)\n",
    "    for j in range(NUM_REPS):\n",
    "      print(f\"Iteration: {(i*NUM_REPS)+j+1} of {NUM_REPS**2}\")\n",
    "      acc, f1 = train_model_max_cover(sel_data_df, test_df, target_dir,seed=j)\n",
    "      max_k_results_df = pd.concat([max_k_results_df, pd.DataFrame({'K': K, 'acc': [acc], 'f1': [f1]})], ignore_index=True)\n",
    "  max_k_results_df.to_csv(os.path.join(target_dir, results_save_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "kMeans:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "k_means_results_df = pd.DataFrame(columns=['K', 'acc', 'f1'])\n",
    "results_save_name = f\"k_means_results-{datestr}.csv\"\n",
    "\n",
    "for K in Ks:\n",
    "  print(f\"K: {K}\")\n",
    "  for i in range(NUM_REPS):\n",
    "    torch.cuda.manual_seed(i)\n",
    "    for j in range(NUM_REPS):\n",
    "      print(f\"Iteration: {(i*NUM_REPS)+j+1} of {NUM_REPS**2}\")\n",
    "      acc, f1 = train_model_kmeans(K, syn_data_df, test_df, target_dir, seed=j)\n",
    "      k_means_results_df = pd.concat([k_means_results_df, pd.DataFrame({'K': K, 'acc': [acc], 'f1': [f1]})], ignore_index=True)\n",
    "  k_means_results_df.to_csv(os.path.join(target_dir, results_save_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2ionoHMu6OHk"
   },
   "source": [
    "Random:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7wbyHUhy6QJZ",
    "outputId": "325b945e-467a-4427-c1e5-2f9c77b03a58"
   },
   "outputs": [],
   "source": [
    "rand_k_results_df = pd.DataFrame(columns=['K', 'acc', 'f1'])\n",
    "results_save_name = f\"rand_k_results-{datestr}-cont.csv\"\n",
    "\n",
    "for K in Ks:\n",
    "  # data_df = data_df.drop(columns=['embedding'])\n",
    "  for i in range(NUM_REPS):\n",
    "    torch.cuda.manual_seed(i)\n",
    "    data_df = syn_data_df.sample(n=K, random_state=i)\n",
    "\n",
    "    acc, f1 = train_model_random(K, data_df, test_df, target_dir)\n",
    "    rand_k_results_df = pd.concat([rand_k_results_df, pd.DataFrame({'K': K, 'acc': [acc], 'f1': [f1]})], ignore_index=True)\n",
    "  rand_k_results_df.to_csv(os.path.join(target_dir, results_save_name))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "L4",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
