{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "import copy\n",
    "import pandas as pd\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "all_df1 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_sc.csv\")\n",
    "all_df2 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_ci.csv\")\n",
    "all_df3 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_ai.csv\")\n",
    "all_df4 = pd.read_csv(\"//exps/div_explore/toy_v2/results_seed0_3shifts.csv\")\n",
    "all_df = pd.concat([all_df1, all_df2, all_df3, all_df4], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_df_over = all_df[all_df['method'] == 'oversample']\n",
    "all_df_under = all_df[all_df['method'] == 'remax-margin']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "\n",
    "sns.histplot(loss_diff, kde=True, bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = np.where(loss_diff<-0.05)[0]\n",
    "checkdf = all_df_over.iloc[ids]\n",
    "print(len(ids))\n",
    "print(len(all_df_over))\n",
    "print(np.unique(checkdf['n'], return_counts=True))\n",
    "print(np.unique(checkdf['sc'], return_counts=True))\n",
    "print(np.unique(checkdf['var_causal'], return_counts=True))\n",
    "print(np.unique(checkdf['d_feat'], return_counts=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "methods = np.array([\"ERM\", \"GroupDRO\", \"oversample\", \"undersample\", \"remax-margin\"])\n",
    "combinations = np.array(list(itertools.combinations(methods, 2)))\n",
    "for m1, m2 in combinations:\n",
    "    all_df_1 = all_df[all_df['method'] == m1]\n",
    "    all_df_2 = all_df[all_df['method'] == m2]\n",
    "    # plot histogram\n",
    "    loss_diff = all_df_1['wga_te_err'].to_numpy() - all_df_2['wga_te_err'].to_numpy()\n",
    "    loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "    ids = np.where(loss_diff<-0.05)[0]\n",
    "    checkdf = all_df_over.iloc[ids]\n",
    "    print(f\"------------{m1}_{m2}--------------\")\n",
    "    print(len(ids))\n",
    "    print(len(all_df_over))\n",
    "    print(np.unique(checkdf['n'], return_counts=True))\n",
    "    print(np.unique(checkdf['sc'], return_counts=True))\n",
    "    print(np.unique(checkdf['var_causal'], return_counts=True))\n",
    "    print(np.unique(checkdf['d_feat'], return_counts=True))\n",
    "    # use a new plot every time\n",
    "    plt.figure()\n",
    "    sns.histplot(loss_diff, kde=True, bins=20)\n",
    "    plt.title(m1+m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution 1: further filter\n",
    "# solution 2: use other visualization\n",
    "# larger font\n",
    "plt.figure(figsize=(6, 5))\n",
    "\n",
    "# larger font\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'oversample']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df1 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df1['method'] = 'Under-Oversample'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'ERM']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df2 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df2['method'] = 'Undersample-ERM'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'remax-margin']\n",
    "all_df_under = all_df[all_df['method'] == 'GroupDRO']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df3 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df3['method'] = 'Maxmargin-GroupDRO'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'GroupDRO']\n",
    "all_df_under = all_df[all_df['method'] == 'undersample']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df4 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df4['method'] = 'GroupDRO-Undersample'\n",
    "\n",
    "df = pd.concat([df1, df2, df3, df4])\n",
    "\n",
    "# rename columns\n",
    "df.columns = ['Test error diff.', 'Methods']\n",
    "df['Methods'] = df['Methods'].map({'Under-Oversample': 'Under-Oversample', 'Undersample-ERM': 'Undersample-ERM', 'Maxmargin-GroupDRO': 'Logits corr.-GroupDRO', 'GroupDRO-Undersample': 'GroupDRO-Undersample'})\n",
    "# white grid\n",
    "sns.set(style=\"whitegrid\", font_scale=1.3)\n",
    "# sns.color_palette(palette='Dark2')\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', kde=False, bins=20, stat='count', multiple=\"layer\", common_norm=True, log_scale=(False, True))\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.8, palette='Set2')\n",
    "sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.7)\n",
    "\n",
    "plt.xlabel(\"Difference in worst-group test error\")\n",
    "# log transform the y axis\n",
    "plt.yscale('log')\n",
    "plt.ylim(0, 10000)\n",
    "\n",
    "# plt.title(\"Histogram of test error difference across OOD datasets\")\n",
    "\n",
    "# high resolution\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./toy_error_diff.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only for enlarge the font size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution 1: further filter\n",
    "# solution 2: use other visualization\n",
    "# larger font\n",
    "plt.figure(figsize=(6, 5))\n",
    "\n",
    "# larger font\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'oversample']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df1 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df1['method'] = 'Under-Oversample'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'ERM']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df2 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df2['method'] = 'Undersample-ERM'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'remax-margin']\n",
    "all_df_under = all_df[all_df['method'] == 'GroupDRO']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df3 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df3['method'] = 'Maxmargin-GroupDRO'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'GroupDRO']\n",
    "all_df_under = all_df[all_df['method'] == 'undersample']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df4 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df4['method'] = 'GroupDRO-Undersample'\n",
    "\n",
    "df = pd.concat([df1, df2, df3, df4])\n",
    "\n",
    "# rename columns\n",
    "df.columns = ['Test error diff.', 'Methods']\n",
    "df['Methods'] = df['Methods'].map({'Under-Oversample': 'Under-Oversample', 'Undersample-ERM': 'Undersample-ERM', 'Maxmargin-GroupDRO': 'Logits corr.-GroupDRO', 'GroupDRO-Undersample': 'GroupDRO-Undersample'})\n",
    "# white grid\n",
    "sns.set(style=\"whitegrid\", font_scale=1.6)\n",
    "# sns.color_palette(palette='Dark2')\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', kde=False, bins=20, stat='count', multiple=\"layer\", common_norm=True, log_scale=(False, True))\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.8, palette='Set2')\n",
    "sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.7)\n",
    "\n",
    "plt.xlabel(\"Difference in worst-group test error\")\n",
    "# log transform the y axis\n",
    "plt.yscale('log')\n",
    "plt.ylim(0, 10000)\n",
    "\n",
    "# adjust legend transparency\n",
    "# get current legend labels\n",
    "# leg = plt.legend()\n",
    "# # set the facecolor of the legend\n",
    "# leg.get_frame().set_facecolor('white')\n",
    "# # set the frame alpha to 1\n",
    "# leg.get_frame().set_alpha(1)\n",
    "# # set the title of the legend\n",
    "# leg.set_title('')\n",
    "# plt.title(\"Histogram of test error difference across OOD datasets\")\n",
    "\n",
    "# high resolution\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./toy_error_diff.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# solution 1: further filter\n",
    "# solution 2: use other visualization\n",
    "# larger font\n",
    "plt.figure(figsize=(6, 5))\n",
    "\n",
    "# larger font\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'oversample']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df1 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df1['method'] = 'Under-Oversample'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'undersample']\n",
    "all_df_under = all_df[all_df['method'] == 'ERM']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df2 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df2['method'] = 'Undersample-ERM'\n",
    "\n",
    "all_df_over = all_df[all_df['method'] == 'remax-margin']\n",
    "all_df_under = all_df[all_df['method'] == 'GroupDRO']\n",
    "# plot histogram\n",
    "loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "df3 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "df3['method'] = 'Maxmargin-GroupDRO'\n",
    "\n",
    "# all_df_over = all_df[all_df['method'] == 'GroupDRO']\n",
    "# all_df_under = all_df[all_df['method'] == 'undersample']\n",
    "# # plot histogram\n",
    "# loss_diff = all_df_over['wga_te_err'].to_numpy() - all_df_under['wga_te_err'].to_numpy()\n",
    "# # loss_diff = loss_diff[(loss_diff < -0.05) | (loss_diff > 0.05)]\n",
    "# df4 = pd.DataFrame(loss_diff, columns=['loss_diff'])\n",
    "# df4['method'] = 'GroupDRO-Undersample'\n",
    "\n",
    "# df = pd.concat([df1, df2, df3, df4])\n",
    "\n",
    "df = pd.concat([df1, df2, df3])\n",
    "\n",
    "# rename columns\n",
    "df.columns = ['Test error diff.', 'Methods']\n",
    "df['Methods'] = df['Methods'].map({'Under-Oversample': 'Under - Oversample', 'Undersample-ERM': 'Undersample - ERM', 'Maxmargin-GroupDRO': 'Logits corr. - GroupDRO', 'GroupDRO-Undersample': 'GroupDRO - Undersample'})\n",
    "# white grid\n",
    "sns.set(style=\"whitegrid\", font_scale=1.6)\n",
    "# sns.color_palette(palette='Dark2')\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', kde=False, bins=20, stat='count', multiple=\"layer\", common_norm=True, log_scale=(False, True))\n",
    "# sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.8, palette='Set2')\n",
    "ax = sns.histplot(data=df, x='Test error diff.', hue='Methods', bins=20, multiple=\"dodge\", common_norm=False, alpha=0.7)\n",
    "ax.grid(axis='x')\n",
    "plt.xlabel(\"Difference in worst-group test error\")\n",
    "# log transform the y axis\n",
    "plt.yscale('log')\n",
    "plt.ylim(0, 10000)\n",
    "\n",
    "# adjust legend transparency\n",
    "# get current legend labels\n",
    "# leg = plt.legend()\n",
    "# # set the facecolor of the legend\n",
    "# leg.get_frame().set_facecolor('white')\n",
    "# # set the frame alpha to 1\n",
    "# leg.get_frame().set_alpha(1)\n",
    "# # set the title of the legend\n",
    "# leg.set_title('')\n",
    "# plt.title(\"Histogram of test error difference across OOD datasets\")\n",
    "# plt.axvspan(-0.5, 0, color='gray', alpha=0.1)\n",
    "# Adjust the legend font size\n",
    "legend = ax.get_legend()\n",
    "legend.set_title(legend.get_title().get_text(), prop={'size': 14})  # Set font size for legend title\n",
    "for text in legend.get_texts():\n",
    "    text.set_fontsize(14)  # Adjust font size for legend labels\n",
    "# high resolution\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./toy_error_diff1.pdf\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "div_backup",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
