{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cde0e12",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.getcwd())\n",
    "os.chdir('../')\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "428bdb7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd \n",
    "import wandb\n",
    "api = wandb.Api()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c965125",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_method_main=[[\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcam\", \n",
    "                          \"gradcamgithub\",\n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgembedding\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"ours\"],\n",
    "                        [\"random\"],\n",
    "                        ]\n",
    "explanation_method_main_flatten=[j for i in explanation_method_main for j in i]\n",
    "\n",
    "dataset_name=\"ImageNette\"\n",
    "explanation_method=explanation_method_main_flatten[0]\n",
    "gpus_classifier=3\n",
    "backbone_type=\"vit_small_patch16_224\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8aa299e",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "815af627",
   "metadata": {},
   "outputs": [],
   "source": [
    "!gpustat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ef1939",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_retraining_status(backbone_type, dataset_name, api_dir=\"ch6845/transformer_interpretability_project_retraining\"):\n",
    "    api.flush()\n",
    "    runs = api.runs(api_dir)\n",
    "    result_dict_list=[]\n",
    "    for run in runs:\n",
    "        try:\n",
    "            print(run.name)\n",
    "            if run.config[\"datasets\"]!=dataset_name:\n",
    "                print(run.name, \"pass\")\n",
    "                continue\n",
    "            if run.config[\"backbone_type\"]!=backbone_type:\n",
    "                print(run.name, \"pass\")\n",
    "                continue\n",
    "            #print(run.config)\n",
    "            result_dict_list.append({\"name\": run.name,\n",
    "                                     \"accuracy\": run.summary[\"test/accuracy\"] if \"test/accuracy\" in run.summary.keys() else np.nan,\n",
    "                                     \"cohenkappa\": run.summary[\"test/cohenkappa\"] if \"test/cohenkappa\" in run.summary.keys() else np.nan,\n",
    "                                     \"backbone_type\": run.config[\"classifier_backbone_type\"],\n",
    "                                     \"datasets\": run.config[\"datasets\"],\n",
    "                                     \"classifier_enable_pos_embed\": run.config[\"classifier_enable_pos_embed\"] if \"classifier_enable_pos_embed\" in run.config else True,\n",
    "\n",
    "                                     \"explanation_location_train\": run.config[\"explanation_location_train\"],\n",
    "                                     \"explanation_mask_amount_train\": run.config[\"explanation_mask_amount_train\"],\n",
    "                                     \"explanation_mask_ascending_train\": run.config[\"explanation_mask_ascending_train\"],\n",
    "\n",
    "                                     \"explanation_location_val\": run.config[\"explanation_location_val\"],\n",
    "                                     \"explanation_mask_amount_val\": run.config[\"explanation_mask_amount_val\"],\n",
    "                                     \"explanation_mask_ascending_val\": run.config[\"explanation_mask_ascending_val\"],\n",
    "\n",
    "                                     \"explanation_location_test\": run.config[\"explanation_location_test\"],\n",
    "                                     \"explanation_mask_amount_test\": run.config[\"explanation_mask_amount_test\"],\n",
    "                                     \"explanation_mask_ascending_test\": run.config[\"explanation_mask_ascending_test\"]\n",
    "\n",
    "                                    })\n",
    "        except:\n",
    "            print(run.name, 'error')\n",
    "\n",
    "        #print('\\n')\n",
    "    result_df=pd.DataFrame(result_dict_list)\n",
    "    #result_df=result_df[(result_df[\"backbone_type\"]==backbone_type) & (result_df[\"datasets\"]==dataset_name)]\n",
    "    return result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d57547df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for insert_delete in [\"delete\", \"insert\"]:\n",
    "    for num_stage in [1, 3, 7, 14, 28, 56, 84, 112, 140, 168, 182]:\n",
    "        exp_name=f\"{dataset_name}_ROAR_classifier_{backbone_type}_1e-5_train_{insert_delete}_{explanation_method}_{num_stage}\"\n",
    "        \n",
    "        classifier_enable_pos_embed=True\n",
    "        \n",
    "        explanation_location_train = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_train.pickle\"\n",
    "        explanation_mask_amount_train = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_train = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_val = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_val.pickle\"\n",
    "        explanation_mask_amount_val = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_val = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_test = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_test.pickle\"\n",
    "        explanation_mask_amount_test = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_test = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        retraining_status=get_retraining_status(backbone_type=backbone_type, \n",
    "                                                dataset_name=dataset_name,\n",
    "                                                api_dir=\"ch6845/transformer_interpretability_project_retraining\")\n",
    "        \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            checkpoint_metric=\"accuracy\"\n",
    "            loss_weight=''\n",
    "        elif dataset_name==\"MURA\":\n",
    "            checkpoint_metric=\"CohenKappa\"\n",
    "            loss_weight=\" 'loss_weight = [21935, 14873]'\"\n",
    "        else:\n",
    "            raise\n",
    "        \n",
    "        if len(retraining_status)==0 or not ((~retraining_status[\"accuracy\"].isnull()) &\\\n",
    "                                             (retraining_status[\"classifier_enable_pos_embed\"]==classifier_enable_pos_embed) &\\\n",
    "                                             (retraining_status[\"explanation_location_train\"]==explanation_location_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_train\"]==explanation_mask_amount_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_train\"]==explanation_mask_ascending_train) &\\\n",
    "                                             (retraining_status[\"explanation_location_val\"]==explanation_location_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_val\"]==explanation_mask_amount_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_val\"]==explanation_mask_ascending_val) &\\\n",
    "                                             (retraining_status[\"explanation_location_test\"]==explanation_location_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_test\"]==explanation_mask_amount_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_test\"]==explanation_mask_ascending_test)).any():\n",
    "            command = f\"~/miniconda3/envs/vitshapley/bin/python main.py with 'stage = \\\"classifier\\\"' \\\n",
    "            'wandb_project_name = \\\"transformer_interpretability_project_retraining\\\"' 'exp_name = \\\"{exp_name}\\\"' \\\n",
    "            env_username 'gpus_classifier=[{gpus_classifier}]' \\\n",
    "            dataset_{dataset_name}_ROAR \\\n",
    "            'classifier_enable_pos_embed=[{classifier_enable_pos_embed}]' \\\n",
    "            'explanation_location_train = \\\"{explanation_location_train}\\\"' 'explanation_mask_amount_train = {explanation_mask_amount_train}' 'explanation_mask_ascending_train = {explanation_mask_ascending_train}' \\\n",
    "            'explanation_location_val = \\\"{explanation_location_val}\\\"' 'explanation_mask_amount_val = {explanation_mask_amount_val}' 'explanation_mask_ascending_val = {explanation_mask_ascending_val}' \\\n",
    "            'explanation_location_test = \\\"{explanation_location_test}\\\"' 'explanation_mask_amount_test = {explanation_mask_amount_test}' 'explanation_mask_ascending_test = {explanation_mask_ascending_test}' \\\n",
    "            'classifier_backbone_type = \\\"{backbone_type}\\\"' 'classifier_download_weight = True' 'classifier_load_path = None' \\\n",
    "            training_hyperparameters_transformer{loss_weight} 'checkpoint_metric = \\\"{checkpoint_metric}\\\"' 'learning_rate = 1e-5'\"\n",
    "            print(command)\n",
    "            !{command}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df6747ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for insert_delete in [\"delete\", \"insert\"]:\n",
    "    for num_stage in ([1, 3, 7, 14, 28, 56, 84, 112, 140, 168, 182] if insert_delete==\"insert\" else [1, 3, 7, 14, 28, 56, 84, 112, 140, 168, 182][::-1]):\n",
    "        exp_name=f\"{dataset_name}_ROAR_classifier_{backbone_type}_1e-5_train_{insert_delete}_{explanation_method}_{num_stage}_nopos\"\n",
    "        \n",
    "        classifier_enable_pos_embed=False\n",
    "        \n",
    "        explanation_location_train = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_train.pickle\"\n",
    "        explanation_mask_amount_train = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_train = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_val = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_val.pickle\"\n",
    "        explanation_mask_amount_val = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_val = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_test = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_test.pickle\"\n",
    "        explanation_mask_amount_test = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_test = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        retraining_status=get_retraining_status(backbone_type=backbone_type, \n",
    "                                                dataset_name=dataset_name,\n",
    "                                                api_dir=\"ch6845/transformer_interpretability_project_retraining_nopos\")\n",
    "        \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            checkpoint_metric=\"accuracy\"\n",
    "            loss_weight=''\n",
    "        elif dataset_name==\"MURA\":\n",
    "            checkpoint_metric=\"CohenKappa\"\n",
    "            loss_weight=\" 'loss_weight = [21935, 14873]'\"\n",
    "        else:\n",
    "            raise\n",
    "        \n",
    "        if len(retraining_status)==0 or not ((~retraining_status[\"accuracy\"].isnull()) &\\\n",
    "                                             (retraining_status[\"classifier_enable_pos_embed\"]==classifier_enable_pos_embed) &\\\n",
    "                                             (retraining_status[\"explanation_location_train\"]==explanation_location_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_train\"]==explanation_mask_amount_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_train\"]==explanation_mask_ascending_train) &\\\n",
    "                                             (retraining_status[\"explanation_location_val\"]==explanation_location_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_val\"]==explanation_mask_amount_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_val\"]==explanation_mask_ascending_val) &\\\n",
    "                                             (retraining_status[\"explanation_location_test\"]==explanation_location_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_test\"]==explanation_mask_amount_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_test\"]==explanation_mask_ascending_test)).any():\n",
    "            command = f\"~/miniconda3/envs/vitshapley/bin/python main.py with 'stage = \\\"classifier\\\"' \\\n",
    "            'wandb_project_name = \\\"transformer_interpretability_project_retraining_nopos\\\"' 'exp_name = \\\"{exp_name}\\\"' \\\n",
    "            env_username 'gpus_classifier=[{gpus_classifier}]' \\\n",
    "            dataset_{dataset_name}_ROAR \\\n",
    "            'classifier_enable_pos_embed={classifier_enable_pos_embed}' \\\n",
    "            'explanation_location_train = \\\"{explanation_location_train}\\\"' 'explanation_mask_amount_train = {explanation_mask_amount_train}' 'explanation_mask_ascending_train = {explanation_mask_ascending_train}' \\\n",
    "            'explanation_location_val = \\\"{explanation_location_val}\\\"' 'explanation_mask_amount_val = {explanation_mask_amount_val}' 'explanation_mask_ascending_val = {explanation_mask_ascending_val}' \\\n",
    "            'explanation_location_test = \\\"{explanation_location_test}\\\"' 'explanation_mask_amount_test = {explanation_mask_amount_test}' 'explanation_mask_ascending_test = {explanation_mask_ascending_test}' \\\n",
    "            'classifier_backbone_type = \\\"{backbone_type}\\\"' 'classifier_download_weight = True' 'classifier_load_path = None' \\\n",
    "            training_hyperparameters_transformer{loss_weight} 'checkpoint_metric = \\\"{checkpoint_metric}\\\"' 'learning_rate = 1e-5'\"\n",
    "            print(command)\n",
    "            !{command}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4ab595a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for insert_delete in [\"insert\"]:\n",
    "    for num_stage in ([0]):\n",
    "        exp_name=f\"{dataset_name}_ROAR_classifier_{backbone_type}_1e-5_train_{insert_delete}_{explanation_method}_{num_stage}_nopos\"\n",
    "        \n",
    "        classifier_enable_pos_embed=False\n",
    "        \n",
    "        explanation_location_train = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_train.pickle\"\n",
    "        explanation_mask_amount_train = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_train = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_val = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_val.pickle\"\n",
    "        explanation_mask_amount_val = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_val = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        explanation_location_test = f\"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/{dataset_name}/vit_base_patch16_224_{explanation_method}_test.pickle\"\n",
    "        explanation_mask_amount_test = 196-num_stage if insert_delete==\"insert\" else num_stage\n",
    "        explanation_mask_ascending_test = True if insert_delete==\"insert\" else False\n",
    "        \n",
    "        retraining_status=get_retraining_status(backbone_type=backbone_type, \n",
    "                                                dataset_name=dataset_name,\n",
    "                                                api_dir=\"ch6845/transformer_interpretability_project_retraining_nopos\")\n",
    "        \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            checkpoint_metric=\"accuracy\"\n",
    "            loss_weight=''\n",
    "        elif dataset_name==\"MURA\":\n",
    "            checkpoint_metric=\"CohenKappa\"\n",
    "            loss_weight=\" 'loss_weight = [21935, 14873]'\"\n",
    "        else:\n",
    "            raise\n",
    "        \n",
    "        if len(retraining_status)==0 or not ((~retraining_status[\"accuracy\"].isnull()) &\\\n",
    "                                             (retraining_status[\"classifier_enable_pos_embed\"]==classifier_enable_pos_embed) &\\\n",
    "                                             (retraining_status[\"explanation_location_train\"]==explanation_location_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_train\"]==explanation_mask_amount_train) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_train\"]==explanation_mask_ascending_train) &\\\n",
    "                                             (retraining_status[\"explanation_location_val\"]==explanation_location_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_val\"]==explanation_mask_amount_val) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_val\"]==explanation_mask_ascending_val) &\\\n",
    "                                             (retraining_status[\"explanation_location_test\"]==explanation_location_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_amount_test\"]==explanation_mask_amount_test) &\\\n",
    "                                             (retraining_status[\"explanation_mask_ascending_test\"]==explanation_mask_ascending_test)).any():\n",
    "            command = f\"~/miniconda3/envs/vitshapley/bin/python main.py with 'stage = \\\"classifier\\\"' \\\n",
    "            'wandb_project_name = \\\"transformer_interpretability_project_retraining_nopos\\\"' 'exp_name = \\\"{exp_name}\\\"' \\\n",
    "            env_username 'gpus_classifier=[{gpus_classifier}]' \\\n",
    "            dataset_{dataset_name}_ROAR \\\n",
    "            'classifier_enable_pos_embed={classifier_enable_pos_embed}' \\\n",
    "            'explanation_location_train = \\\"{explanation_location_train}\\\"' 'explanation_mask_amount_train = {explanation_mask_amount_train}' 'explanation_mask_ascending_train = {explanation_mask_ascending_train}' \\\n",
    "            'explanation_location_val = \\\"{explanation_location_val}\\\"' 'explanation_mask_amount_val = {explanation_mask_amount_val}' 'explanation_mask_ascending_val = {explanation_mask_ascending_val}' \\\n",
    "            'explanation_location_test = \\\"{explanation_location_test}\\\"' 'explanation_mask_amount_test = {explanation_mask_amount_test}' 'explanation_mask_ascending_test = {explanation_mask_ascending_test}' \\\n",
    "            'classifier_backbone_type = \\\"{backbone_type}\\\"' 'classifier_download_weight = True' 'classifier_load_path = None' \\\n",
    "            training_hyperparameters_transformer{loss_weight} 'checkpoint_metric = \\\"{checkpoint_metric}\\\"' 'learning_rate = 1e-5'\"\n",
    "            print(command)\n",
    "            !{command}\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "085b17f4",
   "metadata": {},
   "source": [
    "~/miniconda3/envs/vitshapley/bin/python main.py with 'stage = \"classifier\"'             'wandb_project_name = \"transformer_interpretability_project_retraining\"' 'exp_name = \"MURA_ROAR_classifier_vit_base_patch16_224_1e-5_train_original\"'             env_username 'gpus_classifier=[7]'             dataset_MURA_ROAR             'explanation_location_train = None' 'explanation_mask_amount_train = None' 'explanation_mask_ascending_train = None'             'explanation_location_val = None' 'explanation_mask_amount_val = None' 'explanation_mask_ascending_val = None'             'explanation_location_test = None' 'explanation_mask_amount_test = None' 'explanation_mask_ascending_test = None'             'classifier_backbone_type = \"vit_base_patch16_224\"' 'classifier_download_weight = True' 'classifier_load_path = None'             training_hyperparameters_transformer 'loss_weight = [21935, 14873]' 'checkpoint_metric = \"CohenKappa\"' 'learning_rate = 1e-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a5c80b",
   "metadata": {},
   "outputs": [],
   "source": [
    "~/miniconda3/envs/vitshapley/bin/python main.py with 'stage = \"classifier\"'             'wandb_project_name = \"transformer_interpretability_project_retraining_nopos\"' 'exp_name = \"ImageNette_ROAR_classifier_vit_base_patch16_224_1e-5_train_insert_attention_last_0_nopos\"'             env_username 'gpus_classifier=[3]'             dataset_ImageNette_ROAR             'classifier_enable_pos_embed=False'             'explanation_location_train = \"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/ImageNette/vit_base_patch16_224_attention_last_train.pickle\"' 'explanation_mask_amount_train = 196' 'explanation_mask_ascending_train = True'             'explanation_location_val = \"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/ImageNette/vit_base_patch16_224_attention_last_val.pickle\"' 'explanation_mask_amount_val = 196' 'explanation_mask_ascending_val = True'             'explanation_location_test = \"/homes/gws/username/ViT_shapley/results/4_1_explanation_generate/ImageNette/vit_base_patch16_224_attention_last_test.pickle\"' 'explanation_mask_amount_test = 196' 'explanation_mask_ascending_test = True'             'classifier_backbone_type = \"vit_base_patch16_224\"' 'classifier_download_weight = True' 'classifier_load_path = None'             training_hyperparameters_transformer 'checkpoint_metric = \"accuracy\"' 'learning_rate = 1e-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed3bf6f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vitshapley",
   "language": "python",
   "name": "vitshapley"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
