{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from methods import *\n",
    "from templates import *\n",
    "datasets=['imagenet','cifar100','caltech101','cars196','dtd','eurosat','food101','oxford_flowers102','oxford_iiit_pet','resisc45','sun397','fgvc_aircraft']\n",
    "backbone_names=['ViT-B/16','ViT-L/14']\n",
    "\n",
    "global method_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# template: a photo of a\n",
    "def A_Photo_Of_A(dataset_name,backbone_name,test_embeds,test_labels):\n",
    "    method_name = \"template: a photo of a\"\n",
    "    print(method_name)\n",
    "    method_list.append(method_name)\n",
    "    all_templates = [\"a photo of a {}.\"]\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Prompt Ensembling (hand-crafted)\n",
    "def Prompt_Ensembling_H(dataset_name,backbone_name,test_embeds,test_labels):\n",
    "    method_name = \"Prompt_Ensembling_H\"\n",
    "    print(method_name)\n",
    "    method_list.append(method_name)\n",
    "\n",
    "    all_templates = ZEROSHOT_TEMPLATES[dataset_name]\n",
    "    print(\"len(all_templates)\", len(all_templates))\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Max-Logit Scoring : Using hand craft\n",
    "def Max_Logit_Scoring(dataset_name,backbone_name,test_embeds,test_labels,hand_craft=True):\n",
    "    if hand_craft:\n",
    "        method_name = \"Max-Logit Scoring: Using Hand Craft\"\n",
    "        all_templates =ZEROSHOT_TEMPLATES[dataset_name]\n",
    "    else:\n",
    "        method_name = \"Max-Logit Scoring: Using Pool Set\"\n",
    "        all_templates =list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))\n",
    "    \n",
    "\n",
    "    print(method_name)\n",
    "    method_list.append(method_name)\n",
    "\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "    scores = [None] * len(all_class_embeds)\n",
    "    for p, class_embeds in enumerate(all_class_embeds):\n",
    "        logits = cosine_similarity(test_embeds, class_embeds)\n",
    "        max_logits = logits.max(dim=1).values\n",
    "        scores[p] = max_logits.mean()\n",
    "    scores = torch.stack(scores)\n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds,scores).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cmm Scoring :Using hand craft \n",
    "def Cmm_Scoring(dataset_name,backbone_name,test_embeds,test_labels,hand_craft=True,modify=False):\n",
    "   \n",
    "    if hand_craft:\n",
    "        all_templates = ZEROSHOT_TEMPLATES[dataset_name]\n",
    "        method_name = \"Cmm Scoring, Using Hand Craft\"\n",
    "    else:\n",
    "        all_templates =list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))\n",
    "        method_name = \"Cmm Scoring, Using Pool Set\"\n",
    "\n",
    "    if modify==True:\n",
    "        method_name+=\" Modified\"\n",
    "\n",
    "    method_list.append(method_name)\n",
    "\n",
    "\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "    worst_k=all_class_embeds.shape[1]//10\n",
    "    scores = [None] * len(all_class_embeds)\n",
    "    for p, class_embeds in enumerate(all_class_embeds):   \n",
    "        if modify==False:\n",
    "            scores[p]=get_cmm_score(test_embeds,class_embeds,None,worst_k)  \n",
    "        else:\n",
    "            scores[p]=get_cmm_score_modified(test_embeds,class_embeds,None,worst_k)  \n",
    "    scores = torch.stack(scores).cuda().half()\n",
    "    scores-=min(scores.min(), 0)\n",
    "    \n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds,scores).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ZPE\n",
    "\n",
    "def ZPE_Scoring(dataset_name,backbone_name,test_embeds,test_labels,hand_craft=True, selection=False):\n",
    "\n",
    "    global method_list\n",
    "\n",
    "    if selection:\n",
    "        if_sel = \"(with prompt selection)\"\n",
    "    else:\n",
    "        if_sel = \"\"\n",
    "    if hand_craft:\n",
    "        method = f\"ZPE Scoring{if_sel}: Using Hand Craft\"\n",
    "        print(method)\n",
    "        method_list.append(method)\n",
    "        all_templates =ZEROSHOT_TEMPLATES[dataset_name]\n",
    "    else:\n",
    "        method = f\"ZPE Scoring{if_sel}: Using Pool Set\"\n",
    "        print(method)\n",
    "        method_list.append(method)\n",
    "        all_templates =list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))\n",
    "\n",
    "        \n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "\n",
    "    if backbone_name == 'ViT-L/14':\n",
    "        file = os.path.join(\"laion\", \"vitl14\", \"pretrain_embeds_0.pt\")\n",
    "    elif backbone_name == 'ViT-B/16':\n",
    "    \n",
    "        file = os.path.join(\"laion\", \"vitb16\", \"pretrain_embeds_0.pt\")\n",
    "    else:\n",
    "        raise Exception(\"Not Implemented\")\n",
    "    \n",
    "    pre_embed = torch.load(file)\n",
    "\n",
    "\n",
    "\n",
    "    pre_embed = torch.Tensor(pre_embed).to(\"cuda\").half()\n",
    "\n",
    "    pre_embed_norm = torch.nn.functional.normalize(pre_embed, dim=-1)\n",
    "    pre_embed_mean = pre_embed_norm.mean(dim=0)\n",
    "\n",
    "\n",
    "    zpe_scores = [None] * len(all_class_embeds)\n",
    "    for p, class_embeds in enumerate(all_class_embeds):\n",
    "        logits = cosine_similarity(test_embeds, class_embeds)\n",
    "    \n",
    "        logits_pretrain = cosine_similarity(pre_embed_mean.reshape(1, -1), class_embeds)\n",
    "\n",
    "        e_pretrain = logits_pretrain.mean(dim=0)\n",
    "        e_test = logits.mean(dim=0)\n",
    "\n",
    "        logits_normalized = logits - (e_pretrain + e_test) / 2\n",
    "\n",
    "    \n",
    "        max_logits = logits_normalized.max(dim=1).values\n",
    "        zpe_scores[p] = max_logits.mean()\n",
    "    zpe_scores = torch.stack(zpe_scores)\n",
    "    zpe_scores -= min(zpe_scores.min(), 0)\n",
    "\n",
    "    if selection:\n",
    "        th=zpe_scores.median()\n",
    "        zpe_scores = zpe_scores * (zpe_scores >th)\n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds,zpe_scores).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def DES(dataset_name,backbone_name,test_embeds,test_labels):\n",
    "\n",
    "\n",
    "    global method_list\n",
    "    \n",
    "    method_name = \"DES [template: a photo of a ]\"\n",
    "    print(method_name)\n",
    "    method_list.append(method_name)\n",
    "    all_templates = [\"{}.\"]\n",
    "\n",
    "\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "\n",
    "    prompt_dict_file = os.path.join(\"cache_prompt_embed\", dataset_name, backbone_name, \"prompts_dict.pt\")\n",
    "    \n",
    "    from prompt_engine import load_all_dict\n",
    "    \n",
    "    all_prompt_embed_dict = load_all_dict(save_file=prompt_dict_file)\n",
    "    print(\"loading finished\")\n",
    "\n",
    "    logits = cosine_similarity(test_embeds, all_class_embeds[0])\n",
    "    logits_des = logits.clone()\n",
    "\n",
    "\n",
    "    classname_list = ZEROSHOT_CLASS_NAMES[dataset_name]\n",
    "    worst_index=range(len(classname_list))\n",
    "    for c_i in worst_index:\n",
    "\n",
    "        # try:\n",
    "            \n",
    "        prefix_class_key = (all_templates[0], classname_list[c_i])\n",
    "        # print(\"try \", prefix_class_key)\n",
    "        embed_col = [all_prompt_embed_dict[prefix_class_key][key_i].reshape(1, -1) for key_i in all_prompt_embed_dict[prefix_class_key]]\n",
    "        prompts_embed = torch.cat(embed_col, dim=0)\n",
    "\n",
    "\n",
    "        prompts_embed = prompts_embed.to(\"cuda\")\n",
    "\n",
    "        similarity_matrix_chunk = cosine_similarity(test_embeds, prompts_embed)\n",
    "        class_similarity = similarity_matrix_chunk.mean(dim=1)\n",
    "        \n",
    "\n",
    "        logits_des[:, c_i] = class_similarity\n",
    "    test_preds = logits_des.argmax(dim=1)\n",
    "\n",
    "    print(test_preds)\n",
    "\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    \n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CMM_DES\n",
    "\n",
    "def CMM_DES_Scoring(dataset_name,backbone_name,test_embeds,test_labels,hand_craft=True, selection=False):\n",
    "\n",
    "\n",
    "    global method_list\n",
    "\n",
    "    if selection:\n",
    "        if_sel = \"(with prompt selection)\"\n",
    "    else:\n",
    "        if_sel = \"\"\n",
    "    if hand_craft:\n",
    "        method = f\"CMM-DES-Scoring{if_sel}: Using Hand Craft\"\n",
    "        print(method)\n",
    "        method_list.append(method)\n",
    "        all_templates =ZEROSHOT_TEMPLATES[dataset_name]\n",
    "    else:\n",
    "        method = f\"CMM-DES-Scoring{if_sel}: Using Pool Set\"\n",
    "        print(method)\n",
    "        method_list.append(method)\n",
    "        all_templates =list(set(sum(ZEROSHOT_TEMPLATES.values(), [])))\n",
    "\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "\n",
    "    prompt_dict_file = os.path.join(\"cache_prompt_embed\", dataset_name, backbone_name, \"prompts_dict.pt\")\n",
    "\n",
    "    from prompt_engine import load_all_dict\n",
    "    \n",
    "    all_prompt_embed_dict = load_all_dict(save_file=prompt_dict_file)\n",
    "\n",
    "    print(\"loading finished\")\n",
    "\n",
    "\n",
    "    worst_k=all_class_embeds.shape[1]//10\n",
    "\n",
    "    scores = [None] * len(all_class_embeds)\n",
    "\n",
    "    all_logits = []\n",
    "    \n",
    "    classname_list = ZEROSHOT_CLASS_NAMES[dataset_name]\n",
    "    \n",
    "  \n",
    "\n",
    "    success_count, all_load_count = 0, 0\n",
    "\n",
    "    for prefix_i in tqdm(range(len(all_templates))): \n",
    "\n",
    "        score_over_class = torch.Tensor([])\n",
    "        logits = cosine_similarity(test_embeds, all_class_embeds[prefix_i])\n",
    "        score_over_class=get_all_cmm_score_from_logits(logits,test_labels=None,worst_k=None)\n",
    "\n",
    "        sorted_cmm, sorted_index = torch.sort(score_over_class, descending=False)\n",
    "\n",
    "        worst_index = sorted_index[:worst_k]\n",
    "\n",
    "        logits_des = logits.clone()\n",
    "\n",
    "        for c_i in worst_index:\n",
    "            all_load_count += 1\n",
    "            try:\n",
    "            \n",
    "                prefix_class_key = (all_templates[0], classname_list[c_i])\n",
    "                embed_col = [all_prompt_embed_dict[prefix_class_key][key_i].reshape(1, -1) for key_i in all_prompt_embed_dict[prefix_class_key]]\n",
    "                prompts_embed = torch.cat(embed_col, dim=0)\n",
    "\n",
    "\n",
    "                prompts_embed = prompts_embed.to(\"cuda\")\n",
    "\n",
    "                similarity_matrix_chunk = cosine_similarity(test_embeds, prompts_embed)\n",
    "                class_similarity = similarity_matrix_chunk.mean(dim=1)\n",
    "\n",
    "                logits_des[:, c_i] = class_similarity\n",
    "\n",
    "                # print(\"wrap class {} [idx: {}] with gpt description...\".format(classname_list[c_i], c_i))\n",
    "\n",
    "                success_count += 1\n",
    "                \n",
    "            except:\n",
    "                pass\n",
    "\n",
    "        all_logits.append(logits_des.reshape(1, logits_des.shape[0],logits_des.shape[1]).cpu())\n",
    "        scores[prefix_i] = get_cmm_score_from_logits(logits_des,test_labels=None,worst_k=worst_k)  \n",
    "\n",
    "\n",
    "    \n",
    "    print(\"success rate: \", success_count / all_load_count)\n",
    "\n",
    "    scores = torch.stack(scores).cuda().half()\n",
    "    all_logits = torch.cat(all_logits, dim=0)\n",
    "    # print(all_logits.shape)\n",
    "\n",
    "    softmax_scores = torch.softmax(scores.reshape(1,-1),dim=1).reshape(-1)\n",
    "\n",
    "\n",
    "    if selection:\n",
    "        th=softmax_scores.median()\n",
    "        softmax_scores = softmax_scores * (softmax_scores >= th)\n",
    "\n",
    "\n",
    "    test_preds = calc_pred_from_logits(all_logits=all_logits,weights=softmax_scores)\n",
    "\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    \n",
    "    return cls_accs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "def to_excel(backbone_name,dataset_name,res,names):\n",
    "    backbone_name=backbone_name.replace('/','')\n",
    "    save_path=\"./results/{}_{}.csv\".format(dataset_name,backbone_name)\n",
    "    print(save_path)\n",
    "    res_set={'Accuracy':[],'Worst@1':[],'Worst@5':[],'Worst@10':[],'Worst@20':[],'Worst@50':[],'Worst@100':[],'Harmonic Mean':[],'Geometric Mean':[]}\n",
    "    assert(len(res)==len(names))\n",
    "    for i in range(len(res)):\n",
    "        accs=res[i]\n",
    "        method_name=names[i]\n",
    "        for key in res_set.keys():\n",
    "            res_set[key].append(accs[key])\n",
    "    df = pd.DataFrame(res_set, index = names)\n",
    "    df.to_csv(save_path)\n",
    "    \n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# template: a photo of a\n",
    "def Class_Name(dataset_name,backbone_name,test_embeds,test_labels):\n",
    "    method_name = \"Class Name\"\n",
    "    print(method_name)\n",
    "    method_list.append(method_name)\n",
    "    all_templates = [\"{}\"]\n",
    "    all_class_embeds = load_all_class_embeds(dataset_name, backbone_name, all_templates)\n",
    "    test_preds = cosine_similarity(test_embeds, all_class_embeds).argmax(dim=-1)\n",
    "    cls_accs = evaluate(test_labels, test_preds)\n",
    "    return cls_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 80/80 [00:00<00:00, 1139.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 68.26%\n",
      "Number of classes: 1000\n",
      "Worst@1: 0.0% index: 231\n",
      "Worst@5: 0.4% Acc: [0. 0. 0. 0. 2.]\n",
      "Worst@10: 1.6%\n",
      "Worst@20: 4.7% Acc: [ 0.  0.  0.  0.  2.  2.  2.  2.  4.  4.  6.  6.  6.  6.  8.  8.  8. 10.\n",
      " 10. 10.]\n",
      "Worst@50: 12.0%\n",
      "Worst@100: 21.46%\n",
      "Mean Class Acc: 68.26%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/imagenet_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 80/80 [00:02<00:00, 31.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 75.37%\n",
      "Number of classes: 1000\n",
      "Worst@1: 0.0% index: 638\n",
      "Worst@5: 1.2% Acc: [0. 0. 0. 2. 4.]\n",
      "Worst@10: 3.4%\n",
      "Worst@20: 8.1% Acc: [ 0.  0.  0.  2.  4.  4.  4.  6.  6.  8.  8. 10. 12. 12. 12. 12. 14. 14.\n",
      " 16. 18.]\n",
      "Worst@50: 19.72%\n",
      "Worst@100: 29.76%\n",
      "Mean Class Acc: 75.37%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/imagenet_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:00<00:00, 3927.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 68.62%\n",
      "Number of classes: 100\n",
      "Worst@1: 10.0% index: 67\n",
      "Worst@5: 19.2% Acc: [10. 17. 21. 23. 25.]\n",
      "Worst@10: 25.4%\n",
      "Worst@20: 36.7% Acc: [10. 17. 21. 23. 25. 25. 26. 28. 39. 40. 45. 45. 45. 46. 47. 48. 50. 51.\n",
      " 51. 52.]\n",
      "Worst@50: 52.46%\n",
      "Worst@100: 68.62%\n",
      "Mean Class Acc: 68.62%\n",
      "Harmonic Mean: 57.55%\n",
      "Geometric Mean: 64.32%\n",
      "./results/cifar100_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:00<00:00, 59.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 78.44%\n",
      "Number of classes: 100\n",
      "Worst@1: 26.0% index: 67\n",
      "Worst@5: 38.0% Acc: [26. 39. 40. 40. 45.]\n",
      "Worst@10: 46.2%\n",
      "Worst@20: 54.2% Acc: [26. 39. 40. 40. 45. 52. 52. 55. 55. 58. 59. 61. 61. 61. 62. 62. 63. 64.\n",
      " 64. 65.]\n",
      "Worst@50: 66.3%\n",
      "Worst@100: 78.44%\n",
      "Mean Class Acc: 78.44%\n",
      "Harmonic Mean: 74.18%\n",
      "Geometric Mean: 76.58%\n",
      "./results/cifar100_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 34/34 [00:00<00:00, 2397.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 87.7%\n",
      "Number of classes: 101\n",
      "Worst@1: 20.18% index: 55\n",
      "Worst@5: 24.07% Acc: [20.1754386  20.68965517 23.52941176 24.24242424 31.70731707]\n",
      "Worst@10: 40.06%\n",
      "Worst@20: 61.03% Acc: [20.1754386  20.68965517 23.52941176 24.24242424 31.70731707 33.87096774\n",
      " 55.55555556 57.47126437 66.66666667 66.66666667 70.58823529 72.\n",
      " 80.39215686 82.35294118 84.44444444 85.71428571 85.71428571 85.71428571\n",
      " 86.04651163 87.14285714]\n",
      "Worst@50: 80.58%\n",
      "Worst@100: 89.92%\n",
      "Mean Class Acc: 90.02%\n",
      "Harmonic Mean: 79.83%\n",
      "Geometric Mean: 86.54%\n",
      "./results/caltech101_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 34/34 [00:00<00:00, 82.22it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 90.55%\n",
      "Number of classes: 101\n",
      "Worst@1: 8.28% index: 0\n",
      "Worst@5: 44.87% Acc: [ 8.27586207 43.93939394 48.7804878  58.82352941 64.51612903]\n",
      "Worst@10: 59.32%\n",
      "Worst@20: 74.57% Acc: [ 8.27586207 43.93939394 48.7804878  58.82352941 64.51612903 67.54385965\n",
      " 68.88888889 76.5        77.47126437 78.43137255 85.71428571 88.\n",
      " 88.         88.23529412 89.74358974 91.42857143 91.42857143 91.4893617\n",
      " 91.80327869 92.30769231]\n",
      "Worst@50: 87.74%\n",
      "Worst@100: 93.85%\n",
      "Mean Class Acc: 93.92%\n",
      "Harmonic Mean: 84.37%\n",
      "Geometric Mean: 91.75%\n",
      "./results/caltech101_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:00<00:00, 2953.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 64.67%\n",
      "Number of classes: 196\n",
      "Worst@1: 0.0% index: 74\n",
      "Worst@5: 0.93% Acc: [0.        0.        0.        2.3255814 2.3255814]\n",
      "Worst@10: 2.72%\n",
      "Worst@20: 8.58% Acc: [ 0.          0.          0.          2.3255814   2.3255814   2.38095238\n",
      "  4.44444444  4.76190476  5.40540541  5.55555556  6.97674419  9.09090909\n",
      " 13.33333333 14.28571429 14.70588235 15.         16.27906977 16.66666667\n",
      " 19.04761905 19.04761905]\n",
      "Worst@50: 24.16%\n",
      "Worst@100: 41.12%\n",
      "Mean Class Acc: 64.66%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/cars196_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:00<00:00, 59.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 77.64%\n",
      "Number of classes: 196\n",
      "Worst@1: 2.78% index: 59\n",
      "Worst@5: 10.42% Acc: [ 2.77777778  6.97674419 10.25641026 15.         17.07317073]\n",
      "Worst@10: 15.35%\n",
      "Worst@20: 23.72% Acc: [ 2.77777778  6.97674419 10.25641026 15.         17.07317073 17.14285714\n",
      " 18.18181818 21.42857143 21.42857143 23.25581395 25.         25.58139535\n",
      " 27.5        30.         32.43243243 32.5        34.14634146 35.71428571\n",
      " 35.8974359  42.10526316]\n",
      "Worst@50: 43.93%\n",
      "Worst@100: 60.67%\n",
      "Mean Class Acc: 77.61%\n",
      "Harmonic Mean: 56.24%\n",
      "Geometric Mean: 71.32%\n",
      "./results/cars196_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:00<00:00, 3675.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 46.22%\n",
      "Number of classes: 47\n",
      "Worst@1: 0.0% index: 12\n",
      "Worst@5: 1.0% Acc: [0.  0.  0.  2.5 2.5]\n",
      "Worst@10: 2.25%\n",
      "Worst@20: 12.12% Acc: [ 0.   0.   0.   2.5  2.5  2.5  2.5  2.5  5.   5.   7.5 12.5 17.5 17.5\n",
      " 22.5 25.  25.  27.5 30.  35. ]\n",
      "Worst@50: 46.22%\n",
      "Worst@100: 46.22%\n",
      "Mean Class Acc: 46.22%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/dtd_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:00<00:00, 45.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 55.32%\n",
      "Number of classes: 47\n",
      "Worst@1: 0.0% index: 21\n",
      "Worst@5: 4.5% Acc: [ 0.   2.5  2.5  7.5 10. ]\n",
      "Worst@10: 9.5%\n",
      "Worst@20: 23.75% Acc: [ 0.   2.5  2.5  7.5 10.  12.5 12.5 12.5 17.5 17.5 22.5 25.  32.5 35.\n",
      " 37.5 37.5 45.  45.  47.5 52.5]\n",
      "Worst@50: 55.32%\n",
      "Worst@100: 55.32%\n",
      "Mean Class Acc: 55.32%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/dtd_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 2459.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 54.44%\n",
      "Number of classes: 10\n",
      "Worst@1: 16.72% index: 6\n",
      "Worst@5: 28.53% Acc: [16.72       19.76666667 20.03333333 38.64       47.5       ]\n",
      "Worst@10: 54.27%\n",
      "Worst@20: 54.27% Acc: [16.72       19.76666667 20.03333333 38.64       47.5        63.46666667\n",
      " 72.12       78.6        91.6        94.28      ]\n",
      "Worst@50: 54.27%\n",
      "Worst@100: 54.27%\n",
      "Mean Class Acc: 54.27%\n",
      "Harmonic Mean: 36.88%\n",
      "Geometric Mean: 45.47%\n",
      "./results/eurosat_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 132.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 60.58%\n",
      "Number of classes: 10\n",
      "Worst@1: 10.6% index: 0\n",
      "Worst@5: 38.75% Acc: [10.6        17.43333333 41.1        57.95       66.68      ]\n",
      "Worst@10: 61.85%\n",
      "Worst@20: 61.85% Acc: [10.6        17.43333333 41.1        57.95       66.68       67.84\n",
      " 77.26666667 90.36666667 90.56       98.72      ]\n",
      "Worst@50: 61.85%\n",
      "Worst@100: 61.85%\n",
      "Mean Class Acc: 61.85%\n",
      "Harmonic Mean: 37.28%\n",
      "Geometric Mean: 51.22%\n",
      "./results/eurosat_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 1534.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 89.01%\n",
      "Number of classes: 101\n",
      "Worst@1: 37.6% index: 93\n",
      "Worst@5: 64.88% Acc: [37.6 67.2 70.8 74.  74.8]\n",
      "Worst@10: 71.2%\n",
      "Worst@20: 76.04% Acc: [37.6 67.2 70.8 74.  74.8 76.  77.2 77.6 78.  78.8 78.8 78.8 79.6 80.\n",
      " 80.4 81.6 81.6 82.  82.4 83.6]\n",
      "Worst@50: 82.84%\n",
      "Worst@100: 88.9%\n",
      "Mean Class Acc: 89.01%\n",
      "Harmonic Mean: 87.74%\n",
      "Geometric Mean: 88.47%\n",
      "./results/food101_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 46.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 93.32%\n",
      "Number of classes: 101\n",
      "Worst@1: 62.4% index: 52\n",
      "Worst@5: 75.2% Acc: [62.4 65.6 80.  82.4 85.6]\n",
      "Worst@10: 80.96%\n",
      "Worst@20: 84.82% Acc: [62.4 65.6 80.  82.4 85.6 85.6 86.  87.2 87.2 87.6 88.  88.  88.4 88.4\n",
      " 88.8 88.8 88.8 89.2 89.2 89.2]\n",
      "Worst@50: 89.36%\n",
      "Worst@100: 93.26%\n",
      "Mean Class Acc: 93.32%\n",
      "Harmonic Mean: 92.87%\n",
      "Geometric Mean: 93.11%\n",
      "./results/food101_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 2504.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 71.98%\n",
      "Number of classes: 102\n",
      "Worst@1: 0.0% index: 1\n",
      "Worst@5: 0.0% Acc: [0. 0. 0. 0. 0.]\n",
      "Worst@10: 0.0%\n",
      "Worst@20: 5.08% Acc: [ 0.          0.          0.          0.          0.          0.\n",
      "  0.          0.          0.          0.          0.          2.\n",
      "  2.17391304  3.19148936  3.63636364  5.          5.         15.38461538\n",
      " 25.80645161 39.3258427 ]\n",
      "Worst@50: 44.3%\n",
      "Worst@100: 69.83%\n",
      "Mean Class Acc: 70.42%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/oxford_flowers102_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 35.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 78.94%\n",
      "Number of classes: 102\n",
      "Worst@1: 0.0% index: 36\n",
      "Worst@5: 0.85% Acc: [0.         0.         0.         0.6097561  3.63636364]\n",
      "Worst@10: 4.14%\n",
      "Worst@20: 20.77% Acc: [ 0.          0.          0.          0.6097561   3.63636364  4.76190476\n",
      "  5.          5.         10.86956522 11.53846154 17.77777778 21.34831461\n",
      " 22.         23.68421053 28.57142857 31.25       52.5        52.77777778\n",
      " 61.29032258 62.76595745]\n",
      "Worst@50: 58.97%\n",
      "Worst@100: 78.91%\n",
      "Mean Class Acc: 79.33%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/oxford_flowers102_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 2940.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 88.72%\n",
      "Number of classes: 37\n",
      "Worst@1: 31.82% index: 7\n",
      "Worst@5: 53.76% Acc: [31.81818182 33.         66.         67.         71.        ]\n",
      "Worst@10: 68.8%\n",
      "Worst@20: 80.75% Acc: [31.81818182 33.         66.         67.         71.         77.\n",
      " 79.3814433  87.         87.         88.7755102  90.         91.\n",
      " 91.01123596 92.         92.92929293 93.         93.         94.\n",
      " 95.         95.        ]\n",
      "Worst@50: 88.54%\n",
      "Worst@100: 88.54%\n",
      "Mean Class Acc: 88.54%\n",
      "Harmonic Mean: 82.65%\n",
      "Geometric Mean: 86.28%\n",
      "./results/oxford_iiit_pet_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 84.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 93.76%\n",
      "Number of classes: 37\n",
      "Worst@1: 59.0% index: 26\n",
      "Worst@5: 72.49% Acc: [59.         72.         73.         77.         81.44329897]\n",
      "Worst@10: 81.9%\n",
      "Worst@20: 89.05% Acc: [59.         72.         73.         77.         81.44329897 87.\n",
      " 88.76404494 92.92929293 93.87755102 94.         94.         94.\n",
      " 95.         96.         96.         97.         97.         97.\n",
      " 97.97979798 98.        ]\n",
      "Worst@50: 93.75%\n",
      "Worst@100: 93.75%\n",
      "Mean Class Acc: 93.75%\n",
      "Harmonic Mean: 92.54%\n",
      "Geometric Mean: 93.2%\n",
      "./results/oxford_iiit_pet_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:00<00:00, 1559.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 65.46%\n",
      "Number of classes: 45\n",
      "Worst@1: 0.0% index: 6\n",
      "Worst@5: 13.27% Acc: [ 0.          3.1496063  16.15384615 22.04724409 25.        ]\n",
      "Worst@10: 25.41%\n",
      "Worst@20: 40.84% Acc: [ 0.          3.1496063  16.15384615 22.04724409 25.         33.84615385\n",
      " 34.72222222 34.9112426  42.13836478 42.17687075 44.91525424 51.04895105\n",
      " 53.14685315 53.8961039  58.33333333 59.13043478 59.3220339  60.\n",
      " 61.33333333 61.4379085 ]\n",
      "Worst@50: 65.32%\n",
      "Worst@100: 65.32%\n",
      "Mean Class Acc: 65.32%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/resisc45_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 18/18 [00:00<00:00, 312.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 70.67%\n",
      "Number of classes: 45\n",
      "Worst@1: 0.79% index: 42\n",
      "Worst@5: 16.0% Acc: [ 0.78740157 16.66666667 17.69230769 21.76870748 23.07692308]\n",
      "Worst@10: 30.01%\n",
      "Worst@20: 48.22% Acc: [ 0.78740157 16.66666667 17.69230769 21.76870748 23.07692308 30.71428571\n",
      " 44.44444444 46.61016949 47.92899408 50.43478261 61.01694915 62.33766234\n",
      " 63.88888889 64.77987421 65.71428571 67.29559748 67.33333333 67.85714286\n",
      " 71.95121951 72.14285714]\n",
      "Worst@50: 70.57%\n",
      "Worst@100: 70.57%\n",
      "Mean Class Acc: 70.57%\n",
      "Harmonic Mean: 22.18%\n",
      "Geometric Mean: 60.26%\n",
      "./results/resisc45_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 1037.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 63.75%\n",
      "Number of classes: 397\n",
      "Worst@1: 0.0% index: 49\n",
      "Worst@5: 0.0% Acc: [0. 0. 0. 0. 0.]\n",
      "Worst@10: 1.4%\n",
      "Worst@20: 6.25% Acc: [ 0.          0.          0.          0.          0.          0.\n",
      "  2.56410256  2.85714286  3.84615385  4.76190476  5.17241379  5.45454545\n",
      "  5.88235294  8.33333333 10.         13.15789474 14.70588235 15.38461538\n",
      " 15.90909091 16.99346405]\n",
      "Worst@50: 18.76%\n",
      "Worst@100: 30.21%\n",
      "Mean Class Acc: 65.18%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/sun397_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 53.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 67.38%\n",
      "Number of classes: 397\n",
      "Worst@1: 0.0% index: 160\n",
      "Worst@5: 0.0% Acc: [0. 0. 0. 0. 0.]\n",
      "Worst@10: 2.28%\n",
      "Worst@20: 8.05% Acc: [ 0.          0.          0.          0.          0.          2.7027027\n",
      "  3.44827586  3.57142857  6.4516129   6.66666667  7.69230769  8.69565217\n",
      " 11.53846154 11.53846154 14.70588235 15.2173913  15.68627451 17.24137931\n",
      " 17.64705882 18.18181818]\n",
      "Worst@50: 19.25%\n",
      "Worst@100: 30.8%\n",
      "Mean Class Acc: 68.28%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/sun397_ViT-L14.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 2556.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 24.33%\n",
      "Number of classes: 100\n",
      "Worst@1: 0.0% index: 0\n",
      "Worst@5: 0.0% Acc: [0. 0. 0. 0. 0.]\n",
      "Worst@10: 0.0%\n",
      "Worst@20: 0.0% Acc: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "Worst@50: 4.09%\n",
      "Worst@100: 24.27%\n",
      "Mean Class Acc: 24.27%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/fgvc_aircraft_ViT-B16.csv\n",
      "Loading existing checkpoints\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 77.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 31.56%\n",
      "Number of classes: 100\n",
      "Worst@1: 0.0% index: 3\n",
      "Worst@5: 0.0% Acc: [0. 0. 0. 0. 0.]\n",
      "Worst@10: 0.0%\n",
      "Worst@20: 0.74% Acc: [0.         0.         0.         0.         0.         0.\n",
      " 0.         0.         0.         0.         0.         0.\n",
      " 0.         0.         0.         2.94117647 2.94117647 2.94117647\n",
      " 2.94117647 3.03030303]\n",
      "Worst@50: 9.01%\n",
      "Worst@100: 31.57%\n",
      "Mean Class Acc: 31.57%\n",
      "Harmonic Mean: 0.0%\n",
      "Geometric Mean: 0.0%\n",
      "./results/fgvc_aircraft_ViT-L14.csv\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "global method_list \n",
    "method_list = []\n",
    "\n",
    "for dataset_name in datasets:\n",
    "    for backbone_name in backbone_names:\n",
    "        test_embeds, test_labels = load_test_embeds_and_labels(dataset_name, backbone_name)\n",
    "        method_list = []\n",
    "        res=[]\n",
    "        # #class name\n",
    "        res.append(Class_Name(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        # #A photo of A\n",
    "        res.append(A_Photo_Of_A(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        # # using handcraft\n",
    "        res.append(Prompt_Ensembling_H(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        res.append(Max_Logit_Scoring(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        res.append(ZPE_Scoring(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        \n",
    "        res.append(DES(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        \n",
    "        res.append(Cmm_Scoring(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        res.append(CMM_DES_Scoring(dataset_name,backbone_name,test_embeds,test_labels))\n",
    "        res.append(CMM_DES_Scoring(dataset_name,backbone_name,test_embeds,test_labels, selection=True))\n",
    "       \n",
    "        to_excel(backbone_name,dataset_name,res,method_list)\n",
    "        \n",
    "        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
