{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179161f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import StrMethodFormatter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a7525d",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = \"sketch\"\n",
    "model = \"FLUX.1-dev\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8611ba7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_data_clust = pd.read_csv(f't2i_results/CHaRS/cluster_ot_FLUX.1-dev_{style}/seed1/calculate_clip_score/clip_score.csv', index_col=0)\n",
    "print(len(clip_data_clust))\n",
    "clip_data_clust.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99c4413",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_data_linear = pd.read_csv(f't2i_results/CHaRS/linear_ot_FLUX.1-dev_{style}/seed1/calculate_clip_score/clip_score.csv', index_col=0)\n",
    "print(len(clip_data_linear))\n",
    "clip_data_linear.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f638cd09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig = go.Figure()\n",
    "cut_off_bound = 0.49\n",
    "lbs = clip_data_linear['strength'].unique()\n",
    "\n",
    "zeroshot_mean_ot_cluster = [clip_data_clust[\n",
    "    (clip_data_clust['strength'] == lb) & (clip_data_clust['conditional_zero_shot_score'] >= cut_off_bound)\n",
    "].shape[0] / clip_data_clust[clip_data_clust['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "clipscore_mean_ot_cluster = [clip_data_clust['conditional_similarity'][\n",
    "    (clip_data_clust['strength'] == lb) & (clip_data_clust['unconditional_similarity'] >= 0.0)\n",
    "].sum() / clip_data_clust[clip_data_clust['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "df_cluster = pd.DataFrame({\n",
    "    \"strength\": lbs,\n",
    "    \"0shot_cluster\": zeroshot_mean_ot_cluster,\n",
    "    \"clipscore_cluster\": clipscore_mean_ot_cluster,\n",
    "})\n",
    "fig, (ax1_left, ax2_left) = plt.subplots(1, 2, figsize=(12.8, 5.6))\n",
    "sns.lineplot(\n",
    "    data = df_cluster,\n",
    "    x = \"strength\",\n",
    "    y = \"0shot_cluster\",\n",
    "    ax=ax1_left,\n",
    "    marker=\"o\",\n",
    "    linewidth=2,\n",
    "    label = \"0-Shot Classification Score\",\n",
    "    color = \"#2CA02C\"\n",
    ")\n",
    "\n",
    "ax1_right = ax1_left.twinx() # Creates a second axes that shares the same x-axis\n",
    "sns.lineplot(\n",
    "    data = df_cluster,\n",
    "    x = \"strength\",\n",
    "    y = \"clipscore_cluster\",\n",
    "    ax=ax1_right,\n",
    "    marker=\"o\",\n",
    "    linewidth=2,\n",
    "    label = \"CLIPScore\",\n",
    "    color = \"#1F77B4\"\n",
    ")\n",
    "\n",
    "zeroshot_mean_ot_linear = [clip_data_linear[\n",
    "    (clip_data_linear['strength'] == lb) & (clip_data_linear['conditional_zero_shot_score'] >= cut_off_bound)\n",
    "].shape[0] / clip_data_linear[clip_data_linear['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "clipscore_mean_ot_linear = [clip_data_linear['conditional_similarity'][\n",
    "    (clip_data_linear['strength'] == lb) & (clip_data_linear['unconditional_similarity'] >= 0.0)\n",
    "].sum() / clip_data_linear[clip_data_linear['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "df_linear = pd.DataFrame({\n",
    "    \"strength\": lbs,\n",
    "    \"0shot_cluster\": zeroshot_mean_ot_linear,\n",
    "    \"clipscore_cluster\": clipscore_mean_ot_linear,\n",
    "})\n",
    "sns.lineplot(\n",
    "    data = df_linear,\n",
    "    x = \"strength\",\n",
    "    y = \"0shot_cluster\",\n",
    "    ax=ax2_left,\n",
    "    marker=\"o\",\n",
    "    linewidth=2,\n",
    "    label = \"0-Shot Classification Score\",\n",
    "    color = \"#2CA02C\"\n",
    ")\n",
    "\n",
    "ax2_right = ax2_left.twinx() # Creates a second axes that shares the same x-axis\n",
    "sns.lineplot(\n",
    "    data = df_linear,\n",
    "    x = \"strength\",\n",
    "    y = \"clipscore_cluster\",\n",
    "    ax=ax2_right,\n",
    "    marker=\"o\",\n",
    "    linewidth=2,\n",
    "    label = \"CLIPScore\",\n",
    "    color = \"#1F77B4\"\n",
    ")\n",
    "\n",
    "# Set labels, label sizes\n",
    "ax1_left.set_xlabel(r\"Strength $\\lambda$\", fontsize = 22)\n",
    "ax1_left.set_ylabel(\"0-Shot Classification Score\", fontsize = 22)\n",
    "ax1_right.set_ylabel(\"CLIPScore\", fontsize = 22)\n",
    "ax2_left.set_xlabel(r\"Strength $\\lambda$\", fontsize = 22)\n",
    "ax2_left.set_ylabel(\"0-Shot Classification Score\", fontsize = 22)\n",
    "ax2_right.set_ylabel(\"CLIPScore\", fontsize = 22)\n",
    "\n",
    "# Set y axis limits\n",
    "ax1_left.set_ylim(0.5, 1.05)\n",
    "ax1_right.set_ylim(0.1, 0.3)\n",
    "ax2_left.set_ylim(0.5, 1.05)\n",
    "ax2_right.set_ylim(0.1, 0.3)\n",
    "\n",
    "# Set ticks, tick sizes\n",
    "plt.xticks(fontsize=18)\n",
    "ax1_left.tick_params(labelsize=18)\n",
    "ax1_right.tick_params(labelsize=18)\n",
    "ax2_left.tick_params(labelsize=18)\n",
    "ax2_right.tick_params(labelsize=18)\n",
    "\n",
    "# Set log scale\n",
    "ax1_left.set_xscale('log')\n",
    "ax2_left.set_xscale('log')\n",
    "ax1_left.xaxis.set_major_formatter(StrMethodFormatter(\"{x:g}\"))\n",
    "ax2_left.xaxis.set_major_formatter(StrMethodFormatter(\"{x:g}\"))\n",
    "\n",
    "# Set combined legend for both y axes\n",
    "lines1, labels1 = ax1_left.get_legend_handles_labels()\n",
    "lines2, labels2 = ax1_right.get_legend_handles_labels()\n",
    "lines3, labels3 = ax2_left.get_legend_handles_labels()\n",
    "lines4, labels4 = ax2_right.get_legend_handles_labels()\n",
    "fig.legend(lines1+lines2, labels1+labels2, loc='upper center', bbox_to_anchor=(0.5, 1.1), fontsize=20, ncol=2)\n",
    "for a in [ax1_left, ax1_right, ax2_left, ax2_right]:\n",
    "    if a.get_legend(): a.get_legend().remove()\n",
    "\n",
    "# Other grid parameters\n",
    "ax1_left.grid(True, which=\"major\", linewidth=0.6, alpha=0.7)\n",
    "ax2_left.grid(True, which=\"major\", linewidth=0.6, alpha=0.7)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"clipscore_plots/both_stop10_logscale.pdf\", bbox_inches='tight')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80d00e7c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
