{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "splits = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}\n",
    "df = pd.read_parquet(\"hf://datasets/stanfordnlp/sst2/\" + splits[\"train\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_word_rarity(df):\n",
    "    word_count = df['sentence'].str.split().explode().value_counts()\n",
    "    word_count /= word_count.sum()\n",
    "    return word_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_rarity = compute_word_rarity(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assign_rarity(sentence):\n",
    "    return sum([-np.log(word_rarity.get(word, 0)) for word in sentence.split()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['rarity'] = df['sentence'].apply(assign_rarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['rarity'] = df['rarity']/df['rarity'].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### import sentence bert\n",
    "from sklearn import linear_model\n",
    "from sklearn.svm import SVC as SVM\n",
    "from sentence_transformers import SentenceTransformer\n",
    "model = SentenceTransformer('all-mpnet-base-v2')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment for difficulty based learning for text classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.label.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# take a balanced subset of the data of size 10000\n",
    "df_balanced = df.groupby('label').apply(lambda x: x.sample(n=2500, random_state=42))\n",
    "hardness_threshold = 0.5\n",
    "easy_threshold = 0.2\n",
    "df_balanced = df_balanced[(df_balanced.rarity>hardness_threshold)|(df_balanced.rarity>easy_threshold)]\n",
    "df_balanced.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_balanced = df_balanced.reset_index(drop=True)\n",
    "embeddings_balanced = model.encode(df_balanced.loc[:, 'sentence'].values, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.plot(np.sort(np.abs(np.linalg.eig(embeddings_balanced.T@embeddings_balanced/((embeddings_balanced.shape[0])))[0]))[::-1])\n",
    "plt.xlabel(\"Eigenvalue Index\")\n",
    "plt.ylabel(\"Eigenvalue Magnitude\")\n",
    "plt.savefig(\"plots/eigenvalue_spectrum_text.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "def get_good_vectors(embeddings):\n",
    "    n = embeddings.shape[0]\n",
    "    d = embeddings.shape[1]\n",
    "    x = cp.Variable(n)\n",
    "    constraints = [cp.sum(x)==1,x>=0]\n",
    "    objective_function = embeddings.T @ cp.diag(x) @ embeddings\n",
    "    objective = cp.Maximize(cp.lambda_min(objective_function))\n",
    "    prob = cp.Problem(objective,constraints)\n",
    "    prob.solve()\n",
    "    return x.value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_notin =df[~df.idx.isin(df_balanced.idx.values)]\n",
    "df_notin_hard = df_notin[df_notin.rarity>hardness_threshold]\n",
    "df_notin_easy = df_notin[(df_notin.rarity>easy_threshold)&(df_notin.rarity<hardness_threshold)]\n",
    "hard_validation_label = df_notin_hard.label.values\n",
    "easy_validation_label = df_notin_easy.label.values\n",
    "hard_validation_em = model.encode(df_notin_hard.sentence.values,show_progress_bar=True)\n",
    "easy_validation_em = model.encode(df_notin_easy.sentence.values,show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_mc = 10\n",
    "hardness_threshold = 0.5\n",
    "metrics = np.zeros((n_mc,4,2))\n",
    "for mc in tqdm(range(n_mc)):\n",
    "\n",
    "    hard_validation_subset = np.random.choice(hard_validation_label.shape[0],size=500,replace=False)\n",
    "    hard_validation_em_subset = hard_validation_em[hard_validation_subset]\n",
    "    hard_validation_labels = hard_validation_label[hard_validation_subset]\n",
    "    easy_validation_subset = np.random.choice(easy_validation_label.shape[0],size=500,replace=False)\n",
    "    easy_validation_em_subset = easy_validation_em[easy_validation_subset]\n",
    "    easy_validation_labels = easy_validation_label[easy_validation_subset]\n",
    "    train_set_all = df_balanced.index\n",
    "\n",
    "    model_hard = SVM(kernel='linear')\n",
    "    train_set_hard = (df_balanced[df_balanced.rarity>hardness_threshold].index)\n",
    "    train_set_hard = np.random.choice(train_set_hard,size=train_set_hard.shape[0],replace=False)\n",
    "    #print(\"Train Set Hard\",train_set_hard.shape[0])\n",
    "\n",
    "    N_samples = 100\n",
    "    model_mixed = SVM(kernel='linear')\n",
    "    train_set_mixed = np.concatenate([np.random.choice(train_set_hard,size=train_set_hard.shape[0]),np.random.choice(train_set_all,size=N_samples,replace=False)])\n",
    "    train_set_mixed = np.unique(train_set_mixed)\n",
    "    model_mixed.fit(embeddings_balanced[train_set_mixed],df_balanced.loc[train_set_mixed,'label'])\n",
    "    metrics[mc,1,:] = (model_mixed.predict(hard_validation_em_subset)==hard_validation_labels).mean(),(model_mixed.predict(easy_validation_em_subset)==easy_validation_labels).mean()\n",
    "\n",
    "    N_samples_random = train_set_mixed.shape[0]\n",
    "    train_set_random = np.random.choice(train_set_all, size=N_samples_random,replace=False)\n",
    "    model_random = SVM(kernel='linear')\n",
    "    model_random.fit(embeddings_balanced[train_set_random],df_balanced.loc[train_set_random,'label'])\n",
    "    metrics[mc,0,:] = (model_random.predict(hard_validation_em_subset)==hard_validation_labels).mean(),(model_random.predict(easy_validation_em_subset)==easy_validation_labels).mean()\n",
    "\n",
    "    model_predicted = SVM(kernel='rbf')\n",
    "    train_set_random_predict = np.random.choice(train_set_all,size=N_samples,replace=False)\n",
    "    lasso = make_pipeline( linear_model.Lasso(alpha=0.01,max_iter=10000))\n",
    "    lasso.fit(embeddings_balanced[train_set_random_predict],df_balanced.loc[train_set_random_predict,'rarity'],   )\n",
    "    top_features =  np.argwhere(lasso.named_steps['lasso'].coef_>0).flatten()\n",
    "   # print(top_features.shape[0])\n",
    "    lm = linear_model.LinearRegression()\n",
    "    lm.fit(embeddings_balanced[train_set_random_predict],df_balanced.loc[train_set_random_predict,'rarity'])\n",
    "    predicted_difficulty = lm.predict(embeddings_balanced[train_set_all])\n",
    "    train_set_hard_predicted = train_set_all[np.argsort(predicted_difficulty)[-train_set_hard.shape[0]::]]\n",
    "    #\n",
    "    train_set_mixed_predicted = np.concatenate([train_set_hard_predicted,train_set_random_predict])\n",
    "    model_predicted.fit(embeddings_balanced[train_set_mixed_predicted],df_balanced.loc[train_set_mixed_predicted,'label'])\n",
    "\n",
    "    metrics[mc,3,:] = (model_predicted.predict(hard_validation_em_subset)==hard_validation_labels).mean(),(model_predicted.predict(easy_validation_em_subset)==easy_validation_labels).mean()\n",
    "\n",
    "    train_set_3_times = np.random.choice(train_set_all,size=3*N_samples_random,replace=True)\n",
    "    model_all = SVM(kernel='linear')\n",
    "    model_all.fit(embeddings_balanced[train_set_3_times],df_balanced.loc[train_set_3_times,'label'])\n",
    "    metrics[mc,2,:] = (model_all.predict(hard_validation_em_subset)==hard_validation_labels).mean(),(model_all.predict(easy_validation_em_subset)==easy_validation_labels).mean()\n",
    "print(df_balanced.loc[train_set_hard_predicted,'rarity'].mean(),df_balanced.loc[train_set_random_predict,'rarity'].mean())\n",
    "metrics.mean(0),metrics.std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_mc = 10\n",
    "from small_text.classifiers import ConfidenceEnhancedLinearSVC\n",
    "from small_text.classifiers.factories import SklearnClassifierFactory\n",
    "from small_text import (\n",
    "    PoolBasedActiveLearner,\n",
    "    AnchorSubsampling,\n",
    "    PredictionEntropy,\n",
    "    random_initialization_balanced,\n",
    "    SEALS\n",
    ")\n",
    "from small_text.data import SklearnDataset\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "\n",
    "clf_template = ConfidenceEnhancedLinearSVC()\n",
    "num_classes = 2\n",
    "\n",
    "clf_factory = SklearnClassifierFactory(clf_template, num_classes)\n",
    "def evaluate(active_learner, hard_dataset, easy_dataset):\n",
    "    y_pred_hard = active_learner.classifier.predict(hard_dataset)\n",
    "    y_pred_easy = active_learner.classifier.predict(easy_dataset)\n",
    "    \n",
    "    test_acc_hard = accuracy_score(y_pred_hard, hard_dataset.y)\n",
    "    test_acc_easy = accuracy_score(y_pred_easy, easy_dataset.y)\n",
    "\n",
    "    # print(f\"Test Accuracy Hard: {test_acc_hard}\")\n",
    "    # print(f\"Test Accuracy Easy: {test_acc_easy}\")\n",
    "    \n",
    "    return test_acc_hard, test_acc_easy\n",
    "\n",
    "results = np.zeros((N_mc,2,11, 2))\n",
    "for mc in tqdm(range(N_mc)):\n",
    "    hard_validation_subset = np.random.choice(hard_validation_label.shape[0],size=500,replace=False)\n",
    "    hard_validation_em_subset = hard_validation_em[hard_validation_subset]\n",
    "    hard_validation_labels = hard_validation_label[hard_validation_subset]\n",
    "    easy_validation_subset = np.random.choice(easy_validation_label.shape[0],size=500,replace=False)\n",
    "    easy_validation_em_subset = easy_validation_em[easy_validation_subset]\n",
    "    easy_validation_labels = easy_validation_label[easy_validation_subset]\n",
    "    train_set_all = df_balanced.index\n",
    "    x = embeddings_balanced[train_set_all]\n",
    "    y = df_balanced.loc[train_set_all,'label'].values\n",
    "    x_test_hard = hard_validation_em_subset\n",
    "    y_test_hard = hard_validation_labels\n",
    "    x_test_easy = easy_validation_em_subset\n",
    "    y_test_easy = easy_validation_labels\n",
    "\n",
    "    dataset = SklearnDataset(x, y, target_labels=np.arange(2))\n",
    "    hard_dataset = SklearnDataset(x_test_hard, y_test_hard, target_labels=np.arange(2))\n",
    "    easy_dataset = SklearnDataset(x_test_easy, y_test_easy, target_labels=np.arange(2))\n",
    "    learner = PoolBasedActiveLearner(\n",
    "        clf_factory,\n",
    "        AnchorSubsampling(PredictionEntropy()),\n",
    "        dataset,\n",
    "    )\n",
    "    random_indices = np.random.choice(np.arange(x.shape[0]), size=100, replace=False)\n",
    "    indices_labeled = random_indices\n",
    "    learner.initialize_data(random_indices, y[random_indices])\n",
    "    num_queries = 10\n",
    "\n",
    "    results[mc,0, 0, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "\n",
    "        \n",
    "    for i in range(num_queries):\n",
    "        # ...where each iteration consists of labelling 20 samples\n",
    "        indices_queried = learner.query(num_samples=5)\n",
    "\n",
    "        # Simulate user interaction here. Replace this for real-world usage.\n",
    "        y_true = dataset.y[indices_queried]\n",
    "\n",
    "        # Return the labels for the current query to the active learner.\n",
    "        learner.update(y_true)\n",
    "\n",
    "        indices_labeled = np.concatenate([indices_queried, indices_labeled])\n",
    "        \n",
    "        # print(f'Iteration #{i} ({len(indices_labeled)} samples)')\n",
    "        results[mc,0, i+1, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "\n",
    "\n",
    "   \n",
    "    learner = PoolBasedActiveLearner(\n",
    "        clf_factory,\n",
    "        SEALS(PredictionEntropy()),\n",
    "        dataset,\n",
    "    )\n",
    "    random_indices = np.random.choice(np.arange(x.shape[0]), size=100, replace=False)\n",
    "    indices_labeled = random_indices\n",
    "    learner.initialize_data(random_indices, y[random_indices])\n",
    "    num_queries = 10\n",
    "\n",
    "    results[mc,1, 0, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "\n",
    "        \n",
    "    for i in range(num_queries):\n",
    "        # ...where each iteration consists of labelling 20 samples\n",
    "        indices_queried = learner.query(num_samples=10)\n",
    "\n",
    "        # Simulate user interaction here. Replace this for real-world usage.\n",
    "        y_true = dataset.y[indices_queried]\n",
    "\n",
    "        # Return the labels for the current query to the active learner.\n",
    "        learner.update(y_true)\n",
    "\n",
    "        indices_labeled = np.concatenate([indices_queried, indices_labeled])\n",
    "        \n",
    "        # print(f'Iteration #{i} ({len(indices_labeled)} samples)')\n",
    "        results[mc, 1, i+1, :] = evaluate(learner, hard_dataset, easy_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results.mean(0)[:,-1,:],results.std(0)[:,-1,:],len(indices_labeled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "metadata": {},
   "outputs": [],
   "source": [
    "means_combined = np.concat([results.mean(axis=0)[:,-1,::-1],metrics.mean(axis=0)])\n",
    "quantiles_combined = np.concat([results.std(axis=0)[:,-1,::-1],metrics.std(axis=0)])\n",
    "method_names = [\n",
    "    \"AnchorAL\",\n",
    "    \"SEALS\",\n",
    "    \"Random\",\n",
    "    \"Mixed\",\n",
    "    \"RandomLarge\",\n",
    "    \"BSLB \\n (Our Method)\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means_combined"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.figure(figsize=(10,5))\n",
    "plt.bar(np.arange(means_combined.shape[0]),means_combined[:,0],yerr=quantiles_combined[:,0],label=\"Easy Validation Data\",color=\"green\",width=0.4)\n",
    "plt.bar(np.arange(means_combined.shape[0])+0.4,means_combined[:,1],yerr=quantiles_combined[:,1],label=\"Hard Validation Data\",color=\"red\",alpha=0.5,width=0.4)\n",
    "plt.ylabel(\"Accuracy\",fontsize=20)\n",
    "plt.xlabel(\"Subset of Training Data\",fontsize=20)\n",
    "plt.ylim([0.5,0.9])\n",
    "### remove lines\n",
    "plt.yticks(fontsize=15)\n",
    "plt.xticks(np.arange(len(method_names))+.2,method_names,fontsize=15,rotation=0)\n",
    "#f\"Random \\n {df_new_train_hard.shape[0]*3} Samples\",\n",
    "## set position of legend to bottom left\n",
    "plt.legend(loc=\"lower right\",fontsize=18)\n",
    "### vline from our method\n",
    "plt.hlines(means_combined[-1,1],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "plt.hlines(means_combined[-1,0],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "\n",
    "### draw a dotted box around the first two bars\n",
    "plt.gca().add_patch(plt.Rectangle((-0.3, 0.41), 2, 0.48, fill=False, edgecolor='black', linewidth=1, linestyle='--',alpha=0.5))\n",
    "\n",
    "### add a text \"active learning framework\"\n",
    "plt.text(-0.1,0.9,\"Active Learning Framework\")\n",
    "## bounding box\n",
    "plt.savefig(f\"plots/accuracy_comparison_text_classification.pdf\",bbox_inches=\"tight\",dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
