{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/workspace/repositories/offimg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from IPython.display import Image\n",
    "import glob\n",
    "from main.smid.utils import precision_recall, f1\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%%\n",
    "n_splits= 10\n",
    "#model_type = 'resnet50'\n",
    "#model_type = 'finetuned'\n",
    "#model_type = 'probe'\n",
    "#model_type = 'sim'\n",
    "#language_model = 'Resnet'\n",
    "#language_model = 'Clip_RN50'\n",
    "#language_model = 'Clip_RN50x4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_result(data_path):\n",
    "    label = os.path.basename(os.path.dirname(os.path.dirname(data_path)))\n",
    "    label = label[:4]\n",
    "    data = np.loadtxt(open(data_path, \"rb\"), delimiter=\",\", skiprows=1)\n",
    "    print('#split,total_acc,acc_t_gt,acc_t_lt,init_total_acc,init_acc_t_gt,init_acc_t_lt')\n",
    "    total_acc,acc_t_gt,acc_t_lt,init_total_acc,init_acc_t_gt,init_acc_t_lt = data[:,1], data[:,2], data[:,3], data[:,4],data[:,5],data[:,6]\n",
    "    return total_acc,acc_t_gt,acc_t_lt,init_total_acc,init_acc_t_gt,init_acc_t_lt, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "def load_data(data_path):\n",
    "    return pickle.load(open(data_path, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_results(files_):\n",
    "    res_ = []\n",
    "    labels_ = []\n",
    "    for data_path in files_:\n",
    "        total_acc,acc_t_gt,acc_t_lt,init_total_acc,init_acc_t_gt,init_acc_t_lt, label = read_result(data_path)\n",
    "\n",
    "        total_acc_mean = np.mean(total_acc)\n",
    "        total_acc_std = np.std(total_acc)\n",
    "\n",
    "        init_total_acc_mean = np.mean(init_total_acc)\n",
    "        init_total_acc_std = np.std(init_total_acc)\n",
    "\n",
    "        print(f\"{label}: {total_acc_mean:.2f} +- {total_acc_std:.2f}, Init: {init_total_acc_mean:.2f} +- {init_total_acc_std:.2f}\")\n",
    "        res_.append((total_acc_mean, total_acc_std, init_total_acc_mean, init_total_acc_std, np.array(init_total_acc, dtype=float)))\n",
    "        labels_.append(int(100*(float(label))))\n",
    "    res_ = np.array([x for _, x in sorted(zip(labels_, res_))])\n",
    "    \n",
    "    labels_.sort()\n",
    "    labels_ = [f'{str(l)}\\\\%' for l in labels_]\n",
    "    #labels = labels[::-1]\n",
    "    #res = res[::-1]\n",
    "    return labels_, res_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    'Clip_ViT-B-16_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_ViT-B-32_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_probe_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_probe_tuning_full': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet_full': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet_21k': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet_21k_full': {'labels': None, 'res': None},\n",
    "}\n",
    "data_type = 'moral'\n",
    "data_type = 'harm'\n",
    "data_type = 'valence'\n",
    "suffix = '_new_multiTestset0dot1'\n",
    "#suffix = '_new'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'Clip_ViT-B-16'\n",
    "method = 'prompt_tuning_clip'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "#print(files_)\n",
    "#for file_ in files_:\n",
    "#    print(file_)\n",
    "#1 / 0\n",
    "print('Files found', len(files_))\n",
    "data['Clip_ViT-B-16_prompt_tuning']['labels'], data['Clip_ViT-B-16_prompt_tuning']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "language_model = 'Clip_ViT-B-32'\n",
    "method = 'prompt_tuning_clip'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "#print(files_)\n",
    "print('Files found', len(files_))\n",
    "data['Clip_ViT-B-32_prompt_tuning']['labels'], data['Clip_ViT-B-32_prompt_tuning']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'Clip_RN50'\n",
    "method = 'probe_tuning_clip'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['Clip_RN50_probe_tuning']['labels'], data['Clip_RN50_probe_tuning']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'Clip_RN50'\n",
    "method = 'probe_tuning_clip_full'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['Clip_RN50_probe_tuning_full']['labels'], data['Clip_RN50_probe_tuning_full']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'Clip_RN50'\n",
    "method = 'prompt_tuning_clip'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['Clip_RN50_prompt_tuning']['labels'], data['Clip_RN50_prompt_tuning']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'resnet50'\n",
    "method = 'probe_tuning_imagenet'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['resnet50_probe_tuning_imagenet']['labels'], data['resnet50_probe_tuning_imagenet']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'resnet50'\n",
    "method = 'probe_tuning_imagenet_full'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['resnet50_probe_tuning_imagenet_full']['labels'], data['resnet50_probe_tuning_imagenet_full']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'resnet50'\n",
    "method = 'probe_tuning_imagenet_21k'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['resnet50_probe_tuning_imagenet_21k']['labels'], data['resnet50_probe_tuning_imagenet_21k']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "language_model = 'resnet50'\n",
    "method = 'probe_tuning_imagenet_21k_full'\n",
    "base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{language_model}/{method}\"\n",
    "base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "#print(base_dir)\n",
    "files_ = glob.glob(base_dir + \"/*2.50_3.50/*/prediction.csv\")\n",
    "print('Files found', len(files_))\n",
    "data['resnet50_probe_tuning_imagenet_21k_full']['labels'], data['resnet50_probe_tuning_imagenet_21k_full']['res'] = read_results(files_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pylab as pylab\n",
    "params = {\n",
    "    'legend.fontsize': 'x-large',\n",
    "    'figure.figsize': (8, 6),\n",
    "    'axes.labelsize': 'x-large',\n",
    "    'axes.titlesize':'x-large',\n",
    "    'xtick.labelsize':'xx-large',\n",
    "    'ytick.labelsize':'xx-large',\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'legend.fontsize': 10,\n",
    "    'legend.handlelength': 1\n",
    "}\n",
    "\n",
    "pylab.rcParams.update(params)\n",
    "plt.rcParams['axes.xmargin'] = 0.05\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "data = {\n",
    "    'Clip_ViT-B_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_probe_tuning': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet': {'labels': None, 'res': None},\n",
    "}\n",
    "\"\"\"\n",
    "skip_step = 1\n",
    "res01, labels01 = data['Clip_ViT-B-16_prompt_tuning']['res'][::skip_step], data['Clip_ViT-B-16_prompt_tuning']['labels'][::skip_step]\n",
    "res0, labels0 = data['Clip_ViT-B-32_prompt_tuning']['res'][::skip_step], data['Clip_ViT-B-32_prompt_tuning']['labels'][::skip_step]\n",
    "res1, labels1 = data['Clip_RN50_prompt_tuning']['res'][::skip_step], data['Clip_RN50_prompt_tuning']['labels'][::skip_step]\n",
    "res2, labels2 = data['Clip_RN50_probe_tuning']['res'][::skip_step], data['Clip_RN50_probe_tuning']['labels'][::skip_step]\n",
    "res3, labels3 = data['resnet50_probe_tuning_imagenet']['res'][::skip_step], data['resnet50_probe_tuning_imagenet']['labels'][::skip_step]\n",
    "res4, labels4 = data['resnet50_probe_tuning_imagenet_full']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_full']['labels'][::skip_step]\n",
    "res5, labels5 = data['resnet50_probe_tuning_imagenet_21k']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_21k']['labels'][::skip_step]\n",
    "res6, labels6 = data['resnet50_probe_tuning_imagenet_21k_full']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_21k_full']['labels'][::skip_step]\n",
    "\n",
    "markersize = 15\n",
    "plt.figure(figsize=(8,4), dpi=1200)\n",
    "zero_shot_pos = len(res0)#0\n",
    "t = np.reshape(np.stack(res01[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res01), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res01[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res01[:,1], zero_shot_pos, zeroshot_total_acc_std), \n",
    "             linestyle='None', marker='*', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[5], label='Prompt-tuning (ViT-B/16 CLIP-WebImageText)')\n",
    "\n",
    "t = np.reshape(np.stack(res0[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res0[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res0[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[1], label='Prompt-tuning (ViT-B/32 CLIP-WebImageText)')\n",
    "plt.xticks(ticks=x, labels=labels0+[\"0\\\\%\"], rotation='vertical')\n",
    "\n",
    "\n",
    "t = np.reshape(np.stack(res1[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res1), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res1[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res1[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[3], label='Prompt-tuning (ResNet50 CLIP-WebImageText)', fillstyle='none')\n",
    "\n",
    "t = np.reshape(np.stack(res2[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res2),len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res2[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res2[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[2], label='Probing (ResNet50 CLIP-WebImageText)', fillstyle='none')\n",
    "\n",
    "t = np.reshape(np.stack(res3[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res3), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res3[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res3[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[4], label='Probing (ResNet50 Imagenet1k)')\n",
    "\n",
    "t = np.reshape(np.stack(res4[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res4), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res4[:,0], zero_shot_pos, zeroshot_total_acc_mean), \n",
    "             np.insert(res4[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=.5, fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[4], label='Fine-tuning (ResNet50 Imagenet1k)')\n",
    "\n",
    "t = np.reshape(np.stack(res5[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res5), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res5[:,0], zero_shot_pos, zeroshot_total_acc_mean), \n",
    "             np.insert(res5[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[5], label='Probing Fine-tuning (ResNet50 Imagenet21k)')\n",
    "\n",
    "t = np.reshape(np.stack(res6[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res6), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res6[:,0], zero_shot_pos, zeroshot_total_acc_mean), \n",
    "             np.insert(res6[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=.5, fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[5], label='Fine-tuning (ResNet50 Imagenet21k)')\n",
    "# Plot performance of initialization\n",
    "#x = np.arange(len(res0))\n",
    "#plt.errorbar(x, res0[:,2], res0[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[1], label='Zero-shot (ViT-B/32 CLIP)')\n",
    "\n",
    "\"\"\"\n",
    "x = np.arange(len(res0)-len(res01), len(res0))\n",
    "plt.errorbar(x, res01[:,2], res01[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "            c=sns.color_palette('deep')[5], label='Zero-shot (ViT-B/16 CLIP)')\n",
    "\n",
    "#x = np.arange(len(res0)-len(res1), len(res0))\n",
    "#plt.errorbar(x, res1[:,2], res1[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[3], label='Zero-shot (ResNet50 CLIP)')\n",
    "\n",
    "#x = np.arange(len(res0)-len(res2), len(res0))\n",
    "#plt.errorbar(x, res2[:,2], res2[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[2], label='Init Probing (ResNet50 CLIP)')\n",
    "\n",
    "x = np.arange(len(res0)-len(res3), len(res0))\n",
    "plt.errorbar(x, res3[:,2], res3[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "            c=sns.color_palette('deep')[4], label='Init Probing (ResNet50 Imagenet)')\n",
    "\"\"\"\n",
    "#plt.hlines(np.mean(res[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[1], \n",
    "#           linestyles='solid', linewidth=3, label='Prompt init (ViT-32/B CLIP)')\n",
    "#plt.hlines(np.mean(res3[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[3], \n",
    "#           linestyles='solid', linewidth=3, label='Prompt init (ResNet50 CLIP)')\n",
    "#plt.hlines(np.mean(res2[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[2], \n",
    "#           linestyles='solid', linewidth=3, label='Probe init (ResNet50 CLIP)')\n",
    "#plt.hlines(np.mean(res4[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[4], \n",
    "#           linestyles='solid', linewidth=3, label='Probe init (ResNet50 Imagenet)')\n",
    "\n",
    "plt.vlines(10.5, 0, 100, colors=sns.color_palette('deep')[7], \n",
    "           linestyles='--', linewidth=3)\n",
    "#plt.grid(True)\n",
    "plt.ylim(40, 100)\n",
    "ax = plt.gca()\n",
    "ax.set_xlabel('\\\\# of labeled training examples used, shown in \\\\% of training set.', fontsize=14)\n",
    "ax.set_ylabel('Test Accuracy', fontsize=20)\n",
    "\n",
    "box = ax.get_position()\n",
    "ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n",
    "\n",
    "# Put a legend to the right of the current axis\n",
    "lgnd = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), \n",
    "                 markerscale=0.7)\n",
    "#plt.legend(loc=8)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pylab as pylab\n",
    "params = {\n",
    "    'legend.fontsize': 'x-large',\n",
    "    'figure.figsize': (8, 6),\n",
    "    'axes.labelsize': 'x-large',\n",
    "    'axes.titlesize':'x-large',\n",
    "    'xtick.labelsize':'xx-large',\n",
    "    'ytick.labelsize':'xx-large',\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'legend.fontsize': 10,\n",
    "    'legend.handlelength': 1\n",
    "}\n",
    "\n",
    "pylab.rcParams.update(params)\n",
    "plt.rcParams['axes.xmargin'] = 0.05\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "data = {\n",
    "    'Clip_ViT-B_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_prompt_tuning': {'labels': None, 'res': None},\n",
    "    'Clip_RN50_probe_tuning': {'labels': None, 'res': None},\n",
    "    'resnet50_probe_tuning_imagenet': {'labels': None, 'res': None},\n",
    "}\n",
    "\"\"\"\n",
    "skip_step = 1\n",
    "res01, labels01 = data['Clip_ViT-B-16_prompt_tuning']['res'][::skip_step], data['Clip_ViT-B-16_prompt_tuning']['labels'][::skip_step]\n",
    "res0, labels0 = data['Clip_ViT-B-32_prompt_tuning']['res'][::skip_step], data['Clip_ViT-B-32_prompt_tuning']['labels'][::skip_step]\n",
    "res1, labels1 = data['Clip_RN50_prompt_tuning']['res'][::skip_step], data['Clip_RN50_prompt_tuning']['labels'][::skip_step]\n",
    "res2, labels2 = data['Clip_RN50_probe_tuning']['res'][::skip_step], data['Clip_RN50_probe_tuning']['labels'][::skip_step]\n",
    "res3, labels3 = data['resnet50_probe_tuning_imagenet']['res'][::skip_step], data['resnet50_probe_tuning_imagenet']['labels'][::skip_step]\n",
    "res4, labels4 = data['resnet50_probe_tuning_imagenet_full']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_full']['labels'][::skip_step]\n",
    "res5, labels5 = data['resnet50_probe_tuning_imagenet_21k']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_21k']['labels'][::skip_step]\n",
    "res6, labels6 = data['resnet50_probe_tuning_imagenet_21k_full']['res'][::skip_step], data['resnet50_probe_tuning_imagenet_21k_full']['labels'][::skip_step]\n",
    "\n",
    "markersize = 15\n",
    "alpha=0.8\n",
    "plt.figure(figsize=(8,4), dpi=1200)\n",
    "zero_shot_pos = len(res0)#0\n",
    "t = np.reshape(np.stack(res01[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res01), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res01[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res01[:,1], zero_shot_pos, zeroshot_total_acc_std), \n",
    "             linestyle='None', marker='*', alpha=alpha, fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[1], label='Prompt-tuning (ViT-B/16 CLIP-WebImageText)')\n",
    "\n",
    "t = np.reshape(np.stack(res0[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res0[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res0[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=alpha, fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[2], label='Prompt-tuning (ViT-B/32 CLIP-WebImageText)')\n",
    "plt.xticks(ticks=x, labels=labels0+[\"0\\\\%\"], rotation='vertical')\n",
    "\n",
    "\n",
    "t = np.reshape(np.stack(res1[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res1), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res1[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res1[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=alpha, fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[3], label='Prompt-tuning (ResNet50 CLIP-WebImageText)', fillstyle='none')\n",
    "\n",
    "t = np.reshape(np.stack(res2[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res2),len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res2[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res2[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='*', alpha=1., fmt=':', capsize=10, capthick=1, ms=markersize*1.2,\n",
    "            c=sns.color_palette('deep')[4], label='Linear Probing (ResNet50 CLIP-WebImageText)', fillstyle='none')\n",
    "\n",
    "\n",
    "t = np.reshape(np.stack(res5[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res5), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res5[:,0], zero_shot_pos, zeroshot_total_acc_mean), \n",
    "             np.insert(res5[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=alpha, fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[5], label='Linear Probing (ResNet50 Imagenet21k)', fillstyle='none')\n",
    "\n",
    "t = np.reshape(np.stack(res3[:,-1], axis=0), (-1))\n",
    "zeroshot_total_acc_mean, zeroshot_total_acc_std = np.mean(t), np.std(t)\n",
    "x = np.arange(len(res0)-len(res3), len(res0)+1)\n",
    "plt.errorbar(x, np.insert(res3[:,0], zero_shot_pos, zeroshot_total_acc_mean), np.insert(res3[:,1], zero_shot_pos, zeroshot_total_acc_std),\n",
    "             linestyle='None', marker='P', alpha=alpha, fmt=':', capsize=10, capthick=1, ms=markersize,\n",
    "            c=sns.color_palette('deep')[6], label='Linear Probing (ResNet50 Imagenet1k)', fillstyle='none')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Plot performance of initialization\n",
    "#x = np.arange(len(res0))\n",
    "#plt.errorbar(x, res0[:,2], res0[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[1], label='Zero-shot (ViT-B/32 CLIP)')\n",
    "\n",
    "\"\"\"\n",
    "x = np.arange(len(res0)-len(res01), len(res0))\n",
    "plt.errorbar(x, res01[:,2], res01[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "            c=sns.color_palette('deep')[5], label='Zero-shot (ViT-B/16 CLIP)')\n",
    "\n",
    "#x = np.arange(len(res0)-len(res1), len(res0))\n",
    "#plt.errorbar(x, res1[:,2], res1[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[3], label='Zero-shot (ResNet50 CLIP)')\n",
    "\n",
    "#x = np.arange(len(res0)-len(res2), len(res0))\n",
    "#plt.errorbar(x, res2[:,2], res2[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "#            c=sns.color_palette('deep')[2], label='Init Probing (ResNet50 CLIP)')\n",
    "\n",
    "x = np.arange(len(res0)-len(res3), len(res0))\n",
    "plt.errorbar(x, res3[:,2], res3[:,3], linestyle='-', marker=None, alpha=1., fmt=':', capsize=10, capthick=1, ms=30,\n",
    "            c=sns.color_palette('deep')[4], label='Init Probing (ResNet50 Imagenet)')\n",
    "\"\"\"\n",
    "#plt.hlines(np.mean(res[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[1], \n",
    "#           linestyles='solid', linewidth=3, label='Prompt init (ViT-32/B CLIP)')\n",
    "#plt.hlines(np.mean(res3[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[3], \n",
    "#           linestyles='solid', linewidth=3, label='Prompt init (ResNet50 CLIP)')\n",
    "#plt.hlines(np.mean(res2[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[2], \n",
    "#           linestyles='solid', linewidth=3, label='Probe init (ResNet50 CLIP)')\n",
    "#plt.hlines(np.mean(res4[:,2]).item(), -0.1, len(res)-0.9, colors=sns.color_palette('deep')[4], \n",
    "#           linestyles='solid', linewidth=3, label='Probe init (ResNet50 Imagenet)')\n",
    "\n",
    "plt.vlines(10.5, 0, 100, colors=sns.color_palette('deep')[7], \n",
    "           linestyles='--', linewidth=3)\n",
    "#plt.grid(True)\n",
    "plt.ylim(40, 100)\n",
    "ax = plt.gca()\n",
    "ax.set_xlabel('\\\\# of labeled training examples used, shown in \\\\% of training set.', fontsize=14)\n",
    "ax.set_ylabel('Test Accuracy', fontsize=20)\n",
    "\n",
    "box = ax.get_position()\n",
    "ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n",
    "\n",
    "# Put a legend to the right of the current axis\n",
    "lgnd = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), \n",
    "                 markerscale=0.7)\n",
    "#plt.legend(loc=8)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/1.00_2.50_3.50/1627564633.6356685/0/'\n",
    "#model_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.40_2.50_3.50/1627557728.1547391/0/'\n",
    "#model_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.10_2.50_3.50/1627557728.411988/0/'\n",
    "#model_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.06_2.50_3.50/1627557728.1343462/0/'\n",
    "model_path = f'/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B-16/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.01_2.50_3.50/1628084176.034234/0/'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sklearn\n",
    "from sklearn.metrics._plot.precision_recall_curve import PrecisionRecallDisplay\n",
    "import seaborn as sn\n",
    "import pandas as pd\n",
    "#data_path = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B/prompt_tuning_clip/cv_10_new_multiTestset0dot1/0.08_2.50_3.50/1627055133.2856226/5/results.p\"\n",
    "#data_path = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B/prompt_tuning_clip/cv_10/0.94_2.50_3.50/1625837159.9172006/1/results.p\"\n",
    "#data_path = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_ViT-B/prompt_tuning_clip/cv_10/0.98_2.50_3.50/1625835816.415273/9/results.p\"\n",
    "data_path = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/Clip_RN50/probe_tuning_clip/cv_10_new_multiTestset0dot1/1.00_2.50_3.50/1627065134.5431058/0/results.p\"\n",
    "\n",
    "\n",
    "data_dict = load_data(data_path)\n",
    "print(list(data_dict.keys()))\n",
    "y_pred, y_gt, means, epoch_acc, epoch_cmt, acc_t_gt, acc_t_lt = data_dict['y_pred'], data_dict['y_gt'], data_dict['means'], data_dict['epoch_acc'], data_dict['epoch_cmt'], data_dict['acc_t_gt'], data_dict['acc_t_lt']\n",
    "#try:\n",
    "#y_pred_zero, y_gt_zero = data_dict['y_pred_zero'], data_dict['y_gt_zero']\n",
    "print(f'Acc: {epoch_acc:.2f}, {acc_t_gt:.2f}, {acc_t_lt:.2f}')\n",
    "print('CM', epoch_cmt)\n",
    "cm = np.array(epoch_cmt, dtype=float)\n",
    "for i, v in enumerate(cm):  \n",
    "    t = np.sum(v)\n",
    "    for j, e in enumerate(v):\n",
    "        cm[i][j] = round(cm[i][j] / t, 2)\n",
    "        \n",
    "print('CM normalized', cm)\n",
    "df_cm = pd.DataFrame(cm, index = [i for i in range(len(epoch_cmt))],\n",
    "                  columns = [i for i in range(len(epoch_cmt))])\n",
    "plt.figure(figsize = (10,7))\n",
    "#plt.title(f'Test size: {np.sum(epoch_cmt)}')\n",
    "sn.heatmap(df_cm, annot=True, annot_kws={\"size\": 28},\n",
    "           yticklabels=['Moral', 'Immoral'], xticklabels=['Moral', 'Immoral'])\n",
    "\n",
    "pos_label=1\n",
    "#precision, recall, thresholds, avg_precision = precision_recall(torch.tensor(y_pred_zero), torch.tensor(y_gt_zero), pos_label=pos_label)\n",
    "precision, recall, thresholds, avg_precision = precision_recall(torch.tensor(y_pred), torch.tensor(y_gt), pos_label=pos_label)\n",
    "f1_score = f1(torch.tensor(y_pred), torch.tensor(y_gt))\n",
    "print(' Precision:', precision.numpy().tolist(), '\\n',\n",
    "      'Recall:', recall.numpy().tolist(), '\\n',\n",
    "      'Thresholds:', thresholds.numpy().tolist(), '\\n',\n",
    "      'AP:', avg_precision.item(),'\\n',\n",
    "        'F1', f1_score)\n",
    "pr_plot = PrecisionRecallDisplay(\n",
    "        precision=precision,\n",
    "        recall=recall,\n",
    "        average_precision=avg_precision,\n",
    "        estimator_name='Test'\n",
    "    )\n",
    "fig, ax = plt.subplots()\n",
    "ax.set_ylim(0, 1.1)\n",
    "pr_plot.plot(ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_mapping = [\n",
    "    ('prompt_tuning_clip', 'Clip_ViT-B-16'),\n",
    "    ('prompt_tuning_clip', 'Clip_ViT-B-32'),\n",
    "    ('prompt_tuning_clip', 'Clip_RN50'),\n",
    "    ('probe_tuning_clip', 'Clip_RN50'),\n",
    "    ('probe_tuning_clip_full', 'Clip_RN50'),\n",
    "    ('probe_tuning_imagenet', 'resnet50'),\n",
    "    ('probe_tuning_imagenet_full', 'resnet50'),\n",
    "    ('probe_tuning_imagenet_21k', 'resnet50'),\n",
    "    ('probe_tuning_imagenet_21k_full', 'resnet50')\n",
    "]\n",
    "suffix = '_new_multiTestset0dot1'\n",
    "files_ = list()\n",
    "for idx, (dm1, dm2) in enumerate(data_mapping):\n",
    "    base_dir = f\"/workspace/datasets/results/normativity/clip_stuff/results/SMID/{data_type}/{dm2}/{dm1}\"\n",
    "    base_dir = os.path.join(base_dir, f\"cv_{n_splits}{suffix}\")\n",
    "    #print(base_dir)\n",
    "    files_ += [glob.glob(base_dir + \"/1.00_2.50_3.50/*/*/results.p\")]\n",
    "    print(len(files_), len(files_[idx]))\n",
    "print(files_[0][0])\n",
    "print(files_[1][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_paths = files_\n",
    "for idx, data_path_cv in enumerate(data_paths):\n",
    "    scores = list()\n",
    "    for data_path in data_path_cv:\n",
    "        data_dict = load_data(data_path)\n",
    "        y_pred, y_gt, means, epoch_acc, epoch_cmt, acc_t_gt, acc_t_lt = data_dict['y_pred'], data_dict['y_gt'], data_dict['means'], data_dict['epoch_acc'], data_dict['epoch_cmt'], data_dict['acc_t_gt'], data_dict['acc_t_lt']\n",
    "        #try:\n",
    "        #y_pred_zero, y_gt_zero = data_dict['y_pred_zero'], data_dict['y_gt_zero']\n",
    "        #print(f'Acc: {epoch_acc:.2f}, {acc_t_gt:.2f}, {acc_t_lt:.2f}')\n",
    "        #print('CM', epoch_cmt)\n",
    "        #cm = np.array(epoch_cmt, dtype=float)\n",
    "        #for i, v in enumerate(cm):  \n",
    "        #    t = np.sum(v)\n",
    "        #    for j, e in enumerate(v):\n",
    "        #        cm[i][j] = round(cm[i][j] / t, 2)\n",
    "\n",
    "        #print('CM normalized', cm)\n",
    "        #df_cm = pd.DataFrame(cm, index = [i for i in range(len(epoch_cmt))],\n",
    "        #                  columns = [i for i in range(len(epoch_cmt))])\n",
    "        #plt.figure(figsize = (10,7))\n",
    "        #plt.title(f'Test size: {np.sum(epoch_cmt)}')\n",
    "        #sn.heatmap(df_cm, annot=True, annot_kws={\"size\": 28},\n",
    "        #           yticklabels=['Moral', 'Immoral'], xticklabels=['Moral', 'Immoral'])\n",
    "\n",
    "        pos_label=1\n",
    "        #precision, recall, thresholds, avg_precision = precision_recall(torch.tensor(y_pred_zero), torch.tensor(y_gt_zero), pos_label=pos_label)\n",
    "        precision, recall, thresholds, avg_precision = precision_recall(torch.tensor(y_pred), torch.tensor(y_gt), pos_label=pos_label)\n",
    "        f1_score = f1(torch.tensor(y_pred), torch.tensor(y_gt))\n",
    "        #print(' Precision:', precision.numpy().tolist(), '\\n',\n",
    "        #      'Recall:', recall.numpy().tolist(), '\\n',\n",
    "        #      'Thresholds:', thresholds.numpy().tolist(), '\\n',\n",
    "        #      'AP:', avg_precision.item(),'\\n',\n",
    "        #        'F1', f1_score)\n",
    "        scores.append((precision.numpy()[-2].item(), recall.numpy()[-2].item(), f1_score.numpy().item()))\n",
    "    scores = np.array(scores)\n",
    "    print('Precision mean', 'Precision std', 'Recall mean', 'Recall std', 'F1 mean', 'F1 std')\n",
    "    print(data_mapping[idx][0],data_mapping[idx][1], \n",
    "          f'${np.mean(scores[0]):.2f}\\pm{np.std(scores[0]):.2f}$&',\n",
    "          f'${np.mean(scores[1]):.2f}\\pm{np.std(scores[1]):.2f}$&',\n",
    "          f'${np.mean(scores[2]):.2f}\\pm{np.std(scores[2]):.2f}$')\n",
    "    print('--'*42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}