{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-07-19T16:49:14.812437Z",
     "start_time": "2024-07-19T16:49:12.170015Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yuenc2/anaconda3/envs/fishr/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'dlopen(/Users/yuenc2/anaconda3/envs/fishr/lib/python3.9/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev\n",
      "  Referenced from: <2D1B8D5C-7891-3680-9CF9-F771AE880676> /Users/yuenc2/anaconda3/envs/fishr/lib/python3.9/site-packages/torchvision/image.so\n",
      "  Expected in:     <8D9916BF-BDB2-35FE-B43C-E28A9CA7D22C> /Users/yuenc2/anaconda3/envs/fishr/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "import functools\n",
    "import glob\n",
    "import pickle\n",
    "import itertools\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import sys\n",
    "import numpy as np\n",
    "import tqdm\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# If domainbed is a custom module, you might need to adjust its imports or setup\n",
    "from domainbed import datasets\n",
    "from domainbed import algorithms\n",
    "from domainbed.lib import misc, reporting\n",
    "from domainbed import model_selection\n",
    "from domainbed.lib.query import Q\n",
    "import warnings\n"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total subdirectories: 0\n",
      "Empty subdirectories: 0 in results_vits_hessian_real_bias directory\n",
      "Total subdirectories: 0\n",
      "Empty subdirectories: 0 in results_vits_hessian_bias directory\n",
      "Total subdirectories: 0\n",
      "Empty subdirectories: 0 in results_vits_coral directory\n"
     ]
    }
   ],
   "source": [
    "def count_subdirectories_and_empty_ones(directory):\n",
    "    num_subdirectories = 0\n",
    "    num_empty_subdirectories = 0\n",
    "\n",
    "    # Walk through all directories in the given directory\n",
    "    for root, dirs, files in os.walk(directory):\n",
    "        # Iterate over each directory in the current root\n",
    "        for d in dirs:\n",
    "            num_subdirectories += 1\n",
    "            subdirectory_path = os.path.join(root, d)\n",
    "            # Check if the directory is empty\n",
    "            if not os.listdir(subdirectory_path):\n",
    "                num_empty_subdirectories += 1\n",
    "\n",
    "    return num_subdirectories, num_empty_subdirectories\n",
    "\n",
    "\n",
    "# Specify the directory you want to inspect\n",
    "# directory_path = 'results_vits_3600_32'\n",
    "# directory_path = 'results_vits_terra_pacs'\n",
    "# directory_path = 'results_vits_VLCS_ERM_Fishr'\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "directory_path_list = ['results_vits_hessian_real_bias', 'results_vits_hessian_bias','results_vits_coral']\n",
    "\n",
    "\n",
    "for directory_path in directory_path_list:\n",
    "    # Get the count of subdirectories and empty subdirectories\n",
    "    subdirectories, empty_subdirectories = count_subdirectories_and_empty_ones(directory_path)\n",
    "    print(f\"Total subdirectories: {subdirectories}\")\n",
    "    print(f\"Empty subdirectories: {empty_subdirectories} in {directory_path} directory\")\n",
    "\n",
    "# Get the count of subdirectories and empty subdirectories\n",
    "# subdirectories, empty_subdirectories = count_subdirectories_and_empty_ones(directory_path)\n",
    "# print(f\"Total subdirectories: {subdirectories}\")\n",
    "# print(f\"Empty subdirectories: {empty_subdirectories}\")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T16:49:14.817640Z",
     "start_time": "2024-07-19T16:49:14.814606Z"
    }
   },
   "id": "8ab783a36ad7a66e",
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n",
    "import itertools\n",
    "import numpy as np\n",
    "\n",
    "def get_test_records(records):\n",
    "    \"\"\"Given records with a common test env, get the test records (i.e. the\n",
    "    records with *only* that single test env and no other test envs)\"\"\"\n",
    "    return records.filter(lambda r: len(r['args']['test_envs']) == 1)\n",
    "\n",
    "class SelectionMethod:\n",
    "    \"\"\"Abstract class whose subclasses implement strategies for model\n",
    "    selection across hparams and timesteps.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        raise TypeError\n",
    "\n",
    "    @classmethod\n",
    "    def run_acc(self, run_records):\n",
    "        \"\"\"\n",
    "        Given records from a run, return a {val_acc, test_acc} dict representing\n",
    "        the best val-acc and corresponding test-acc for that run.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError\n",
    "\n",
    "    @classmethod\n",
    "    def hparams_accs(self, records):\n",
    "        \"\"\"\n",
    "        Given all records from a single (dataset, algorithm, test env) pair,\n",
    "        return a sorted list of (run_acc, records) tuples.\n",
    "        \"\"\"\n",
    "\n",
    "        return (records.group('args.hparams_seed')\n",
    "            .map(lambda _, run_records:\n",
    "                (\n",
    "                    self.run_acc(run_records),\n",
    "                    run_records\n",
    "                )\n",
    "            ).filter(lambda x: x[0] is not None)\n",
    "            .sorted(key=lambda x: x[0]['val_acc'])[::-1]\n",
    "        )\n",
    "\n",
    "    @classmethod\n",
    "    def sweep_acc(self, records):\n",
    "        \"\"\"\n",
    "        Given all records from a single (dataset, algorithm, test env) pair,\n",
    "        return the mean test acc of the k runs with the top val accs.\n",
    "        \"\"\"\n",
    "        _hparams_accs = self.hparams_accs(records)\n",
    "        a = _hparams_accs\n",
    "        # for i in range(len(a)):\n",
    "        #     print(a[i][0]['val_acc'], a[i][0]['test_acc'])\n",
    "        #     print(f\"Hparams {(a[i][1][0]['hparams']['grad_alpha'], a[i][1][0]['hparams']['grad_alpha'],  a[i][1][0]['hparams']['penalty_anneal_iters'])}\")\n",
    "        # if a[0][1][0]['args']['algorithm'] == 'HessianAlignment':\n",
    "        # if a[0][1][0]['args']['algorithm'] in ['HessianAlignment'] and a[0][1][0]['args']['dataset'] == 'PACS':\n",
    "        # if a[0][1][0]['args']['algorithm'] in ['HessianAlignment'] and a[0][1][0]['args']['dataset'] == 'TerraIncognita':\n",
    "        # if a[0][1][0]['args']['algorithm'] in ['HessianAlignment'] and a[0][1][0]['args']['dataset'] in ['VLCS','ColoredMNIST']:\n",
    "        # if a[0][1][0]['args']['algorithm'] in ['HessianAlignment'] and a[0][1][0]['args']['dataset'] == 'ColoredMNIST':\n",
    "        # if a[0][1][0]['args']['algorithm'] in ['Fishr'] and a[0][1][0]['args']['dataset'] == 'ColoredMNIST':\n",
    "        #     print(f\"Algorithm: {a[0][1][0]['args']['algorithm']}\")\n",
    "        #     print(f\"Best val acc and test for {a[0][1][0]['args']['dataset']}, env {a[0][1][0]['args']['test_envs']}:\", a[0][0]['val_acc'], a[0][0]['test_acc'], \"Hparams\", (a[0][1][0]['hparams']))\n",
    "        #     print(\"Output dir\", a[0][1][0]['args']['output_dir'])\n",
    "        #    print(f\"Best hyperparameters for {a[0][1][0]['args']['dataset']}, env {a[0][1][0]['args']['test_envs']}:\", \n",
    "        #          f\"grad_alpha: {a[0][1][0]['hparams']['grad_alpha']}, hess_beta {a[0][1][0]['hparams']['hess_beta']}, \"\n",
    "        #          f\"anneal_iters: {a[0][1][0]['hparams']['penalty_anneal_iters']}\",\n",
    "        #          f\"output_dir: {a[0][1][0]['args']['output_dir']}\")\n",
    "\n",
    "        if len(_hparams_accs):\n",
    "            # breakpoint()\n",
    "            return _hparams_accs[0][0]['test_acc']\n",
    "        else:\n",
    "            return None\n",
    "\n",
    "class OracleSelectionMethod(SelectionMethod):\n",
    "    \"\"\"Like Selection method which picks argmax(test_out_acc) across all hparams\n",
    "    and checkpoints, but instead of taking the argmax over all\n",
    "    checkpoints, we pick the last checkpoint, i.e. no early stopping.\"\"\"\n",
    "    name = \"test-domain validation set (oracle)\"\n",
    "\n",
    "    @classmethod\n",
    "    def run_acc(self, run_records):\n",
    "        run_records = run_records.filter(lambda r:\n",
    "            len(r['args']['test_envs']) == 1)\n",
    "        if not len(run_records):\n",
    "            return None\n",
    "        test_env = run_records[0]['args']['test_envs'][0]\n",
    "        test_out_acc_key = 'env{}_out_acc'.format(test_env)\n",
    "        test_in_acc_key = 'env{}_in_acc'.format(test_env)\n",
    "        chosen_record = run_records.sorted(lambda r: r['step'])[-1]\n",
    "        return {\n",
    "            'val_acc':  chosen_record[test_out_acc_key],\n",
    "            'test_acc': chosen_record[test_in_acc_key]\n",
    "        }\n",
    "\n",
    "class IIDAccuracySelectionMethod(SelectionMethod):\n",
    "    \"\"\"Picks argmax(mean(env_out_acc for env in train_envs))\"\"\"\n",
    "    name = \"training-domain validation set\"\n",
    "\n",
    "    @classmethod\n",
    "    def _step_acc(self, record):\n",
    "        \"\"\"Given a single record, return a {val_acc, test_acc} dict.\"\"\"\n",
    "        test_env = record['args']['test_envs'][0]\n",
    "        val_env_keys = []\n",
    "        for i in itertools.count():\n",
    "            if f'env{i}_out_acc' not in record:\n",
    "                break\n",
    "            if i != test_env:\n",
    "                val_env_keys.append(f'env{i}_out_acc')\n",
    "        test_in_acc_key = 'env{}_in_acc'.format(test_env)\n",
    "        return {\n",
    "            'val_acc': np.mean([record[key] for key in val_env_keys]),\n",
    "            'test_acc': record[test_in_acc_key]\n",
    "        }\n",
    "\n",
    "    @classmethod\n",
    "    def run_acc(self, run_records):\n",
    "        test_records = get_test_records(run_records)\n",
    "        if not len(test_records):\n",
    "            return None\n",
    "\n",
    "        index_of_max = test_records.map(self._step_acc).map(lambda x: x['val_acc'])._list.index(\n",
    "            max(test_records.map(self._step_acc).map(lambda x: x['val_acc'])))\n",
    "        full_record_with_hyperparams = test_records[index_of_max]\n",
    "        # print(f\"Hyperparameters for {full_record_with_hyperparams['args']['dataset']}, env {full_record_with_hyperparams['args']['test_envs']}:\", full_record_with_hyperparams['hparams'])\n",
    "        return test_records.map(self._step_acc).argmax('val_acc')\n",
    "\n",
    "class LeaveOneOutSelectionMethod(SelectionMethod):\n",
    "    \"\"\"Picks (hparams, step) by leave-one-out cross validation.\"\"\"\n",
    "    name = \"leave-one-domain-out cross-validation\"\n",
    "\n",
    "    @classmethod\n",
    "    def _step_acc(self, records):\n",
    "        \"\"\"Return the {val_acc, test_acc} for a group of records corresponding\n",
    "        to a single step.\"\"\"\n",
    "        test_records = get_test_records(records)\n",
    "        if len(test_records) != 1:\n",
    "            return None\n",
    "\n",
    "        test_env = test_records[0]['args']['test_envs'][0]\n",
    "        n_envs = 0\n",
    "        for i in itertools.count():\n",
    "            if f'env{i}_out_acc' not in records[0]:\n",
    "                break\n",
    "            n_envs += 1\n",
    "        val_accs = np.zeros(n_envs) - 1\n",
    "        for r in records.filter(lambda r: len(r['args']['test_envs']) == 2):\n",
    "            val_env = (set(r['args']['test_envs']) - set([test_env])).pop()\n",
    "            val_accs[val_env] = r['env{}_in_acc'.format(val_env)]\n",
    "        val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:])\n",
    "        if any([v==-1 for v in val_accs]):\n",
    "            return None\n",
    "        val_acc = np.sum(val_accs) / (n_envs-1)\n",
    "        return {\n",
    "            'val_acc': val_acc,\n",
    "            'test_acc': test_records[0]['env{}_in_acc'.format(test_env)]\n",
    "        }\n",
    "\n",
    "    @classmethod\n",
    "    def run_acc(self, records):\n",
    "        step_accs = records.group('step').map(lambda step, step_records:\n",
    "            self._step_acc(step_records)\n",
    "        ).filter_not_none()\n",
    "        if len(step_accs):\n",
    "            return step_accs.argmax('val_acc')\n",
    "        else:\n",
    "            return None\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T16:49:14.827732Z",
     "start_time": "2024-07-19T16:49:14.825650Z"
    }
   },
   "id": "9a40c1c70267ef5f",
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def format_mean(data, latex):\n",
    "    \"\"\"Given a list of datapoints, return a string describing their mean and\n",
    "    standard error\"\"\"\n",
    "    if len(data) == 0:\n",
    "        return None, None, \"X\"\n",
    "    mean = 100 * np.mean(list(data))\n",
    "    err = 100 * np.std(list(data) / np.sqrt(len(data)))\n",
    "    if latex:\n",
    "        return mean, err, \"{:.1f} $\\\\pm$ {:.1f}\".format(mean, err)\n",
    "    else:\n",
    "        return mean, err, \"{:.1f} +/- {:.1f}\".format(mean, err)\n",
    "\n",
    "def print_table(table, header_text, row_labels, col_labels, colwidth=10,\n",
    "    latex=True):\n",
    "    \"\"\"Pretty-print a 2D array of data, optionally with row/col labels\"\"\"\n",
    "    print(\"\")\n",
    "\n",
    "    if latex:\n",
    "        num_cols = len(table[0])\n",
    "        print(\"\\\\begin{center}\")\n",
    "        print(\"\\\\adjustbox{max width=\\\\textwidth}{%\")\n",
    "        print(\"\\\\begin{tabular}{l\" + \"c\" * num_cols + \"}\")\n",
    "        print(\"\\\\toprule\")\n",
    "    else:\n",
    "        print(\"--------\", header_text)\n",
    "\n",
    "    for row, label in zip(table, row_labels):\n",
    "        row.insert(0, label)\n",
    "\n",
    "    if latex:\n",
    "        col_labels = [\"\\\\textbf{\" + str(col_label).replace(\"%\", \"\\\\%\") + \"}\"\n",
    "            for col_label in col_labels]\n",
    "    table.insert(0, col_labels)\n",
    "\n",
    "    for r, row in enumerate(table):\n",
    "        misc.print_row(row, colwidth=colwidth, latex=latex)\n",
    "        if latex and r == 0:\n",
    "            print(\"\\\\midrule\")\n",
    "    if latex:\n",
    "        print(\"\\\\bottomrule\")\n",
    "        print(\"\\\\end{tabular}}\")\n",
    "        print(\"\\\\end{center}\")\n",
    "\n",
    "def print_results_tables(records, selection_method, latex):\n",
    "    \"\"\"Given all records, print a results table for each dataset.\"\"\"\n",
    "    grouped_records = reporting.get_grouped_records(records).map(lambda group:\n",
    "        { **group, \"sweep_acc\": selection_method.sweep_acc(group[\"records\"]) }\n",
    "    ).filter(lambda g: g[\"sweep_acc\"] is not None)\n",
    "\n",
    "    # read algorithm names and sort (predefined order)\n",
    "    alg_names = Q(records).select(\"args.algorithm\").unique()\n",
    "    alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +\n",
    "        [n for n in alg_names if n not in algorithms.ALGORITHMS])\n",
    "\n",
    "    # read dataset names and sort (lexicographic order)\n",
    "    dataset_names = Q(records).select(\"args.dataset\").unique().sorted()\n",
    "    dataset_names = [d for d in datasets.DATASETS if d in dataset_names]\n",
    "\n",
    "    for dataset in dataset_names:\n",
    "        if latex:\n",
    "            print()\n",
    "            print(\"\\\\subsubsection{{{}}}\".format(dataset))\n",
    "        test_envs = range(datasets.num_environments(dataset))\n",
    "        # breakpoint()\n",
    "        table = [[None for _ in [*test_envs, \"Avg\"]] for _ in alg_names]\n",
    "        for i, algorithm in enumerate(alg_names):\n",
    "            means = []\n",
    "            for j, test_env in enumerate(test_envs):\n",
    "                trial_accs = (grouped_records\n",
    "                    .filter_equals(\n",
    "                        \"dataset, algorithm, test_env\",\n",
    "                        (dataset, algorithm, test_env)\n",
    "                    ).select(\"sweep_acc\"))\n",
    "                mean, err, table[i][j] = format_mean(trial_accs, latex)\n",
    "                means.append(mean)\n",
    "            if None in means:\n",
    "                table[i][-1] = \"X\"\n",
    "            else:\n",
    "                table[i][-1] = \"{:.1f}\".format(sum(means) / len(means))\n",
    "\n",
    "        col_labels = [\n",
    "            \"Algorithm\",\n",
    "            *datasets.get_dataset_class(dataset).ENVIRONMENTS,\n",
    "            \"Avg\"\n",
    "        ]\n",
    "        header_text = (f\"Dataset: {dataset}, \"\n",
    "            f\"model selection method: {selection_method.name}\")\n",
    "        print_table(table, header_text, alg_names, list(col_labels),\n",
    "            colwidth=20, latex=latex)\n",
    "\n",
    "    # Print an \"averages\" table\n",
    "    if latex:\n",
    "        print()\n",
    "        print(\"\\\\subsubsection{Averages}\")\n",
    "\n",
    "    table = [[None for _ in [*dataset_names, \"Avg\"]] for _ in alg_names]\n",
    "    for i, algorithm in enumerate(alg_names):\n",
    "        means = []\n",
    "        for j, dataset in enumerate(dataset_names):\n",
    "            trial_averages = (grouped_records\n",
    "                .filter_equals(\"algorithm, dataset\", (algorithm, dataset))\n",
    "                .group(\"trial_seed\")\n",
    "                .map(lambda trial_seed, group:\n",
    "                    group.select(\"sweep_acc\").mean()\n",
    "                )\n",
    "            )\n",
    "            mean, err, table[i][j] = format_mean(trial_averages, latex)\n",
    "            means.append(mean)\n",
    "        if None in means:\n",
    "            table[i][-1] = \"X\"\n",
    "        else:\n",
    "            table[i][-1] = \"{:.1f}\".format(sum(means) / len(means))\n",
    "\n",
    "    col_labels = [\"Algorithm\", *dataset_names, \"Avg\"]\n",
    "    header_text = f\"Averages, model selection method: {selection_method.name}\"\n",
    "    print_table(table, header_text, alg_names, col_labels, colwidth=25,\n",
    "        latex=latex)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T16:49:14.843253Z",
     "start_time": "2024-07-19T16:49:14.834740Z"
    }
   },
   "id": "73ca9d266f0c31a8",
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: './results_hess_diag'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[5], line 34\u001B[0m\n\u001B[1;32m     31\u001B[0m latex \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[1;32m     32\u001B[0m results_file \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mresults.tex\u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m latex \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mresults.txt\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m---> 34\u001B[0m records \u001B[38;5;241m=\u001B[39m \u001B[43mreporting\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_records\u001B[49m\u001B[43m(\u001B[49m\u001B[43minput_dir\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     35\u001B[0m \u001B[38;5;66;03m# selection_methods = [\u001B[39;00m\n\u001B[1;32m     36\u001B[0m \u001B[38;5;66;03m#     model_selection.IIDAccuracySelectionMethod,\u001B[39;00m\n\u001B[1;32m     37\u001B[0m \u001B[38;5;66;03m#     model_selection.OracleSelectionMethod\u001B[39;00m\n\u001B[1;32m     38\u001B[0m \u001B[38;5;66;03m# ]\u001B[39;00m\n\u001B[1;32m     39\u001B[0m selection_methods \u001B[38;5;241m=\u001B[39m [\n\u001B[1;32m     40\u001B[0m     IIDAccuracySelectionMethod,\n\u001B[1;32m     41\u001B[0m     OracleSelectionMethod,\n\u001B[1;32m     42\u001B[0m ]\n",
      "File \u001B[0;32m~/Desktop/CMA/DomainBed/domainbed/lib/reporting.py:14\u001B[0m, in \u001B[0;36mload_records\u001B[0;34m(path)\u001B[0m\n\u001B[1;32m     12\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mload_records\u001B[39m(path):\n\u001B[1;32m     13\u001B[0m     records \u001B[38;5;241m=\u001B[39m []\n\u001B[0;32m---> 14\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m i, subdir \u001B[38;5;129;01min\u001B[39;00m tqdm\u001B[38;5;241m.\u001B[39mtqdm(\u001B[38;5;28mlist\u001B[39m(\u001B[38;5;28menumerate\u001B[39m(\u001B[43mos\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlistdir\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpath\u001B[49m\u001B[43m)\u001B[49m)),\n\u001B[1;32m     15\u001B[0m                                ncols\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m80\u001B[39m,\n\u001B[1;32m     16\u001B[0m                                leave\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n\u001B[1;32m     17\u001B[0m         results_path \u001B[38;5;241m=\u001B[39m os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mjoin(path, subdir, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mresults.jsonl\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m     18\u001B[0m         \u001B[38;5;66;03m# breakpoint()\u001B[39;00m\n",
      "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: './results_hess_diag'"
     ]
    }
   ],
   "source": [
    "# input_dir = \"./results_vits_3600_32\"\n",
    "input_dir = \"./results_vits_terra_pacs\"\n",
    "\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_VLCS\"\n",
    "# input_dir = \"./results_vits_hessian_pacs_anneal_3600\"\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_rescale\" \n",
    "# input_dir = \"./results_vits_hessian_vlcs_random2\"\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_class_mnist\" \n",
    "input_dir = \"./results_vits_hessian_class_terra\"\n",
    "# input_dir = \"./results_vits_hessian_bias_old\"\n",
    "input_dir = \"./results_vits_hessian_bias\"\n",
    "input_dir = \"./results_vits_hessian_real_bias\" \n",
    "input_dir = \"./results_hgp\"\n",
    "# ! python -m copy_32.py\n",
    "# input_dir = \"./results_vits_coral\" \n",
    "# input_dir = \"./results_vits_coral_MNIST\"  #DONE\n",
    "\n",
    "\n",
    "# input_dir = \"./results_vits_hessian_MNIST_rescale\"\n",
    "# input_dir = \"./results_vits_hessian_MNIST_rescale_sqrt\"\n",
    "# input_dir = \"./results_vits_hessian_MNIST_rescale\"\n",
    "# input_dir = \"./results_vits_combined_bias\"\n",
    "# input_dir = \"./results_resnet_mnist\"\n",
    "input_dir = \"./results_hgp_cma\"\n",
    "input_dir = \"./results_hess_diag\"\n",
    "\n",
    "# latex = True\n",
    "latex = False\n",
    "results_file = \"results.tex\" if latex else \"results.txt\"\n",
    "\n",
    "records = reporting.load_records(input_dir)\n",
    "# selection_methods = [\n",
    "#     model_selection.IIDAccuracySelectionMethod,\n",
    "#     model_selection.OracleSelectionMethod\n",
    "# ]\n",
    "selection_methods = [\n",
    "    IIDAccuracySelectionMethod,\n",
    "    OracleSelectionMethod,\n",
    "]\n",
    "\n",
    "for selection_method in selection_methods:\n",
    "    print_results_tables(records, selection_method, latex)\n",
    "    \n",
    "    \n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-07-19T16:50:22.438716Z",
     "start_time": "2024-07-19T16:50:22.158502Z"
    }
   },
   "id": "83057a822b38d293",
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "-------- Dataset: ColoredMNIST, model selection method: training-domain validation set\n",
      "Algorithm             +90%                  +80%                  -90%                  Avg                  \n",
      "HessianAlignment      72.2 +/- 0.2          72.8 +/- 0.0          10.0 +/- 0.2          51.7                 \n",
      "\n",
      "-------- Dataset: RotatedMNIST, model selection method: training-domain validation set\n",
      "Algorithm             0                     15                    30                    45                    60                    75                    Avg                  \n",
      "HessianAlignment      94.6 +/- 0.6          98.6 +/- 0.2          99.0 +/- 0.0          98.6 +/- 0.3          98.3 +/- 0.3          95.3 +/- 0.6          97.4                 \n",
      "\n",
      "-------- Dataset: VLCS, model selection method: training-domain validation set\n",
      "Algorithm             C                     L                     S                     V                     Avg                  \n",
      "HessianAlignment      96.4 +/- 0.3          63.5 +/- 0.4          74.0 +/- 0.3          76.7 +/- 0.6          77.6                 \n",
      "\n",
      "-------- Dataset: PACS, model selection method: training-domain validation set\n",
      "Algorithm             A                     C                     P                     S                     Avg                  \n",
      "HessianAlignment      83.5 +/- 1.0          74.5 +/- 0.8          97.0 +/- 0.2          65.9 +/- 1.5          80.2                 \n",
      "\n",
      "-------- Dataset: TerraIncognita, model selection method: training-domain validation set\n",
      "Algorithm             L100                  L38                   L43                   L46                   Avg                  \n",
      "HessianAlignment      44.9 +/- 1.9          18.8 +/- 2.7          37.4 +/- 0.5          32.6 +/- 0.5          33.4                 \n",
      "\n",
      "-------- Averages, model selection method: training-domain validation set\n",
      "Algorithm                  ColoredMNIST               RotatedMNIST               VLCS                       PACS                       TerraIncognita             Avg                       \n",
      "HessianAlignment           44.6 +/- 3.0               97.4 +/- 0.2               77.6 +/- 0.1               80.2 +/- 0.1               33.4 +/- 0.9               66.7                      \n",
      "\n",
      "-------- Dataset: ColoredMNIST, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             +90%                  +80%                  -90%                  Avg                  \n",
      "HessianAlignment      68.7 +/- 0.4          68.0 +/- 0.0          23.2 +/- 2.4          53.3                 \n",
      "\n",
      "-------- Dataset: RotatedMNIST, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             0                     15                    30                    45                    60                    75                    Avg                  \n",
      "HessianAlignment      94.6 +/- 0.9          98.5 +/- 0.4          98.7 +/- 0.1          98.7 +/- 0.3          98.1 +/- 0.4          94.5 +/- 1.0          97.2                 \n",
      "\n",
      "-------- Dataset: VLCS, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             C                     L                     S                     V                     Avg                  \n",
      "HessianAlignment      95.9 +/- 0.5          63.0 +/- 0.6          73.0 +/- 0.8          75.2 +/- 1.3          76.8                 \n",
      "\n",
      "-------- Dataset: PACS, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             A                     C                     P                     S                     Avg                  \n",
      "HessianAlignment      84.1 +/- 1.5          75.6 +/- 1.3          97.1 +/- 0.2          70.0 +/- 0.8          81.7                 \n",
      "\n",
      "-------- Dataset: TerraIncognita, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             L100                  L38                   L43                   L46                   Avg                  \n",
      "HessianAlignment      49.1 +/- 0.7          24.4 +/- 2.0          39.9 +/- 1.6          35.4 +/- 2.3          37.2                 \n",
      "\n",
      "-------- Averages, model selection method: test-domain validation set (oracle)\n",
      "Algorithm                  ColoredMNIST               RotatedMNIST               VLCS                       PACS                       TerraIncognita             Avg                       \n",
      "HessianAlignment           48.3 +/- 2.4               97.2 +/- 0.2               76.8 +/- 0.4               81.7 +/- 0.5               37.2 +/- 0.2               68.2                      \n"
     ]
    }
   ],
   "source": [
    "# input_dir = \"./results_vits_3600_32\"\n",
    "input_dir = \"./results_vits_terra_pacs\"\n",
    "\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_VLCS\"\n",
    "# input_dir = \"./results_vits_hessian_pacs_anneal_3600\"\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_rescale\" \n",
    "# input_dir = \"./results_vits_hessian_vlcs_random2\"\n",
    "# input_dir = \"./results_vits_hessian_vlcs_terra_anneal_2500\"\n",
    "# input_dir = \"./results_vits_hessian_class\" \n",
    "# input_dir = \"./results_vits_hessian_class_terra\"\n",
    "# input_dir = \"./results_vits_combined\"\n",
    "latex = False\n",
    "results_file = \"results.tex\" if latex else \"results.txt\"\n",
    "\n",
    "records = reporting.load_records(input_dir)\n",
    "# selection_methods = [\n",
    "#     model_selection.IIDAccuracySelectionMethod,\n",
    "#     model_selection.OracleSelectionMethod\n",
    "# ]\n",
    "selection_methods = [\n",
    "    IIDAccuracySelectionMethod,\n",
    "    OracleSelectionMethod,\n",
    "]\n",
    "\n",
    "for selection_method in selection_methods:\n",
    "    print_results_tables(records, selection_method, latex)\n",
    "    \n",
    "    \n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-22T01:32:40.967155Z",
     "start_time": "2024-05-22T01:32:40.638223Z"
    }
   },
   "id": "925fcf82bb0ee22c",
   "execution_count": 217
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.852764680668667 0.10998017892537633 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9129082782560279, 'img_size': 224, 'lambda': 809.251247284292, 'lr': 0.00011906162909279594, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4072, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/014ed16d218abe1223a01977f0d3848d\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7353193313330476 0.7280372830512106 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9752755873490916, 'img_size': 224, 'lambda': 35.812115651904186, 'lr': 0.0026225693030627445, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4399, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/f6b31858dbf0281ee2ee892bc96e94cc\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7409987141020146 0.7181808442254125 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9129082782560279, 'img_size': 224, 'lambda': 809.251247284292, 'lr': 0.00011906162909279594, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4072, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/48b4c177a6a1b1f35e0259a37d5d2edf\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.7347835405057865 0.7349868752343708 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9387863036486702, 'img_size': 224, 'lambda': 40.997125753143685, 'lr': 0.00017414079263307025, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 168, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/76f789de544159360ee1dda77d5dac31\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.7348906986712387 0.7328976268280923 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9185575699327962, 'img_size': 224, 'lambda': 1181.7633941221247, 'lr': 0.00040753017105436845, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1319, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/d182942e81e04c22759d2b403f2ddf81\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.8525503643377625 0.10687309155193657 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9877960292129928, 'img_size': 224, 'lambda': 44.896933381501775, 'lr': 9.242078636590432e-05, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4289, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/bb9fc32b66fe2857b43552c2bf9e8b55\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7376768109729961 0.7313049067923719 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9578165657568851, 'img_size': 224, 'lambda': 115.84804357438838, 'lr': 0.00013416901146173725, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 2536, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/ac415ff133a017831ef6e55981e1c483\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.7387483926275182 0.7308083784218139 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.95, 'img_size': 224, 'lambda': 1000.0, 'lr': 0.001, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1500, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/438fc04167641e6808dc61439e43cf1b\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.8560865837976854 0.10199817860395351 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.95, 'img_size': 224, 'lambda': 1000.0, 'lr': 0.001, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1500, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/5fcac96fcbbe327884d98b116d21f15e\n",
      "\n",
      "-------- Dataset: ColoredMNIST, model selection method: training-domain validation set\n",
      "Algorithm             +90%                  +80%                  -90%                  Avg                  \n",
      "ERM                   72.2 +/- 0.2          72.9 +/- 0.2          10.1 +/- 0.1          51.7                 \n",
      "Fishr                 72.6 +/- 0.3          73.3 +/- 0.1          10.6 +/- 0.2          52.2                 \n",
      "HessianAlignment      72.4 +/- 0.2          73.2 +/- 0.1          10.0 +/- 0.2          51.9                 \n",
      "\n",
      "-------- Dataset: RotatedMNIST, model selection method: training-domain validation set\n",
      "Algorithm             0                     15                    30                    45                    60                    75                    Avg                  \n",
      "ERM                   95.3 +/- 0.2          98.6 +/- 0.1          99.1 +/- 0.1          98.9 +/- 0.0          98.9 +/- 0.0          96.1 +/- 0.2          97.8                 \n",
      "Fishr                 95.6 +/- 0.3          98.5 +/- 0.1          99.1 +/- 0.1          99.0 +/- 0.1          99.0 +/- 0.1          96.4 +/- 0.0          97.9                 \n",
      "HessianAlignment      94.9 +/- 0.7          98.8 +/- 0.1          99.0 +/- 0.0          98.9 +/- 0.1          98.9 +/- 0.1          95.6 +/- 0.5          97.7                 \n",
      "\n",
      "-------- Dataset: VLCS, model selection method: training-domain validation set\n",
      "Algorithm             C                     L                     S                     V                     Avg                  \n",
      "ERM                   97.1 +/- 0.1          62.3 +/- 0.3          71.9 +/- 0.7          77.2 +/- 0.4          77.2                 \n",
      "Fishr                 96.4 +/- 0.6          63.3 +/- 0.9          74.8 +/- 0.6          76.2 +/- 0.4          77.7                 \n",
      "HessianAlignment      96.4 +/- 0.3          63.5 +/- 0.4          74.0 +/- 0.3          76.7 +/- 0.6          77.6                 \n",
      "\n",
      "-------- Dataset: PACS, model selection method: training-domain validation set\n",
      "Algorithm             A                     C                     P                     S                     Avg                  \n",
      "ERM                   80.2 +/- 0.6          75.4 +/- 0.2          95.9 +/- 0.8          66.6 +/- 0.3          79.5                 \n",
      "Fishr                 83.1 +/- 1.0          74.8 +/- 0.5          97.2 +/- 0.2          68.7 +/- 0.8          81.0                 \n",
      "HessianAlignment      83.5 +/- 1.0          74.5 +/- 0.8          97.0 +/- 0.2          65.9 +/- 1.5          80.2                 \n",
      "\n",
      "-------- Dataset: TerraIncognita, model selection method: training-domain validation set\n",
      "Algorithm             L100                  L38                   L43                   L46                   Avg                  \n",
      "ERM                   48.2 +/- 2.1          17.8 +/- 2.3          37.8 +/- 1.0          34.2 +/- 0.5          34.5                 \n",
      "Fishr                 47.2 +/- 2.1          16.5 +/- 1.6          39.9 +/- 1.9          33.2 +/- 0.7          34.2                 \n",
      "HessianAlignment      44.9 +/- 1.9          18.8 +/- 2.7          37.4 +/- 0.5          32.6 +/- 0.5          33.4                 \n",
      "\n",
      "-------- Averages, model selection method: training-domain validation set\n",
      "Algorithm                  ColoredMNIST               RotatedMNIST               VLCS                       PACS                       TerraIncognita             Avg                       \n",
      "ERM                        51.7 +/- 0.1               97.8 +/- 0.1               77.2 +/- 0.2               79.5 +/- 0.3               34.5 +/- 0.4               68.1                      \n",
      "Fishr                      52.2 +/- 0.1               97.9 +/- 0.1               77.7 +/- 0.4               81.0 +/- 0.3               34.2 +/- 0.9               68.6                      \n",
      "HessianAlignment           51.9 +/- 0.1               97.7 +/- 0.2               77.6 +/- 0.1               80.2 +/- 0.1               33.4 +/- 0.9               68.2                      \n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.35040720102871836 0.34836877912894415 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9509165224455243, 'img_size': 224, 'lambda': 1974.3832860379962, 'lr': 0.00014524944240203904, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 2498, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/a38dc68c0496304daaf502faf3ebfab9\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7280325760822974 0.7345725305335333 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9877960292129928, 'img_size': 224, 'lambda': 44.896933381501775, 'lr': 9.242078636590432e-05, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4289, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/b1529ce69648524c2b0a68b47fd53fdf\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7486069438491213 0.7466788086565246 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9509165224455243, 'img_size': 224, 'lambda': 1974.3832860379962, 'lr': 0.00014524944240203904, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 2498, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/a2b63c0154f0348d60430a6bb5eba199\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.7295327903986284 0.7304869555900787 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.95, 'img_size': 224, 'lambda': 1000.0, 'lr': 0.001, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1500, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/25272584fc938184a92059a278da0333\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.7383197599657094 0.7399153586543098 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.95, 'img_size': 224, 'lambda': 1000.0, 'lr': 0.001, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1500, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600/5c23dfe6b21f883155f968c08656d12e\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.30818688384054865 0.2982268173782611 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9403353191998677, 'img_size': 224, 'lambda': 636.7590532387758, 'lr': 0.0021690714686971317, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 47, 'resnet18': False, 'resnet_dropout': 0.1, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/678d1d6361c8acca5c860ff2f33e22df\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [0]: 0.7471067295327904 0.7363402614098993 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.95, 'img_size': 224, 'lambda': 1000.0, 'lr': 0.001, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1500, 'resnet18': False, 'resnet_dropout': 0.0, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_3600_32/82ed95fda82d7a542c8613b5f0a4d324\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [1]: 0.731890270038577 0.733808324851342 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9403353191998677, 'img_size': 224, 'lambda': 636.7590532387758, 'lr': 0.0021690714686971317, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 47, 'resnet18': False, 'resnet_dropout': 0.1, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/d5fa96050860291e08240c5147430444\n",
      "Algorithm: Fishr\n",
      "Best val acc and test for ColoredMNIST, env [2]: 0.5040720102871838 0.5098301816038999 Hparams {'batch_size': 64, 'class_balanced': False, 'data_augmentation': True, 'ema': 0.9771289458110848, 'img_size': 224, 'lambda': 6440.079124001571, 'lr': 9.376027078720673e-05, 'model_type': 'ViT-S', 'nonlinear_classifier': False, 'penalty_anneal_iters': 1426, 'resnet18': False, 'resnet_dropout': 0.1, 'weight_decay': 0.0}\n",
      "Output dir ./domainbed/results_vits_ERM_Fishr/7ce7dd2925b2833c9b50cbc14b0095b6\n",
      "\n",
      "-------- Dataset: ColoredMNIST, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             +90%                  +80%                  -90%                  Avg                  \n",
      "ERM                   68.1 +/- 1.1          70.5 +/- 0.7          25.0 +/- 1.9          54.5                 \n",
      "Fishr                 73.9 +/- 0.3          73.5 +/- 0.2          38.5 +/- 5.2          62.0                 \n",
      "HessianAlignment      68.8 +/- 0.7          71.2 +/- 0.3          27.6 +/- 2.0          55.9                 \n",
      "\n",
      "-------- Dataset: RotatedMNIST, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             0                     15                    30                    45                    60                    75                    Avg                  \n",
      "ERM                   95.2 +/- 0.3          98.5 +/- 0.1          98.9 +/- 0.1          98.9 +/- 0.1          99.0 +/- 0.1          96.2 +/- 0.2          97.8                 \n",
      "Fishr                 95.7 +/- 0.2          98.7 +/- 0.0          99.0 +/- 0.1          99.1 +/- 0.1          98.8 +/- 0.2          96.4 +/- 0.0          97.9                 \n",
      "HessianAlignment      94.8 +/- 1.0          98.7 +/- 0.2          98.9 +/- 0.2          98.9 +/- 0.1          98.9 +/- 0.1          95.8 +/- 0.6          97.7                 \n",
      "\n",
      "-------- Dataset: VLCS, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             C                     L                     S                     V                     Avg                  \n",
      "ERM                   96.4 +/- 0.1          62.3 +/- 1.0          72.1 +/- 0.6          76.7 +/- 0.3          76.9                 \n",
      "Fishr                 96.0 +/- 0.8          64.0 +/- 0.1          73.5 +/- 0.7          76.4 +/- 0.6          77.5                 \n",
      "HessianAlignment      95.9 +/- 0.5          63.0 +/- 0.6          73.0 +/- 0.8          75.2 +/- 1.3          76.8                 \n",
      "\n",
      "-------- Dataset: PACS, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             A                     C                     P                     S                     Avg                  \n",
      "ERM                   81.2 +/- 0.9          73.4 +/- 0.9          96.1 +/- 0.6          70.3 +/- 0.5          80.2                 \n",
      "Fishr                 83.6 +/- 0.6          74.9 +/- 1.0          97.4 +/- 0.3          70.1 +/- 0.5          81.5                 \n",
      "HessianAlignment      84.1 +/- 1.5          75.6 +/- 1.3          97.1 +/- 0.2          70.0 +/- 0.8          81.7                 \n",
      "\n",
      "-------- Dataset: TerraIncognita, model selection method: test-domain validation set (oracle)\n",
      "Algorithm             L100                  L38                   L43                   L46                   Avg                  \n",
      "ERM                   50.2 +/- 0.4          25.0 +/- 1.9          36.3 +/- 1.6          34.5 +/- 0.1          36.5                 \n",
      "Fishr                 49.9 +/- 2.1          23.2 +/- 1.8          41.4 +/- 1.2          34.7 +/- 0.7          37.3                 \n",
      "HessianAlignment      49.1 +/- 0.7          24.4 +/- 2.0          39.9 +/- 1.6          35.4 +/- 2.3          37.2                 \n",
      "\n",
      "-------- Averages, model selection method: test-domain validation set (oracle)\n",
      "Algorithm                  ColoredMNIST               RotatedMNIST               VLCS                       PACS                       TerraIncognita             Avg                       \n",
      "ERM                        54.5 +/- 0.2               97.8 +/- 0.1               76.9 +/- 0.3               80.2 +/- 0.5               36.5 +/- 0.5               69.2                      \n",
      "Fishr                      62.0 +/- 1.7               97.9 +/- 0.0               77.5 +/- 0.5               81.5 +/- 0.2               37.3 +/- 1.1               71.2                      \n",
      "HessianAlignment           55.9 +/- 0.7               97.7 +/- 0.3               76.8 +/- 0.4               81.7 +/- 0.5               37.2 +/- 0.2               69.8                      \n"
     ]
    }
   ],
   "source": [
    " input_dir = \"./results_vits_3600_32\"\n",
    "latex = False\n",
    "results_file = \"results.tex\" if latex else \"results.txt\"\n",
    "\n",
    "records = reporting.load_records(input_dir)\n",
    "# selection_methods = [\n",
    "#     model_selection.IIDAccuracySelectionMethod,\n",
    "#     model_selection.OracleSelectionMethod\n",
    "# ]\n",
    "selection_methods = [\n",
    "    IIDAccuracySelectionMethod,\n",
    "    OracleSelectionMethod,\n",
    "]\n",
    "\n",
    "for selection_method in selection_methods:\n",
    "    print_results_tables(records, selection_method, latex)\n",
    "    \n",
    "# a = selection_method.hparams_accs(records)\n",
    "# grouped_records = reporting.get_grouped_records(records).map(lambda group:\n",
    "#     { **group, \"sweep_acc\": selection_method.sweep_acc(group[\"records\"]) }\n",
    "# ).filter(lambda g: g[\"sweep_acc\"] is not None)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-18T04:57:46.596805Z",
     "start_time": "2024-05-18T04:57:45.064358Z"
    }
   },
   "id": "46db6111be0f0d03",
   "execution_count": 1148
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "result_dir = './results_vits_3600_32'\n",
    "\n",
    "for exp in os.listdir(result_dir):\n",
    "    if not os.path.isdir(os.path.join(result_dir, exp)) or len(os.listdir(os.path.join(result_dir, exp))) == 0:\n",
    "        continue\n",
    "    with open(os.path.join(result_dir,exp, \"results.jsonl\")) as f:\n",
    "        first_line = f.readline()\n",
    "        first_line = json.loads(first_line)\n",
    "    # if first_line['args']['algorithm'] == 'HessianAlignment' and first_line['args']['dataset'] == 'PACS' and first_line['args']['test_envs'] == [3]:\n",
    "    if first_line['args']['algorithm'] == 'Fishr' and first_line['args']['dataset'] == 'PACS' and first_line['args']['test_envs'] == [3] :\n",
    "        # load the jsonl file into a df\n",
    "        df_fishr = pd.read_json(os.path.join(result_dir,exp, \"results.jsonl\"), lines=True)\n",
    "        # if first_line['args']['algorithm'] == 'HessianAlignment' and first_line['args']['dataset'] == 'PACS' and first_line['args']['test_envs'] == [3]:\n",
    "    if first_line['args']['algorithm'] == 'HessianAlignment' and first_line['args']['dataset'] == 'PACS' and first_line['args']['test_envs'] == [3] and first_line['hparams']['penalty_anneal_iters'] == 3600:\n",
    "        # load the jsonl file into a df\n",
    "        df= pd.read_json(os.path.join(result_dir,exp, \"results.jsonl\"), lines=True)\n",
    "        "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-15T13:17:46.321470Z",
     "start_time": "2024-05-15T13:17:46.092331Z"
    }
   },
   "id": "302bfce095b3fbd2",
   "execution_count": 787
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "#extract the hyperparameters from df.hparams and args from df.args\n",
    "df['grad_alpha'] = df['hparams'].apply(lambda x: x['grad_alpha'])\n",
    "df['hess_beta'] = df['hparams'].apply(lambda x: x['hess_beta'])\n",
    "df['penalty_anneal_iters'] = df['hparams'].apply(lambda x: x['penalty_anneal_iters'])\n",
    "df['test_envs'] = df['args'].apply(lambda x: x['test_envs'])\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-14T04:35:28.670894Z",
     "start_time": "2024-05-14T04:35:28.652706Z"
    }
   },
   "id": "b387311b5aeb2917",
   "execution_count": 723
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "df_fishr['test_envs'] = df_fishr['args'].apply(lambda x: x['test_envs'])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-14T04:35:29.094435Z",
     "start_time": "2024-05-14T04:35:29.091038Z"
    }
   },
   "id": "e0b60059d88ea985",
   "execution_count": 724
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "combined_dir = 'results_hgp_cma'\n",
    "cma_dir = 'results_vits_combined_bias'\n",
    "hgp_dir = 'results_hgp'\n",
    "\n",
    "\n",
    "\n",
    "if not os.path.exists(combined_dir):\n",
    "    os.makedirs(combined_dir)\n",
    "\n",
    "for exp in os.listdir(cma_dir):\n",
    "    if not os.path.isdir(os.path.join(cma_dir, exp)) or len(os.listdir(os.path.join(cma_dir, exp))) == 0:\n",
    "        continue\n",
    "    with open(os.path.join(cma_dir, exp, \"results.jsonl\")) as f:\n",
    "        first_line = f.readline()\n",
    "        first_line = json.loads(first_line)\n",
    "    if (first_line['args']['algorithm'] in ['HessianAlignment']):\n",
    "        os.system(f\"cp -r {os.path.join(cma_dir, exp)} {combined_dir}\")\n",
    "\n",
    "\n",
    "for exp in os.listdir(hgp_dir):\n",
    "    if not os.path.isdir(os.path.join(hgp_dir, exp)) or len(os.listdir(os.path.join(hgp_dir, exp))) == 0:\n",
    "        continue\n",
    "    with open(os.path.join(hgp_dir, exp, \"results.jsonl\")) as f:\n",
    "        first_line = f.readline()\n",
    "        first_line = json.loads(first_line)\n",
    "    if (first_line['args']['algorithm'] in ['HGP']):\n",
    "        os.system(f\"cp -r {os.path.join(hgp_dir, exp)} {combined_dir}\")\n",
    "\n",
    "\n",
    "# terr_dir = './results_vits_hessian_class_terra'"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-22T15:03:32.621561Z",
     "start_time": "2024-05-22T15:03:30.356951Z"
    }
   },
   "id": "25c0df118a09efd2",
   "execution_count": 224
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "b47aa1d6f7519774"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
