{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179161f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a7525d",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = \"steampunk\"\n",
    "model = \"FLUX.1-schnell\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8611ba7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_data = pd.read_csv(f'results_{model}_mean_ot_{style}_42/calculate_clip_score/clip_score.csv', index_col=0)\n",
    "\n",
    "clip_data.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99c4413",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_data_pid = pd.read_csv(f'results_{model}_mean_ot_pid_{style}_42/calculate_clip_score/clip_score.csv', index_col=0)\n",
    "\n",
    "clip_data.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f3524b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "lbs = clip_data['strength'].unique()\n",
    "\n",
    "zeroshot_mean_ot = [clip_data[\n",
    "    (clip_data['strength'] == lb) & (clip_data['conditional_zero_shot_score'] >= 0.49)\n",
    "].shape[0] / clip_data[clip_data['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "zeroshot_mean_ot_pid = [clip_data_pid[\n",
    "    (clip_data_pid['strength'] == lb) & (clip_data_pid['conditional_zero_shot_score'] >= 0.49)\n",
    "].shape[0] / clip_data_pid[clip_data_pid['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "# --- First plot: line across all points ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs,\n",
    "    y=zeroshot_mean_ot,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#4C9AFF\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"Mean-AcT\"\n",
    "))\n",
    "\n",
    "# --- Clip Data (PID) ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs,\n",
    "    y=zeroshot_mean_ot_pid,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#FFB347\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"Mean-AcT-PID\"\n",
    "))\n",
    "\n",
    "# Layout\n",
    "fig.update_layout(\n",
    "    paper_bgcolor=\"white\",\n",
    "    width=700,\n",
    "    height=500,\n",
    "    showlegend=False,\n",
    "    legend=dict(font=dict(size=18)),  # legend font\n",
    "    font=dict(size=18),               # default font for all text\n",
    "    xaxis=dict(\n",
    "        domain=[0, 1.0],\n",
    "        title=\"Strength\",\n",
    "        tickfont=dict(size=16)\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=\"0-shot(style) (→)\",\n",
    "        # titlefont=dict(size=20),\n",
    "        tickfont=dict(size=16)\n",
    "    ),\n",
    "    margin=dict(t=40, b=60, l=80, r=20)\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_image(f\"visualization/{model}/{style}/0shot_clip_{style}.pdf\", width=700, height=500, scale=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d78edde8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "lbs = clip_data['strength'].unique()\n",
    "\n",
    "clipscore_mean_ot = [clip_data['conditional_similarity'][\n",
    "    (clip_data['strength'] == lb) & (clip_data['unconditional_similarity'] >= 0.0)\n",
    "].sum() / clip_data[clip_data['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "clipscore_mean_ot_pid = [clip_data_pid['conditional_similarity'][\n",
    "    (clip_data_pid['strength'] == lb) & (clip_data_pid['unconditional_similarity'] >= 0.0)\n",
    "].sum()/ clip_data_pid[clip_data_pid['strength'] == 1].shape[0] for lb in lbs]\n",
    "\n",
    "# --- First plot: line across all points ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs,\n",
    "    y=clipscore_mean_ot,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color='#4C9AFF', width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"Mean-AcT\"\n",
    "))\n",
    "\n",
    "# --- Clip Data (PID) ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs,\n",
    "    y=clipscore_mean_ot_pid,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color='#FFB347', width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"PID-AcT\"\n",
    "))\n",
    "\n",
    "# Layout\n",
    "fig.update_layout(\n",
    "    width=700,\n",
    "    height=500,\n",
    "    showlegend=True,\n",
    "    xaxis=dict(domain=[0, 1.0], title=\"Strength\"),\n",
    "    yaxis=dict(title=\"ClipScores (style) (→)\"),\n",
    "    margin=dict(t=20, b=60)\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_image(f\"visualization/{model}/{style}/clipscore_{style}.pdf\", width=700, height=500, scale=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98308caa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from plotly.subplots import make_subplots\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "# Create subplots: 2 rows, 1 column (stacked)\n",
    "fig = make_subplots(\n",
    "    rows=1, cols=2,\n",
    "    shared_xaxes=True,\n",
    "    horizontal_spacing=0.1  # space between the two panels\n",
    ")\n",
    "\n",
    "# (Optional) ensure ordered x-axis\n",
    "lbs = sorted(clip_data['strength'].unique())\n",
    "\n",
    "# --- Zeroshot values ---\n",
    "zeroshot_mean_ot = [\n",
    "    clip_data[\n",
    "        (clip_data['strength'] == lb) & (clip_data['conditional_zero_shot_score'] >= 0.49)\n",
    "    ].shape[0] / clip_data[clip_data['strength'] == 1].shape[0] for lb in lbs\n",
    "]\n",
    "zeroshot_mean_ot_pid = [\n",
    "    clip_data_pid[\n",
    "        (clip_data_pid['strength'] == lb) & (clip_data_pid['conditional_zero_shot_score'] >= 0.49)\n",
    "    ].shape[0] / clip_data_pid[clip_data_pid['strength'] == 1].shape[0] for lb in lbs\n",
    "]\n",
    "\n",
    "# --- CLIPScore values ---\n",
    "clipscore_mean_ot = [\n",
    "    clip_data['unconditional_similarity'][\n",
    "        (clip_data['strength'] == lb) & (clip_data['unconditional_similarity'] >= 0.0)\n",
    "    ].sum() / clip_data[clip_data['strength'] == 1].shape[0] for lb in lbs\n",
    "]\n",
    "clipscore_mean_ot_pid = [\n",
    "    clip_data_pid['unconditional_similarity'][\n",
    "        (clip_data_pid['strength'] == lb) & (clip_data_pid['unconditional_similarity'] >= 0.0)\n",
    "    ].sum() / clip_data_pid[clip_data_pid['strength'] == 1].shape[0] for lb in lbs\n",
    "]\n",
    "\n",
    "# --- Top subplot (row 1): Zeroshot ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs, y=zeroshot_mean_ot,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#4C9AFF\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"Mean-AcT\"\n",
    "), row=1, col=1)\n",
    "\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs, y=zeroshot_mean_ot_pid,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#FFB347\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"PID-AcT\"  # fix label\n",
    "), row=1, col=1)\n",
    "\n",
    "# --- Bottom subplot (row 2): CLIPScore ---\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs, y=clipscore_mean_ot,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#4C9AFF\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"Mean-AcT\",\n",
    "    showlegend=False  # avoid duplicate legends\n",
    "), row=1, col=2)\n",
    "\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=lbs, y=clipscore_mean_ot_pid,\n",
    "    mode=\"lines+markers\",\n",
    "    line=dict(color=\"#FFB347\", width=2),\n",
    "    marker=dict(size=8),\n",
    "    name=\"PID-AcT\",\n",
    "    showlegend=False\n",
    "), row=1, col=2)\n",
    "\n",
    "# Layout: put legend below, x-title only on bottom subplot\n",
    "fig.update_layout(\n",
    "    width=900,\n",
    "    height=250,             # taller for vertical stack\n",
    "    showlegend=True,\n",
    "    legend=dict(\n",
    "        orientation=\"v\",     # keep horizontal\n",
    "        yanchor=\"top\",       # anchor to top\n",
    "        y=0.28,              # a bit below the top edge of fig\n",
    "        xanchor=\"right\",     # anchor relative to right\n",
    "        x=0.80,              # place near the right edge\n",
    "        bgcolor=\"rgba(255,255,255,0.6)\",  # optional: semi-transparent background\n",
    "        bordercolor=\"black\",  # optional: border\n",
    "        borderwidth=1,\n",
    "        font=dict(size=15)\n",
    "    ),\n",
    "    margin=dict(t=10, b=40, r=40),\n",
    ")\n",
    "\n",
    "# Axis titles\n",
    "fig.update_yaxes(title_text=\"0-shot (→)\", row=1, col=1)\n",
    "fig.update_yaxes(title_text=\"CLIPScore (→)\", row=2, col=1)\n",
    "fig.update_xaxes(title_text=\"Strength\", row=2, col=1)  # bottom only\n",
    "\n",
    "# Fonts / ticks\n",
    "fig.update_layout(font=dict(size=30))\n",
    "fig.update_xaxes(tickfont=dict(size=25), title_font=dict(size=30), tickformat=\".2f\")\n",
    "fig.update_yaxes(tickfont=dict(size=25), title_font=dict(size=30), tickformat=\".2f\")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "# Save (increase height for PDF export too)\n",
    "fig.write_image(\n",
    "    f\"visualization/{model}/{style}/combined_{style}.pdf\",\n",
    "    width=1000, height=250, scale=2\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace20e06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa23c89c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_dir = \"results_42\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0a0819",
   "metadata": {},
   "outputs": [],
   "source": [
    "# zero_shot_res = pd.read_csv(data_dir + '/' + 'evaluate_0shot/0shot_eval.csv', index_col=0)\n",
    "# zero_shot_res[zero_shot_res['q0_llm_answer'] == \"Yes\"].shape[0] / zero_shot_res.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caac4073",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mmlu_res = pd.read_pickle(data_dir+'/evaluate_eleuther/eleuther.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa00718e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mmlu_res['results']['mmlu']['acc,none']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d09ad0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mistral_res = pd.read_csv(data_dir+'/evaluate_perplexity/model_perplexity.csv', index_col=0)\n",
    "# mistral_res[mistral_res['strength']==1.0]['ppl_Mistral-7B-v0.1'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ed6291a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-act",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
