{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(active_learner, hard_dataset, easy_dataset):\n",
    "    y_pred_hard = active_learner.classifier.predict(hard_dataset)\n",
    "    y_pred_easy = active_learner.classifier.predict(easy_dataset)\n",
    "    \n",
    "    test_acc_hard = accuracy_score(y_pred_hard, hard_dataset.y)\n",
    "    test_acc_easy = accuracy_score(y_pred_easy, easy_dataset.y)\n",
    "\n",
    "    # print(f\"Test Accuracy Hard: {test_acc_hard}\")\n",
    "    # print(f\"Test Accuracy Easy: {test_acc_easy}\")\n",
    "    \n",
    "    return test_acc_easy,test_acc_hard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### load the model\n",
    "model_name = \"google/vit-base-patch16-224\"  # Replace with your desired model\n",
    "model = ViTModel.from_pretrained(model_name)\n",
    "feature_extractor = ViTImageProcessor.from_pretrained(model_name)\n",
    "def preprocess(image):\n",
    "    inputs = feature_extractor(images=image, return_tensors=\"pt\")\n",
    "    return inputs\n",
    "def load_image(image_path):\n",
    "    image = Image.open(image_path).convert('RGB')\n",
    "    return image\n",
    "# Function to obtain embeddings from the ViT model\n",
    "def get_vit_embeddings(image_path):\n",
    "    image = load_image(image_path)\n",
    "    image = preprocess(image)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**image)\n",
    "\n",
    "    # Obtain the embeddings from the last hidden state\n",
    "    embeddings = outputs.last_hidden_state\n",
    "    return embeddings\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "### load the images\n",
    "path = \"data/image_classification/VOC2012\"  #master path\n",
    "objects = os.listdir(path+\"/ImageSets/Main\")  #get the objects\n",
    "path_images = path+\"/JPEGImages\"  #path to the images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    " dfs = {}\n",
    " for i in range(len(objects)):\n",
    "    dfs[objects[i].split(\".\")[0]] = pd.read_csv(path+\"/ImageSets/Main/\"+objects[i],header=None,sep='\\s+',names=[\"image\",\"label\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_image_paths(object_name):\n",
    "    return [path_images+\"/\"+image+\".jpg\" for image in dfs[object_name][\"image\"].values]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "object_name = \"person_trainval\"\n",
    "object_name_just = object_name.split(\"_\")[0]\n",
    "paths = return_image_paths(object_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_input = pd.read_csv(\"./data/image_classification/task-input.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_input['imageurl'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_input['image'] = task_input['imageurl'].apply(lambda x:x.split(\".\")[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "vsd_data = pd.read_csv(\"./data/image_classification/VSD_dataset.csv\",names=[\"image\",\"difficulty_score\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold_easy = 3.1\n",
    "threshold_hard = 3.9\n",
    "pd.merge(task_input,vsd_data,on=\"image\").groupby(\"class\")[\"difficulty_score\"].apply(lambda x:np.array([x.shape[0],(x<=threshold_easy).sum(),(x>=threshold_hard).sum()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "person_paths = task_input[task_input['class'] == object_name_just]['imageurl'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = np.zeros((len(person_paths),768))\n",
    "for i in tqdm(range(len(person_paths))):\n",
    "    embeddings[i] = get_vit_embeddings(\"data/image_classification/VOC2012/JPEGImages/\"+person_paths[i])[:,0,:].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "### save embeddings \n",
    "np.save(\"data/image_classification/embeddings/\"+object_name+\".npy\", embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = np.load(\"data/image_classification/embeddings/\"+object_name+\".npy\")\n",
    "\n",
    "df = pd.DataFrame(embeddings)\n",
    "df[\"image\"] = person_paths\n",
    "df_person = dfs[object_name]\n",
    "df['image'] = df.image.apply(lambda x: x.split(\".\")[0])\n",
    "df = pd.merge(df,df_person,on=\"image\")\n",
    "difficulty_score = pd.read_csv(\"./data/image_classification/VSD_dataset.csv\",names=[\"image\",\"difficulty_score\"])\n",
    "df['image'] = df['image'].apply(lambda x: x.split(\".\")[0])\n",
    "df = df.merge(difficulty_score,on=\"image\")\n",
    "df_embedding = df.iloc[:,:768].values\n",
    "df['label'] = df['label'].apply(lambda x: 1 if x == 1 else 0)\n",
    "df_new = df.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 222,
   "metadata": {},
   "outputs": [],
   "source": [
    "### normalize embeddings\n",
    "df_embedding_normalized = df_embedding/df_embedding.sum(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.plot(np.sort(np.abs(np.linalg.eig(df_embedding_normalized.T@df_embedding_normalized/((df_embedding.shape[0])))[0]))[::-1])\n",
    "plt.xlabel(\"Eigenvalue Index\")\n",
    "plt.ylabel(\"Eigenvalue Magnitude\")\n",
    "#plt.yscale(\"log\")\n",
    "plt.savefig(\"plots/eigenvalue_spectrum_image.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "object_name_1 = \"chair_trainval\"\n",
    "object_name_2 = \"bottle_trainval\"\n",
    "object_name_just_1 = object_name_1.split(\"_\")[0]\n",
    "object_name_just_2 = object_name_2.split(\"_\")[0]\n",
    "\n",
    "embeddings_1 = np.load(\"data/image_classification/embeddings/\"+object_name_1+\".npy\")\n",
    "embeddings_2 = np.load(\"data/image_classification/embeddings/\"+object_name_2+\".npy\")\n",
    "person_paths_1 = task_input[task_input['class'] == object_name_just_1]['imageurl'].values\n",
    "person_paths_2 = task_input[task_input['class'] == object_name_just_2]['imageurl'].values\n",
    "### create a df with the embeddings\n",
    "df = pd.DataFrame(embeddings_1)\n",
    "df[\"image\"] = person_paths_1\n",
    "df_person = dfs[object_name_1]\n",
    "df['image'] = df.image.apply(lambda x: x.split(\".\")[0])\n",
    "df = pd.merge(df,df_person,on=\"image\")\n",
    "difficulty_score = pd.read_csv(\"./data/image_classification/VSD_dataset.csv\",names=[\"image\",\"difficulty_score\"])\n",
    "df['image'] = df['image'].apply(lambda x: x.split(\".\")[0])\n",
    "df = df.merge(difficulty_score,on=\"image\")\n",
    "df_embedding_1 = df.iloc[:,:768].values\n",
    "df['label'] = df['label'].apply(lambda x: 1 if x == 1 else 0)\n",
    "df_new_1 = df.copy()\n",
    "df = pd.DataFrame(embeddings_2)\n",
    "df[\"image\"] = person_paths_2\n",
    "df_person = dfs[object_name_2]\n",
    "df['image'] = df.image.apply(lambda x: x.split(\".\")[0])\n",
    "df = pd.merge(df,df_person,on=\"image\")\n",
    "difficulty_score = pd.read_csv(\"./data/image_classification/VSD_dataset.csv\",names=[\"image\",\"difficulty_score\"])\n",
    "df['image'] = df['image'].apply(lambda x: x.split(\".\")[0])\n",
    "df = df.merge(difficulty_score,on=\"image\")\n",
    "df_embedding_2 = df.iloc[:,:768].values\n",
    "df['label'] = df['label'].apply(lambda x: 1 if x == 1 else 0)\n",
    "df_new_2 = df.copy()\n",
    "print(df_new_1.shape[0],df_new_2.shape[0])\n",
    "unique_images = np.unique(np.concatenate([df_new_1['image'].values,df_new_2['image'].values]))\n",
    "## create new dataframe with embeddings of image from either dataframe (whichever is available) and take label as the AND between the two\n",
    "df_new = pd.DataFrame()\n",
    "for image in unique_images:\n",
    "    df_temp = pd.DataFrame()\n",
    "    if image in df_new_1['image'].values:\n",
    "        df_temp = df_new_1[df_new_1['image'] == image].reset_index(drop=True)\n",
    "    elif image in df_new_2['image'].values:\n",
    "        df_temp = df_new_2[df_new_2['image'] == image].reset_index(drop=True) \n",
    "    if (image in df_new_1['image'].values) and (image in df_new_2['image'].values):\n",
    "        df_temp['label'] = df_new_1[df_new_1['image'] == image]['label'].values[0] or df_new_2[df_new_2['image'] == image]['label'].values[0]\n",
    "        df_temp['difficulty_score'] = (df_new_1[df_new_1['image'] == image]['difficulty_score'].values[0] + df_new_2[df_new_2['image'] == image]['difficulty_score'].values[0])/2 \n",
    "    df_new = pd.concat([df_new,df_temp])\n",
    "df_new = df_new.reset_index(drop=True)\n",
    "print(df_new.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [],
   "source": [
    "### import Linear Regression from sklearn and Lasso\n",
    "from sklearn import linear_model\n",
    "from sklearn.svm import SVR,SVC\n",
    "\n",
    "def return_difficult_samples(df_new_train_easy_random,df_new_embedding,df_new,df_new_test_hard,df_new_test_easy,n_hard):\n",
    "    lasso_model = linear_model.Lasso( max_iter=100000,alpha=0.05)\n",
    "    \n",
    "    scaler = StandardScaler()\n",
    "\n",
    "    X_train = df_new_embedding[df_new_train_easy_random]\n",
    "    y_train = df_new.difficulty_score[df_new_train_easy_random]\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    clf = make_pipeline(StandardScaler(), lasso_model)\n",
    "    clf.fit(X_train, y_train)\n",
    "    top_features = np.argwhere(clf.named_steps['lasso'].coef_!=0).flatten()\n",
    "    print(top_features.shape)\n",
    "    ### train a linear regression model on the embeddings on the top features\n",
    "    X = df_new_embedding[df_new_train_easy_random][:,top_features]\n",
    "    X = scaler.fit_transform(X)\n",
    "    y = df_new.difficulty_score[df_new_train_easy_random]\n",
    "    ### increase convergence tolerance\n",
    "    lm =linear_model.LinearRegression()\n",
    "\n",
    "    lm.fit(X, y)\n",
    "    \n",
    "    X_test = df_new_embedding[:,top_features]\n",
    "    X_test = scaler.transform(X_test)\n",
    "    y_test = df_new.difficulty_score\n",
    "    y_pred = lm.predict(X_test)\n",
    "    print(((y_pred-y_test)**2).mean())\n",
    "    sorted_indices = np.argsort(y_pred)\n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_train_easy_random]\n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_test_hard]\n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_test_easy]\n",
    "    print(y_pred[sorted_indices[-n_hard:]].mean(),y_test[sorted_indices[-n_hard:]].mean())\n",
    "    return sorted_indices[-n_hard:]\n",
    "def return_difficult_samples_classifier(df_new_train_easy_random,df_new_embedding,df_new,df_new_test_hard,df_new_test_easy,n_hard):\n",
    "    # lasso_model = linear_model.Lasso( max_iter=100000,alpha=0.1)\n",
    "    \n",
    "    # scaler = StandardScaler()\n",
    "    scaler = StandardScaler()\n",
    "    X_train = df_new_embedding[df_new_train_easy_random]\n",
    "    y_train = df_new.difficulty_score[df_new_train_easy_random]>3.8\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    clf =  ConfidenceEnhancedLinearSVC() #SVC(kernel=\"rbf\",probability=True,class_weight=\"balanced\")\n",
    "    clf.fit(X_train, y_train)\n",
    "    \n",
    "    \n",
    "    X_test = df_new_embedding[:,:]\n",
    "    X_test = scaler.transform(X_test)\n",
    "    y_test = df_new.difficulty_score>3.8\n",
    "    y_pred = clf.predict_proba(X_test)[:,1]\n",
    "    sorted_indices = np.argsort(y_pred)\n",
    "    \n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_train_easy_random]\n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_test_hard]\n",
    "    sorted_indices = [i for i in sorted_indices if i not in df_new_test_easy]\n",
    "    print(y_test[sorted_indices[-n_hard:]].sum())\n",
    "    return sorted_indices[-n_hard:]\n",
    "def train_clf(df_subset,df_embeddings,df_labels):\n",
    "    X = df_embeddings[df_subset]\n",
    "    y = df_labels[df_subset]\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X)\n",
    "    clf = make_pipeline(StandardScaler(), SVC(class_weight=\"balanced\",max_iter=10000) )\n",
    "    clf.fit(X_train, y)\n",
    "    return clf,scaler\n",
    "def return_metric(clf_model,df_test_easy,df_test_hard,df_embedding,df_labels,scaler):\n",
    "    \n",
    "    X_test = df_embedding[df_test_easy]\n",
    "    X_test = scaler.transform(X_test)\n",
    "    y_test = df_labels[df_test_easy]\n",
    "    y_pred = clf_model.predict(X_test)\n",
    "    easy_accuracy = (y_test==y_pred).mean()\n",
    "    X_test = df_embedding[df_test_hard]\n",
    "    X_test = scaler.transform(X_test)\n",
    "    y_test = df_labels[df_test_hard]\n",
    "    y_pred = clf_model.predict(X_test)\n",
    "    hard_accuracy = (y_test==y_pred).mean()\n",
    "    return easy_accuracy,hard_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [],
   "source": [
    "### get good vectors by maximizing the minimum eigenvalue of the covariance matrix\n",
    "import cvxpy as cp\n",
    "def get_good_vectors(embeddings):\n",
    "    n = embeddings.shape[0]\n",
    "    d = embeddings.shape[1]\n",
    "    x = cp.Variable(n)\n",
    "    constraints = [cp.sum(x)==1,x>=0]\n",
    "    objective_function = embeddings.T @ cp.diag(x) @ embeddings\n",
    "    objective = cp.Maximize(cp.lambda_min(objective_function))\n",
    "    prob = cp.Problem(objective,constraints)\n",
    "    prob.solve()\n",
    "    return x.value\n",
    " ### pca from scikit learn\n",
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=10)\n",
    "pca.fit(df_new.iloc[:,:768].values)\n",
    "df_new_pca = pca.transform(df_new.iloc[:,:768].values)\n",
    "good_vectors = get_good_vectors(df_new_pca)\n",
    "top_d_good_vectors = np.argsort(good_vectors)[-500:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "clf_template = ConfidenceEnhancedLinearSVC()\n",
    "num_classes = 2\n",
    "clf_factory = SklearnClassifierFactory(clf_template, num_classes)\n",
    "np.random.seed(0)\n",
    "\n",
    "lower_bound = 3.1\n",
    "upper_bound = 3.9\n",
    "\n",
    "n_easy = 120\n",
    "n_difficult = 120\n",
    "num_easy_train = 100\n",
    "num_difficult_train = 40\n",
    "active_learning_init = num_easy_train\n",
    "\n",
    "num_sample_per_query = num_difficult_train\n",
    "N_queries = 1\n",
    "\n",
    "print((df_new.difficulty_score<=lower_bound).sum(),(df_new.difficulty_score>=upper_bound).sum())\n",
    "N_random = 5\n",
    "\n",
    "\n",
    "results = np.zeros((N_random,2,N_queries+1, 2))\n",
    "metrics = np.zeros((N_random,3,2))\n",
    "for i in tqdm(range(N_random)):\n",
    "    df_new_test_easy = np.random.choice(df_new[df_new.difficulty_score<=lower_bound].index,n_easy,replace=False)\n",
    "    df_new_test_hard = np.random.choice(df_new[df_new.difficulty_score>=upper_bound].index,n_difficult,replace=False)\n",
    "    df_new_train_easy = df_new[(df_new.difficulty_score<=lower_bound)& ~(df_new.index.isin(df_new_test_easy))].index\n",
    "    df_new_train_hard = df_new[(df_new.difficulty_score>=upper_bound)& ~(df_new.index.isin(df_new_test_hard))].index\n",
    "    top_d_good_vectors_allowed = [i for i in top_d_good_vectors if (i not in df_new_test_easy) and (i not in df_new_test_hard)]\n",
    "    df_new_embedding = df_new.iloc[:,:768].values\n",
    "    df_new_label = df_new.label.values\n",
    "    df_all = np.concatenate([df_new_train_easy,df_new_train_hard])\n",
    "    # ### train an svm on easy images\n",
    "    df_new_train_easy_random = np.random.choice(df_all,num_easy_train+num_difficult_train,replace=False)\n",
    "    clf,scaler = train_clf(df_new_train_easy_random,df_new_embedding,df_new_label)\n",
    "    metrics[i,0] = return_metric(clf,df_new_test_easy,df_new_test_hard,df_new_embedding,df_new_label,scaler)\n",
    "    \n",
    "    # ### train an svm on hard images\n",
    "    # clf_hard,scaler = train_clf(df_new_train_hard,df_new_embedding,df_new_label)\n",
    "    # metrics[i,1] = return_metric(clf_hard,df_new_test_easy,df_new_test_hard,df_new_embedding,df_new_label,scaler)\n",
    "\n",
    "    # ### train an svm on mixed \n",
    "    # df_new_train_easy_random = np.random.choice(df_all,df_new_train_hard.shape[0],replace=False)\n",
    "    # df_subset_mixed = np.concatenate([df_new_train_easy_random,df_new_train_hard])\n",
    "    # clf_mixed,scaler = train_clf(df_subset_mixed,df_new_embedding,df_new_label)\n",
    "    # metrics[i,1] = return_metric(clf_mixed,df_new_test_easy,df_new_test_hard,df_new_embedding,df_new_label,scaler)\n",
    "    \n",
    "    #### train on all images\n",
    "    \n",
    "    clf_all,scaler = train_clf(df_all,df_new_embedding,df_new_label)\n",
    "    metrics[i,1] = return_metric(clf_all,df_new_test_easy,df_new_test_hard,df_new_embedding,df_new_label,scaler)\n",
    "\n",
    "\n",
    "    ## train on random and predicted difficult images\n",
    "    df_new_train_easy_random = np.random.choice(top_d_good_vectors_allowed,num_easy_train,replace=False)\n",
    "    #clf_random_difficult,scaler = train_clf(df_new_train_easy_random,df_new_embedding,df_new_label)\n",
    "    df_difficult = return_difficult_samples(df_new_train_easy_random,df_new_embedding,df_new,df_new_test_hard,df_new_test_easy,n_hard = num_difficult_train)\n",
    "    df_mixed = np.concatenate([df_new_train_easy_random,df_difficult])\n",
    "    clf_difficult,scaler = train_clf(df_mixed,df_new_embedding,df_new_label)\n",
    "    metrics[i,2,:] = return_metric(clf_difficult,df_new_test_easy,df_new_test_hard,df_new_embedding,df_new_label,scaler)\n",
    "    \n",
    "    ## active learning\n",
    "    df_new_label = df_new.label.values==1\n",
    "\n",
    "    x = df_new_embedding[np.concatenate([df_new_train_easy,df_new_train_hard])]\n",
    "    y = df_new_label[np.concatenate([df_new_train_easy,df_new_train_hard])]\n",
    "    x_test_hard = df_new_embedding[df_new_test_hard]\n",
    "    y_test_hard = df_new_label[df_new_test_hard]\n",
    "    x_test_easy = df_new_embedding[df_new_test_easy]\n",
    "    y_test_easy = df_new_label[df_new_test_easy]\n",
    "\n",
    "    dataset = SklearnDataset(x, y, target_labels=np.arange(2))\n",
    "    hard_dataset = SklearnDataset(x_test_hard, y_test_hard, target_labels=np.arange(2))\n",
    "    easy_dataset = SklearnDataset(x_test_easy, y_test_easy, target_labels=np.arange(2))\n",
    "    learner = PoolBasedActiveLearner(\n",
    "        clf_factory,\n",
    "        AnchorSubsampling(PredictionEntropy(),num_anchors=2),\n",
    "        dataset,\n",
    "    )\n",
    "    random_indices = np.random.choice(np.arange(x.shape[0]), size=active_learning_init, replace=False)\n",
    "    indices_labeled = random_indices\n",
    "    learner.initialize_data(random_indices, y[random_indices])\n",
    "\n",
    "    results[i,0, 0, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "\n",
    "        \n",
    "    for j in range(N_queries):\n",
    "        # ...where each iteration consists of labelling 20 samples\n",
    "        indices_queried = learner.query(num_samples=num_sample_per_query)\n",
    "\n",
    "        # Simulate user interaction here. Replace this for real-world usage.\n",
    "        y_true = dataset.y[indices_queried]\n",
    "\n",
    "        # Return the labels for the current query to the active learner.\n",
    "        learner.update(y_true)\n",
    "\n",
    "        indices_labeled = np.concatenate([indices_queried, indices_labeled])\n",
    "        \n",
    "        #print(f'Iteration #{i} ({len(indices_labeled)} samples)')\n",
    "        results[i,0, j+1, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "        \n",
    "    learner = PoolBasedActiveLearner(\n",
    "        clf_factory,\n",
    "        SEALS(PredictionEntropy(),),\n",
    "        dataset,\n",
    "    )\n",
    "    random_indices = np.random.choice(np.arange(x.shape[0]), size=active_learning_init, replace=False)\n",
    "    indices_labeled = random_indices\n",
    "    learner.initialize_data(random_indices, y[random_indices])\n",
    "\n",
    "    results[i,1, 0, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "        \n",
    "    for j in range(N_queries):\n",
    "        # ...where each iteration consists of labelling 20 samples\n",
    "        indices_queried = learner.query(num_samples=num_sample_per_query)\n",
    "\n",
    "        # Simulate user interaction here. Replace this for real-world usage.\n",
    "        y_true = dataset.y[indices_queried]\n",
    "\n",
    "        # Return the labels for the current query to the active learner.\n",
    "        learner.update(y_true)\n",
    "\n",
    "        indices_labeled = np.concatenate([indices_queried, indices_labeled])\n",
    "        \n",
    "        #print(f'Iteration #{i} ({len(indices_labeled)} samples)')\n",
    "        results[i,1, j+1, :] = evaluate(learner, hard_dataset, easy_dataset)\n",
    "    #print(np.vstack([metrics[i],results[i][0,-1,:],results[i][1,-1,:]]))\n",
    "print(np.vstack([results.mean(axis=0)[:,-1,:],metrics.mean(axis=0)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "object_name\n",
    "\n",
    "arr = [str(np.round(100*np.concatenate([results.mean(axis=0)[:,-1,:],metrics.mean(axis=0),]).T.flatten()[i],3))+\"+\" + str(np.round(100*np.concatenate([results.std(axis=0)[:,-1,:],metrics.std(axis=0),]).T.flatten()[i],3)) for i in range((np.vstack([results.mean(axis=0)[:,-1,:],metrics.mean(axis=0)]).T).flatten().shape[0])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\" & \".join(arr[5:]).replace(\"+\",r\"$\\pm$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.concatenate([results.std(axis=0)[:,-1,:],metrics.std(axis=0),]).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "object_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_new.difficulty_score.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "def train_clf(df_subset,df_embeddings,df_labels):\n",
    "    X = df_embeddings[df_subset]\n",
    "    y = df_labels[df_subset]\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X)\n",
    "    clf = make_pipeline(StandardScaler(), SVM(kernel=\"rbf\",C=2))\n",
    "    clf.fit(X_train, y)\n",
    "    return clf,scaler\n",
    "clf_all,scaler = train_clf(df_all,df_new_embedding,df_new_label)\n",
    "\n",
    "frac_difficult = np.zeros(10)\n",
    "decision_fn = clf_all.decision_function(scaler.transform(df_new.iloc[:,:768].values))\n",
    "grid = np.linspace(0,np.max(decision_fn),10)\n",
    "grid_precision = (grid[-1]-grid[0])/10\n",
    "mean_difficulty = np.zeros(10)\n",
    "for i in range(10):\n",
    "    indices_gt = np.where((decision_fn<=grid[i])&(decision_fn>grid[i]-grid_precision))[0]\n",
    "    num_total = indices_gt.shape[0]\n",
    "    num_diff = (df_new.iloc[indices_gt].difficulty_score>=3.8).sum()\n",
    "    mean_difficulty[i] = df_new.iloc[indices_gt].difficulty_score.mean()\n",
    "    frac_difficult[i] = num_diff/num_total\n",
    "plt.plot(grid,frac_difficult)\n",
    "plt.xlabel(\"Distance from the decision boundary\",fontsize=15)\n",
    "plt.ylabel(\"Fraction of difficult samples\",fontsize=15)\n",
    "plt.savefig(\"plots/fraction_difficult.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(df_new.difficulty_score,bins=100)\n",
    "plt.xlabel(\"Difficulty Score\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.title(\"Histogram of Difficulty Score\")\n",
    "plt.savefig(\"plots/histogram_difficulty_score.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 419,
   "metadata": {},
   "outputs": [],
   "source": [
    "means_combined = np.concat([results.mean(axis=0)[:,-1,::-1],metrics.mean(axis=0)])\n",
    "quantiles_combined = np.concat([results.std(axis=0)[:,-1,::-1],metrics.std(axis=0)])\n",
    "method_names = [\n",
    "    \"AnchorAL\",\n",
    "    \"SEALS\",\n",
    "    \"Random\",\n",
    "    \"All\",\n",
    "    \"BSLB \\n (Our Method)\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.figure(figsize=(10,5))\n",
    "plt.bar(np.arange(means_combined.shape[0]),means_combined[:,0],yerr=quantiles_combined[:,0],label=\"Easy Validation Data\",color=\"green\",width=0.4)\n",
    "plt.bar(np.arange(means_combined.shape[0])+0.4,means_combined[:,1],yerr=quantiles_combined[:,1],label=\"Hard Validation Data\",color=\"red\",alpha=0.5,width=0.4)\n",
    "plt.ylabel(\"Accuracy\",fontsize=20)\n",
    "plt.xlabel(\"Subset of Training Data\",fontsize=20)\n",
    "plt.ylim([0.5,1])\n",
    "### remove lines\n",
    "plt.yticks(fontsize=15)\n",
    "plt.xticks(np.arange(len(method_names))+.2,method_names,fontsize=15,rotation=0)\n",
    "#f\"Random \\n {df_new_train_hard.shape[0]*3} Samples\",\n",
    "## set position of legend to bottom left\n",
    "plt.legend(loc=\"lower right\",fontsize=18)\n",
    "### vline from our method\n",
    "plt.hlines(means_combined[-1,1],-.2,5,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "plt.hlines(means_combined[-1,0],-.2,5,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "\n",
    "### draw a dotted box around the first two bars\n",
    "plt.gca().add_patch(plt.Rectangle((-0.3, 0.51), 2, 0.48, fill=False, edgecolor='black', linewidth=1, linestyle='--',alpha=0.5))\n",
    "\n",
    "### add a text \"active learning framework\"\n",
    "plt.text(-0.1,1,\"Active Learning Framework\")\n",
    "## bounding box\n",
    "plt.savefig(f\"plots/accuracy_comparison_{object_name}.pdf\",bbox_inches=\"tight\",dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.figure(figsize=(10,5))\n",
    "plt.bar(np.arange(means_combined.shape[0]),means_combined[:,0],yerr=quantiles_combined[:,0],label=\"Easy Validation Data\",color=\"green\",width=0.4)\n",
    "# plt.bar(np.arange(means_combined.shape[0])+0.4,means_combined[:,1],yerr=quantiles_combined[:,1],label=\"Hard Validation Data\",color=\"red\",alpha=0.5,width=0.4)\n",
    "plt.ylabel(\"Accuracy\",fontsize=20)\n",
    "plt.xlabel(\"Subset of Training Data\",fontsize=20)\n",
    "plt.ylim([0.5,1])\n",
    "### remove lines\n",
    "plt.yticks(fontsize=15)\n",
    "plt.xticks(np.arange(len(method_names))+.2,method_names,fontsize=15,rotation=0)\n",
    "#f\"Random \\n {df_new_train_hard.shape[0]*3} Samples\",\n",
    "## set position of legend to bottom left\n",
    "plt.legend(loc=\"lower right\",fontsize=18)\n",
    "### vline from our method\n",
    "plt.hlines(means_combined[-1,1],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "plt.hlines(means_combined[-1,0],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "\n",
    "### draw a dotted box around the first two bars\n",
    "plt.gca().add_patch(plt.Rectangle((-0.3, 0.51), 1.6, 0.48, fill=False, edgecolor='black', linewidth=1, linestyle='--',alpha=0.5))\n",
    "\n",
    "### add a text \"active learning framework\"\n",
    "plt.text(-0.2,1,\"Active Learning Framework\")\n",
    "## bounding box\n",
    "plt.savefig(f\"plots/accuracy_comparison_{object_name}.pdf\",bbox_inches=\"tight\",dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "})\n",
    "plt.figure(figsize=(10,5))\n",
    "plt.bar(np.arange(means_combined.shape[0]),means_combined[:,1],yerr=quantiles_combined[:,1],label=\"Hard Validation Data\",color=\"red\",width=0.4,alpha=0.5)\n",
    "plt.ylabel(\"Accuracy\",fontsize=20)\n",
    "plt.xlabel(\"Subset of Training Data\",fontsize=20)\n",
    "plt.ylim([0.5,0.8])\n",
    "### remove lines\n",
    "plt.yticks(fontsize=15)\n",
    "plt.xticks(np.arange(len(method_names))+.2,method_names,fontsize=15,rotation=0)\n",
    "#f\"Random \\n {df_new_train_hard.shape[0]*3} Samples\",\n",
    "## set position of legend to bottom left\n",
    "plt.legend(loc=\"lower right\",fontsize=18)\n",
    "### vline from our method\n",
    "plt.hlines(means_combined[-1,1],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "plt.hlines(means_combined[-1,0],-.2,5.6,linestyles=\"--\",colors=\"black\",alpha=0.2)\n",
    "\n",
    "### draw a dotted box around the first two bars\n",
    "plt.gca().add_patch(plt.Rectangle((-0.3, 0.51), 1.6, 0.28, fill=False, edgecolor='black', linewidth=1, linestyle='--',alpha=0.5))\n",
    "\n",
    "### add a text \"active learning framework\"\n",
    "plt.text(-0.2,0.8,\"Active Learning Framework\")\n",
    "## bounding box\n",
    "plt.savefig(f\"plots/accuracy_comparison_{object_name}.pdf\",bbox_inches=\"tight\",dpi=300)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### get good vectors by maximizing the minimum eigenvalue of the covariance matrix\n",
    "import cvxpy as cp\n",
    "def get_good_vectors(embeddings):\n",
    "    n = embeddings.shape[0]\n",
    "    d = embeddings.shape[1]\n",
    "    x = cp.Variable(n)\n",
    "    constraints = [cp.sum(x)==1,x>=0]\n",
    "    objective_function = embeddings.T @ cp.diag(x) @ embeddings\n",
    "    objective = cp.Maximize(cp.lambda_min(objective_function))\n",
    "    prob = cp.Problem(objective,constraints)\n",
    "    prob.solve()\n",
    "    return x.value\n",
    " ### pca from scikit learn\n",
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=10)\n",
    "pca.fit(df_new.iloc[:,:768].values)\n",
    "df_new_pca = pca.transform(df_new.iloc[:,:768].values)\n",
    "good_vectors = get_good_vectors(df_new_pca)\n",
    "top_d_good_vectors = np.argsort(good_vectors)[-1000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
