{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splitting_criteria = \"all\"\n",
    "exp = \"tar\"\n",
    "task = \"or\"\n",
    "filter=\"bias_ref\"\n",
    "colors = ['#6DAEDB', '#A8D5BA', '#F2CAC8', '#D3CCE3', '#FFD580']#['blue', 'green', 'red', 'purple', 'orange']\n",
    "\n",
    "results = pd.read_csv(f'results/generalization/{task}/{splitting_criteria}.csv')\n",
    "results[\"bias_ref\"] = abs(results[\"PPATE_ref\"] - results[\"ATE_ref\"])\n",
    "results[\"bias_tar\"] = abs(results[\"PPATE_tar\"] - results[\"ATE_tar\"])\n",
    "#results = results[results['acc_ref'] > 0.8]\n",
    "results = results.groupby(['method', 'encoder']).apply(lambda x: x.nsmallest(1, 'bias_ref')).reset_index(drop=True)\n",
    "results[\"encoder\"] = results[\"encoder\"].replace({\n",
    "    \"vit_large\": \"ViT-L\",\n",
    "    \"vit\": \"ViT-S\",\n",
    "    \"clip_large\": \"CLIP-ViT-L\",\n",
    "    \"clip\": \"CLIP-ViT-S\",\n",
    "    \"mae\": \"MAE\",\n",
    "    \"dino\": \"DINOv2\"})\n",
    "results[\"encoder\"] = pd.Categorical(results[\"encoder\"], [\"DINOv2\", \"MAE\", \"CLIP-ViT-L\", \"CLIP-ViT-S\", \"ViT-L\", \"ViT-S\"])\n",
    "results['bias_ref'] = results['PPATE_ref'] - results['ATE_ref']\n",
    "results['bias_tar'] = results['PPATE_tar'] - results['ATE_tar']\n",
    "print(results.groupby(['method']).agg({'bias_ref': 'mean', 'bias_tar': 'mean'}).reset_index())\n",
    "print(results.groupby(['method']).agg({'acc_ref': 'mean', 'acc_tar': 'mean'}).reset_index())\n",
    "print(results.groupby(['method']).agg({'bacc_ref': 'mean', 'bacc_tar': 'mean'}).reset_index())\n",
    "methods = [\"DERM\",\"IRM\",\"vREx\",\"ERM\"]\n",
    "results = results.sort_values(by=['method', 'encoder'])\n",
    "fig, ax = plt.subplots(figsize=(8, 5))\n",
    "for k, encoder in enumerate(results['encoder'].unique()[::-1]):\n",
    "    ax.plot([], [], 'o', color=colors[len(results['encoder'].unique())-k-1], label=f'{encoder}')\n",
    "ax.legend(loc='lower left')\n",
    "soft_green = 'green'\n",
    "soft_red = 'red'\n",
    "gt_mean = results[f'ATE_{exp}'].iloc[0]\n",
    "gt_err = results[f'ATE_std_{exp}'].iloc[0] * 1.96\n",
    "gt_low, gt_high = gt_mean - gt_err, gt_mean + gt_err\n",
    "for i, method in enumerate(methods):\n",
    "    subset = results[results['method'] == method]\n",
    "    means = subset[f'PPATE_{exp}']\n",
    "    errs = subset[f'PPATE_std_{exp}']*1.96\n",
    "    encoders = subset['encoder']\n",
    "    for j, (mean, err) in enumerate(zip(means, errs)):\n",
    "        print(f\"{method} {encoders.iloc[j]}: {mean:.3f} ± {err:.3f}, j={j}\")\n",
    "        ci_low, ci_high = mean - err, mean + err  \n",
    "        threshold = 0.3\n",
    "        threshold = 0.3\n",
    "        ci_length = ci_high - ci_low\n",
    "        gt_length = gt_high - gt_low\n",
    "        overlap_low = max(gt_low, ci_low)\n",
    "        overlap_high = min(gt_high, ci_high)\n",
    "        overlap_length = max(0.0, overlap_high - overlap_low)\n",
    "        if overlap_length >= threshold * ci_length:\n",
    "            ci_color = soft_green\n",
    "        else:\n",
    "            ci_color = soft_red\n",
    "        ax.errorbar(mean, i + j * 0.15, xerr=err, fmt='none', capsize=5, ecolor=ci_color, alpha=0.4)\n",
    "        ax.plot(mean, i + j * 0.15, 'o', color=colors[j])\n",
    "ax.errorbar(results[f'ATE_{exp}'].iloc[0], i + j * 0.15+0.5, xerr=subset[f'ATE_std_{exp}'].iloc[0]*1.96, fmt='o', label=\"Ground Truth\", capsize=5, color='black')\n",
    "ax.set_yticks(list(np.arange(len(methods)) + 0.3)+[i + j * 0.15+0.5], labels=methods + [\"Ground Truth\"])\n",
    "ax.set_xlabel(r'$\\hat{\\tau}$')\n",
    "plt.grid(axis='x', linestyle='--', alpha=0.7)\n",
    "plt.tight_layout()\n",
    "plt.axvline(x=results[f'ATE_{exp}'].mean(), color='black', linestyle='--', label=f'ATE_{exp}')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Scatter plot for acc_tar vs bias_tar\n",
    "marker_styles = {\n",
    "    \"DERM\": \"o\",\n",
    "    \"IRM\": \"s\",\n",
    "    \"vREx\": \"^\",\n",
    "    \"ERM\": \"D\"\n",
    "}\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "results[\"encoder\"] = results[\"encoder\"].replace({\n",
    "        \"vit_large\": \"ViT-L\",\n",
    "        \"vit\": \"ViT-S\",\n",
    "        \"clip_large\": \"CLIP-ViT-L\",\n",
    "        \"clip\": \"CLIP-ViT-S\",\n",
    "        \"mae\": \"MAE\",\n",
    "        \"dino\": \"DINOv2\"})\n",
    "# order first DINO then MAE then CLIP then ViT\n",
    "results[\"encoder\"] = pd.Categorical(results[\"encoder\"], [\"DINOv2\", \"MAE\", \"CLIP-ViT-L\", \"CLIP-ViT-S\", \"ViT-L\", \"ViT-S\"])\n",
    "\n",
    "results[\"bias_tar\"] = results[\"PPATE_tar\"] - results[\"ATE_tar\"]\n",
    "# Iterate over methods and encoders to plot points with different markers and colors\n",
    "for method in methods:\n",
    "    for encoder, color in zip(results['encoder'].cat.categories, colors):\n",
    "        subset = results[(results['method'] == method) & (results['encoder'] == encoder)]\n",
    "        if not subset.empty:\n",
    "            ax.scatter(subset['acc_tar'], subset['bias_tar'], \n",
    "                       label=f'{method} - {encoder}', \n",
    "                       color=color, \n",
    "                       marker=marker_styles[method])\n",
    "\n",
    "# Add labels and legend\n",
    "ax.set_xlabel('Accuracy')\n",
    "ax.set_ylabel('Bias')\n",
    "#ax.set_title('Scatter Plot of acc_tar vs bias_tar')\n",
    "# add grid\n",
    "ax.grid(True, linestyle='--', alpha=0.7)\n",
    "ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define marker styles for encoders\n",
    "encoder_marker_styles = {\n",
    "    \"DINOv2\": \"o\",\n",
    "    \"MAE\": \"s\",\n",
    "    \"CLIP-ViT-L\": \"^\",\n",
    "    \"CLIP-ViT-S\": \"D\",\n",
    "    \"ViT-L\": \"P\",\n",
    "    \"ViT-S\": \"X\"\n",
    "}\n",
    "\n",
    "# Define colors for methods\n",
    "method_colors = {\n",
    "    \"DERM\": \"#6DAEDB\",\n",
    "    \"IRM\": \"#A8D5BA\",\n",
    "    \"vREx\": \"#F2CAC8\",\n",
    "    \"ERM\": \"#D3CCE3\"\n",
    "}\n",
    "\n",
    "# Create the plot\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "\n",
    "# Iterate over methods and encoders to plot points with different markers and colors\n",
    "for method in methods:\n",
    "    for encoder in results['encoder'].cat.categories:\n",
    "        subset = results[(results['method'] == method) & (results['encoder'] == encoder)]\n",
    "        if not subset.empty:\n",
    "            ax.scatter(subset['acc_tar'], subset['bias_tar'], \n",
    "                       label=f'{method} - {encoder}', \n",
    "                       color=method_colors[method], \n",
    "                       marker=encoder_marker_styles[encoder])\n",
    "\n",
    "# Add labels and legend\n",
    "ax.set_xlabel('Accuracy')\n",
    "ax.set_ylabel('Bias')\n",
    "ax.grid(True, linestyle='--', alpha=0.7)\n",
    "ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "crl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
