{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import re\n",
    "import json\n",
    "from PIL import Image\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_path = '/mnt/hdd/tanhz/clip_bk/res'\n",
    "res_start_str='=> result'\n",
    "keyword = 'accuracy'\n",
    "ky = 'total'\n",
    "metric = {\n",
    "        \"name\": keyword,\n",
    "        \"regex\": re.compile(fr\"\\* {keyword}: ([\\.\\deE+-]+)%\"),\n",
    "    }\n",
    "all_results = dict()\n",
    "for dataset_name in os.listdir(root_path):\n",
    "    all_results[dataset_name] = dict()\n",
    "    dataset_path = os.path.join(root_path, dataset_name)\n",
    "    for trainer in os.listdir(dataset_path):\n",
    "        all_results[dataset_name][trainer] = dict()\n",
    "        sub_path = os.path.join(dataset_path, trainer)\n",
    "        for sub in os.listdir(sub_path):\n",
    "            all_results[dataset_name][trainer][sub] = dict()\n",
    "            seed_path = os.path.join(sub_path, sub)\n",
    "            outputs = []\n",
    "            for seed in os.listdir(seed_path):\n",
    "                output = dict()\n",
    "                with open(os.path.join(seed_path, seed, 'log.txt'), 'r') as f:\n",
    "                    lines = f.readlines()\n",
    "                    res_signal = False\n",
    "                    for line in lines:\n",
    "                        line = line.strip()\n",
    "                        if line == res_start_str:\n",
    "                            res_signal = True\n",
    "                        if not res_signal:\n",
    "                            continue\n",
    "                        # for metric in metrics:\n",
    "                        match = metric[\"regex\"].search(line)\n",
    "                        if match:\n",
    "                            regex = float(match.group(1))\n",
    "                            output[metric['name']]=regex\n",
    "                if output:\n",
    "                    outputs.append(output)\n",
    "            results = dict()\n",
    "            for output in outputs:\n",
    "                for key, value in output.items():\n",
    "                    if key not in results:\n",
    "                        results[key] = []\n",
    "                    results[key].append(value)\n",
    "            final_results = dict()\n",
    "            for key, values in results.items():\n",
    "                if key == 'accuracy':\n",
    "                    avg = np.mean(values)\n",
    "                    std = np.std(values)\n",
    "                    final_results['avg'] = avg\n",
    "                    final_results['std'] = std\n",
    "                elif key == 'total':\n",
    "                    avg = np.mean(values)\n",
    "                    final_results['total'] = avg\n",
    "            all_results[dataset_name][trainer][sub] = final_results\n",
    "\n",
    "with open('/mnt/hdd/tanhz/clip_bk/res.json','w') as f:\n",
    "    json.dump(all_results, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names = [\n",
    "    'imagenet','caltech101','dtd','eurosat','fgvc_aircraft','food101','oxford_flowers','oxford_pets','stanford_cars','sun397','ucf101']\n",
    "dataset_labels = [\n",
    "    'ImageNet','Caltech101','DTD','EuroSAT','FGVCAircraft','Food101','Flowers102','OxfordPets','StanfordCars','SUN397','UCF101']\n",
    "\n",
    "# colors = [\n",
    "#         \"#003147\",\"#FF0097\",\"#DC3023\",\"#3DE1AD\",\"#057748\",\"#009ADE\",\"#AF58BA\",\"#FFC61E\",\"#F28522\",\"#A6761D\",\"#443F90\",\"#8F003B\"\n",
    "#     ]\n",
    "# colors = [\n",
    "#         \"#003147\",\"#104680\",\"#317CB7\",\"#6DADD1\",\"#B6D7E8\",\"#E9F1F4\",\"#FBE3D5\",\"#F6B293\",\"#DC6D57\",\"#B72230\",\"#6D011F\"\n",
    "#     ]\n",
    "# \"#44045A\"\n",
    "colors = [\n",
    "        \"#44045A\",\"#003147\",\"#413E85\",\"#104680\",\"#317CB7\",\n",
    "        \"#FFB703\",\"#FB8302\",\"#FF5100\",\n",
    "        \"#D84527\",\"#B72230\",\"#6D011F\"\n",
    "    ]\n",
    "markers = ['s','v','o','^','<','>','p','P','*','h','D','X','8']\n",
    "trainers = ['ZeroshotCLIP','CoOp','CoCoOp','VPT','MaPLe',\n",
    "            'ProGrad','KgCoOp','RPO',\n",
    "            'PromptSRC','ProDA','TCP'\n",
    "            ]\n",
    "trainer_labels = ['CLIP(Zero-shot)','CoOp','CoCoOp','VPT','MaPLe',\n",
    "            'ProGrad','KgCoOp','RPO',\n",
    "            'PromptSRC','ProDA','TCP'\n",
    "            ]\n",
    "# ,'PLOTPP'\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams['font.size'] = 20\n",
    "plt.rcParams['figure.dpi'] = 400\n",
    "plt.rcParams['figure.dpi'] = 400\n",
    "# colors = plt.get_cmap('hsv')\n",
    "\n",
    "# plt.rcParams['font.weight'] = 'bold'\n",
    "# plt.rcParams['font.style'] = 'italic'\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "集合结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'dataset_names' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 11\u001b[0m\n\u001b[1;32m      9\u001b[0m grid \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39mGridSpec(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m4\u001b[39m,wspace\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.15\u001b[39m,hspace\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m,figure\u001b[38;5;241m=\u001b[39mplt\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m28\u001b[39m,\u001b[38;5;241m25\u001b[39m)))\n\u001b[1;32m     10\u001b[0m p \u001b[38;5;241m=\u001b[39m q \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataset_name,dataset_label \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[43mdataset_names\u001b[49m,dataset_labels):\n\u001b[1;32m     12\u001b[0m     plt\u001b[38;5;241m.\u001b[39msubplot(grid[p,q])\n\u001b[1;32m     13\u001b[0m     q \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'dataset_names' is not defined"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 2800x2500 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "xs = dict()\n",
    "xs['new'] = ['base','new1','new2','new3','new4','new5']\n",
    "xs['newratio'] = ['base','new_ratio1','new_ratio2','new_ratio3','new_ratio4','new_ratio5']\n",
    "x=dict()\n",
    "x['t'] = ['0','0.2','0.4','0.6','0.8','1.0']\n",
    "plt.rcParams['font.size'] = 20\n",
    "for x_type in xs.keys():\n",
    "    x_axis = xs[x_type]\n",
    "    grid = plt.GridSpec(3,4,wspace=0.15,hspace=0.3,figure=plt.figure(figsize=(28,25)))\n",
    "    p = q = 0\n",
    "    for dataset_name,dataset_label in zip(dataset_names,dataset_labels):\n",
    "        plt.subplot(grid[p,q])\n",
    "        q += 1\n",
    "        if q > 3:\n",
    "            q = 0\n",
    "            p += 1\n",
    "        i = 0\n",
    "        for trainer, trainer_label in zip(trainers,trainer_labels):\n",
    "            y_axis_acc = []\n",
    "            y_axis_std = []\n",
    "            for sub in x_axis:\n",
    "                y_axis_acc.append(all_results[dataset_name][trainer][sub]['avg'])\n",
    "                y_axis_std.append(all_results[dataset_name][trainer][sub]['std'])\n",
    "            y_axis_acc = np.array(y_axis_acc)\n",
    "            y_axis_std = np.array(y_axis_std)\n",
    "            plt.plot(x['t'], y_axis_acc, linestyle='--',marker=markers[i],markersize=7,color=colors[i] ,alpha=1, linewidth=1, label=trainer_label)\n",
    "            plt.fill_between(x['t'], y_axis_acc-y_axis_std, y_axis_acc+y_axis_std, color=colors[i], alpha=0.1)\n",
    "            # plt.errorbar(x_axis, y_axis_acc, yerr=y_axis_std, ecolor=colors[i], color=colors[i], elinewidth=1, capsize=5)\n",
    "            i += 1\n",
    "        plt.title(f'{dataset_label}')\n",
    "        # _base2{x_type}\n",
    "        plt.xlabel('t')\n",
    "        plt.ylabel('Accuracy')\n",
    "        if p == 2 and q == 3 :\n",
    "            plt.legend(bbox_to_anchor=(1.15, 0), loc='lower left')\n",
    "            # print(legend.get_children)\n",
    "        if x_type=='newratio' and dataset_name=='sun397':\n",
    "            plt.ylim(66,86)\n",
    "            plt.yticks(np.arange(68,88,4))\n",
    "        if x_type=='new' and dataset_name=='oxford_flowers':\n",
    "            plt.ylim(63,99)\n",
    "        if x_type=='newratio' and dataset_name=='oxford_flowers':\n",
    "            plt.ylim(59,99)\n",
    "        # plt.savefig(f'/mnt/hdd/tanhz/clip_bk/img/{dataset_name}_base2{x_type}.png',dpi=400)\n",
    "        # plt.clf()\n",
    "    plt.savefig(f'/mnt/hdd/tanhz/clip_bk/new_img_2/base2{x_type}-all.png',bbox_inches='tight', pad_inches=0.2)\n",
    "    plt.savefig(f'/mnt/hdd/tanhz/clip_bk/new_img_2/base2{x_type}-all.eps',bbox_inches='tight', pad_inches=0.2)\n",
    "    plt.savefig(f'/mnt/hdd/tanhz/clip_bk/new_img_2/base2{x_type}-all.pdf',format='pdf',bbox_inches='tight', pad_inches=0.2)\n",
    "    plt.clf()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NewRatio - New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = dict()\n",
    "xs['new'] = ['base','new1','new2','new3','new4','new5']\n",
    "xs['newratio'] = ['base','new_ratio1','new_ratio2','new_ratio3','new_ratio4','new_ratio5']\n",
    "x_diff = ['new_ratio1-new1','new_ratio2-new2','new_ratio3-new3','new_ratio4-new4','new_ratio5-new5']\n",
    "colors = plt.get_cmap('hsv')\n",
    "\n",
    "# for dataset_name in os.listdir(root_path):\n",
    "for dataset_name in os.listdir(root_path):\n",
    "\n",
    "    fig = plt.figure(figsize=(8, 6))\n",
    "    i = 0\n",
    "    for trainer in all_results[dataset_name].keys():\n",
    "        y_axis_accs = []\n",
    "        for x_type in xs.keys():\n",
    "            x_axis = xs[x_type]\n",
    "            y_axis_acc = []\n",
    "            for sub in x_axis:\n",
    "                y_axis_acc.append(all_results[dataset_name][trainer][sub]['avg'])\n",
    "            y_axis_acc = np.array(y_axis_acc)\n",
    "            y_axis_accs.append(y_axis_acc)\n",
    "        y_axis = np.array(y_axis_accs[1]-y_axis_accs[0])[1:]\n",
    "        plt.plot(x_diff, y_axis, '*--',color=colors(i) ,alpha=0.5, linewidth=1, label=trainer)\n",
    "        #     plt.fill_between(x_axis, y_axis_acc-y_axis_std, y_axis_acc+y_axis_std, color=colors(i), alpha=0.2)\n",
    "            # plt.errorbar(x_axis, y_axis_acc, yerr=y_axis_std, ecolor=colors(i), color=colors(i), elinewidth=1, capsize=5)\n",
    "        i += 30\n",
    "        plt.xlabel(f'{dataset_name}_diff')\n",
    "    plt.ylabel('differece')\n",
    "    plt.legend(bbox_to_anchor=(1, 0), loc='lower left')\n",
    "    # plt.ylim(85,100)\n",
    "    # plt.savefig(f'/mnt/hdd/tanhz/clip_bk/img/{dataset_name}_base2{x_type}.png',dpi=400)\n",
    "    # plt.clf()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分数据集结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 2560x1920 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "xs = dict()\n",
    "xs['new'] = ['base','new1','new2','new3','new4','new5']\n",
    "xs['newratio'] = ['base','new_ratio1','new_ratio2','new_ratio3','new_ratio4','new_ratio5']\n",
    "# colors = plt.get_cmap('hsv')\n",
    "# plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams['font.size'] = 10\n",
    "# plt.rcParams['font.weight'] = 'bold'\n",
    "# plt.rcParams['font.style'] = 'italic'\n",
    "\n",
    "for x_type in xs.keys():\n",
    "    x_axis = xs[x_type]\n",
    "    for dataset_name in os.listdir(root_path):\n",
    "        i = 0\n",
    "        y_axis_zero = []\n",
    "        for sub in x_axis:\n",
    "            y_axis_zero.append(all_results[dataset_name]['ZeroshotCLIP'][sub]['avg'])\n",
    "        y_axis_zero = np.array(y_axis_zero)\n",
    "        for trainer, trainer_label in zip(trainers,trainer_labels):\n",
    "            y_axis_acc = []\n",
    "            y_axis_std = []\n",
    "            for sub in x_axis:\n",
    "                y_axis_acc.append(all_results[dataset_name][trainer][sub]['avg'])\n",
    "                y_axis_std.append(all_results[dataset_name][trainer][sub]['std'])\n",
    "            y_axis_acc = np.array(y_axis_acc)\n",
    "            y_axis_std = np.array(y_axis_std)\n",
    "            # y_axis_minus = y_axis_acc - y_axis_zero\n",
    "            plt.plot(x['t'], y_axis_acc,linestyle='--',marker=markers[i],markersize=4,color=colors[i] ,alpha=1, linewidth=1, label=trainer_label)\n",
    "            plt.fill_between(x['t'], y_axis_acc-y_axis_std, y_axis_acc+y_axis_std, color=colors[i], alpha=0.1)\n",
    "            # plt.errorbar(x_axis, y_axis_acc, yerr=y_axis_std, ecolor=colors[i], color=colors[i], elinewidth=1, capsize=5)\n",
    "            i += 1\n",
    "        plt.title(f'{dataset_label}_{x_type}')\n",
    "        # _base2{x_type}\n",
    "        plt.xlabel('t')\n",
    "        plt.ylabel('Accuracy')\n",
    "        plt.legend(bbox_to_anchor=(0, 0), loc='lower left',fontsize=7)\n",
    "        # plt.ylim(85,100)\n",
    "        # plt.savefig(f'/mnt/hdd/tanhz/clip_bk/img/compare_zero/{dataset_name}_base2{x_type}.png',dpi=400)\n",
    "        plt.savefig(f'/mnt/hdd/tanhz/clip_bk/img/{dataset_name}_base2{x_type}.png',dpi=400)\n",
    "        plt.clf()\n",
    "    # plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AUC(y_axis:np.ndarray):\n",
    "    auc = (np.sum(y_axis[:-1]) + np.sum(y_axis[1:])) / (2 * (len(y_axis)-1))\n",
    "    return auc\n",
    "\n",
    "def WA(y_axis:np.ndarray):\n",
    "    wa = np.min(y_axis)\n",
    "    return wa\n",
    "\n",
    "def base_acc(y_axis:np.ndarray):\n",
    "    return y_axis[0]\n",
    "\n",
    "def difference(y_axis:np.ndarray):\n",
    "    return np.max(y_axis) - np.min(y_axis)\n",
    "\n",
    "def EVM(y_axis:np.ndarray):\n",
    "    evm = 0\n",
    "    for j in range(len(y_axis)-1):\n",
    "        evm += abs(y_axis[j+1] - y_axis[j])\n",
    "    return evm\n",
    "\n",
    "def VS(y_axis:np.ndarray):\n",
    "    vs = 0\n",
    "    for j in range(len(y_axis)-1):\n",
    "        vs += (y_axis[j+1] - y_axis[j]) ** 2\n",
    "    vs *= (len(y_axis) - 1)\n",
    "    A = y_axis[-1] - y_axis[0]\n",
    "    vs -= A ** 2\n",
    "    return vs\n",
    "\n",
    "def upper_down_area(y_axis:np.ndarray,y_axis_zero:np.ndarray):\n",
    "    assert len(y_axis) == len(y_axis_zero)\n",
    "    upper_area = 0\n",
    "    down_area = 0\n",
    "    length = len(y_axis)-1\n",
    "    for i in range(length):\n",
    "        det0 = y_axis[i] - y_axis_zero[i]\n",
    "        det1 = y_axis[i+1] - y_axis_zero[i+1]\n",
    "        if det0 * det1 >= 0:\n",
    "            s = (det0 + det1) / 2 / length\n",
    "            if s > 0:\n",
    "                upper_area += s\n",
    "            else:\n",
    "                down_area += s\n",
    "        else:\n",
    "            s0 = (det0 ** 2) / (abs(det0) + abs(det1)) / 2 / length\n",
    "            s1 = (det1 ** 2) / (abs(det0) + abs(det1)) / 2 / length\n",
    "            if det0 > 0:\n",
    "                upper_area += s0\n",
    "                down_area += s1\n",
    "            else:\n",
    "                upper_area += s1\n",
    "                down_area += s0\n",
    "    return abs(upper_area),abs(down_area)\n",
    "\n",
    "\n",
    "def metrics_cal(y_axis:np.ndarray,y_axis_zero:np.ndarray, metric:str):\n",
    "    y_axis_small = y_axis / 100\n",
    "    y_axis_zero_small = y_axis_zero / 100\n",
    "    if metric == 'AUC':\n",
    "        return AUC(y_axis_small)\n",
    "    elif metric == 'WA':\n",
    "        return WA(y_axis_small)\n",
    "    elif metric == 'EVM':\n",
    "        return EVM(y_axis_small)\n",
    "    elif metric == 'VS':\n",
    "        return VS(y_axis_small)\n",
    "    elif metric == 'base_acc':\n",
    "        return base_acc(y_axis_small)\n",
    "    elif metric == 'difference':\n",
    "        return difference(y_axis_small)\n",
    "    elif metric == 'upper_area':\n",
    "        upper_area, _ = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        return upper_area\n",
    "    elif metric == 'down_area':\n",
    "        _, down_area = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        return down_area\n",
    "    elif metric == 'u-d':\n",
    "        upper_area, down_area = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        if upper_area == 0:\n",
    "            return 'inf'\n",
    "        return down_area/upper_area\n",
    "\n",
    "xs = dict()\n",
    "xs['new'] = ['base','new1','new2','new3','new4','new5']\n",
    "xs['newratio'] = ['base','new_ratio1','new_ratio2','new_ratio3','new_ratio4','new_ratio5']\n",
    "metrics = ['base_acc','AUC', 'WA', 'EVM', 'VS', 'upper_area', 'down_area']\n",
    "# trainers = ['ZeroshotCLIP','CoOp','CoCoOp','KgCoOp','MaPLe','ProDA','ProGrad','PromptSRC','TCP','VPT']\n",
    "\n",
    "for x_type in xs.keys():\n",
    "    x_axis = xs[x_type]\n",
    "    ranks = []\n",
    "    for dataset_name in dataset_names:\n",
    "        # trainers = list(all_results[dataset_name].keys())\n",
    "        df = pd.DataFrame(columns=metrics, index=trainer_labels)\n",
    "        y_axis_acc_zero = []\n",
    "        y_axis_std_zero = []\n",
    "        for sub in x_axis:\n",
    "            y_axis_acc_zero.append(all_results[dataset_name]['ZeroshotCLIP'][sub]['avg'])\n",
    "            y_axis_std_zero.append(all_results[dataset_name]['ZeroshotCLIP'][sub]['std'])\n",
    "        y_axis_acc_zero = np.array(y_axis_acc_zero)\n",
    "        y_axis_std_zero = np.array(y_axis_std_zero)\n",
    "        y_s = []\n",
    "        for trainer, trainer_label in zip(trainers,trainer_labels):\n",
    "            y_axis_acc = []\n",
    "            y_axis_std = []\n",
    "            for sub in x_axis:\n",
    "                y_axis_acc.append(all_results[dataset_name][trainer][sub]['avg'])\n",
    "                y_axis_std.append(all_results[dataset_name][trainer][sub]['std'])\n",
    "            y_axis_acc = np.array(y_axis_acc)\n",
    "            y_s.append(y_axis_acc)\n",
    "            y_axis_std = np.array(y_axis_std)\n",
    "            for metric in metrics:\n",
    "                df.loc[trainer_label, metric] = metrics_cal(y_axis_acc,y_axis_acc_zero, metric)\n",
    "        y_s = np.array(y_s)\n",
    "        avg_rank =  np.mean(np.argsort(y_s.argsort(axis=0)[::-1], axis=0) + 1, axis=1)\n",
    "        # df['rank'] = avg_rank\n",
    "        # ranks.append(avg_rank)\n",
    "        df.to_csv(f'/mnt/hdd/tanhz/clip_bk/statistic/metrics-{dataset_name}-{x_type}.csv')\n",
    "    # ranks = np.array(ranks)\n",
    "    # df_ranks = pd.DataFrame(ranks,columns=trainers, index=dataset_names)\n",
    "    # df_ranks.loc['avg'] = np.mean(ranks, axis=0)\n",
    "    # df_ranks.to_csv(f'/mnt/hdd/tanhz/clip_bk/statistic/ranks-{x_type}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def AUC(y_axis:np.ndarray):\n",
    "    auc = (np.sum(y_axis[:-1]) + np.sum(y_axis[1:])) / (2 * (len(y_axis)-1))\n",
    "    return auc\n",
    "\n",
    "def WA(y_axis:np.ndarray):\n",
    "    wa = np.min(y_axis)\n",
    "    return wa\n",
    "\n",
    "def base_acc(y_axis:np.ndarray):\n",
    "    return y_axis[0]\n",
    "\n",
    "def difference(y_axis:np.ndarray):\n",
    "    return np.max(y_axis) - np.min(y_axis)\n",
    "\n",
    "def EVM(y_axis:np.ndarray):\n",
    "    evm = 0\n",
    "    for j in range(len(y_axis)-1):\n",
    "        evm += abs(y_axis[j+1] - y_axis[j])\n",
    "    return evm\n",
    "\n",
    "def VS(y_axis:np.ndarray):\n",
    "    vs = 0\n",
    "    for j in range(len(y_axis)-1):\n",
    "        vs += (y_axis[j+1] - y_axis[j]) ** 2\n",
    "    vs *= (len(y_axis) - 1)\n",
    "    A = y_axis[-1] - y_axis[0]\n",
    "    vs -= A ** 2\n",
    "    return vs\n",
    "\n",
    "def upper_down_area(y_axis:np.ndarray,y_axis_zero:np.ndarray):\n",
    "    assert len(y_axis) == len(y_axis_zero)\n",
    "    upper_area = 0\n",
    "    down_area = 0\n",
    "    length = len(y_axis)-1\n",
    "    for i in range(length):\n",
    "        det0 = y_axis[i] - y_axis_zero[i]\n",
    "        det1 = y_axis[i+1] - y_axis_zero[i+1]\n",
    "        if det0 * det1 >= 0:\n",
    "            s = (det0 + det1) / 2 / length\n",
    "            if s > 0:\n",
    "                upper_area += s\n",
    "            else:\n",
    "                down_area += s\n",
    "        else:\n",
    "            s0 = (det0 ** 2) / (abs(det0) + abs(det1)) / 2 / length\n",
    "            s1 = (det1 ** 2) / (abs(det0) + abs(det1)) / 2 / length\n",
    "            if det0 > 0:\n",
    "                upper_area += s0\n",
    "                down_area += s1\n",
    "            else:\n",
    "                upper_area += s1\n",
    "                down_area += s0\n",
    "    return abs(upper_area),abs(down_area)\n",
    "\n",
    "\n",
    "def metrics_cal(y_axis:np.ndarray,y_axis_zero:np.ndarray, metric:str):\n",
    "    y_axis_small = y_axis / 100\n",
    "    y_axis_zero_small = y_axis_zero / 100\n",
    "    if metric == 'AUC':\n",
    "        return AUC(y_axis_small)\n",
    "    elif metric == 'WA':\n",
    "        return WA(y_axis_small)\n",
    "    elif metric == 'EVM':\n",
    "        return EVM(y_axis_small)\n",
    "    elif metric == 'VS':\n",
    "        return VS(y_axis_small)\n",
    "    elif metric == 'base_acc':\n",
    "        return base_acc(y_axis_small)\n",
    "    elif metric == 'difference':\n",
    "        return difference(y_axis_small)\n",
    "    elif metric == 'upper_area':\n",
    "        upper_area, _ = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        return upper_area\n",
    "    elif metric == 'down_area':\n",
    "        _, down_area = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        return down_area\n",
    "    elif metric == 'detaauc':\n",
    "        return AUC(y_axis_small)-AUC(y_axis_zero_small)\n",
    "    elif metric == 'd/u':\n",
    "        upper_area, down_area = upper_down_area(y_axis_small,y_axis_zero_small)\n",
    "        # if upper_area == 0:\n",
    "        #     return 'inf'\n",
    "        return upper_area-down_area\n",
    "\n",
    "xs = dict()\n",
    "xs['new'] = ['base','new1','new2','new3','new4','new5']\n",
    "xs['newratio'] = ['base','new_ratio1','new_ratio2','new_ratio3','new_ratio4','new_ratio5']\n",
    "# trainers = ['ZeroshotCLIP','CoOp','CoCoOp','KgCoOp','MaPLe','ProDA','ProGrad','PromptSRC','TCP','VPT']\n",
    "metrics = ['detaauc','d/u']\n",
    "\n",
    "\n",
    "for x_type in xs.keys():\n",
    "    x_axis = xs[x_type]\n",
    "    ranks = []\n",
    "    dfs = None\n",
    "    for dataset_name in dataset_names:\n",
    "        # trainers = list(all_results[dataset_name].keys())\n",
    "        df = pd.DataFrame(columns=metrics, index=trainer_labels)\n",
    "        y_axis_acc_zero = []\n",
    "        y_axis_std_zero = []\n",
    "        for sub in x_axis:\n",
    "            y_axis_acc_zero.append(all_results[dataset_name]['ZeroshotCLIP'][sub]['avg'])\n",
    "            y_axis_std_zero.append(all_results[dataset_name]['ZeroshotCLIP'][sub]['std'])\n",
    "        y_axis_acc_zero = np.array(y_axis_acc_zero)\n",
    "        y_axis_std_zero = np.array(y_axis_std_zero)\n",
    "        y_s = []\n",
    "        for trainer, trainer_label in zip(trainers,trainer_labels):\n",
    "            y_axis_acc = []\n",
    "            y_axis_std = []\n",
    "            for sub in x_axis:\n",
    "                y_axis_acc.append(all_results[dataset_name][trainer][sub]['avg'])\n",
    "                y_axis_std.append(all_results[dataset_name][trainer][sub]['std'])\n",
    "            y_axis_acc = np.array(y_axis_acc)\n",
    "            y_s.append(y_axis_acc)\n",
    "            y_axis_std = np.array(y_axis_std)\n",
    "            for metric in metrics:\n",
    "                df.loc[trainer_label, metric] = metrics_cal(y_axis_acc,y_axis_acc_zero, metric)\n",
    "        y_s = np.array(y_s)\n",
    "        avg_rank =  np.mean(np.argsort(y_s.argsort(axis=0)[::-1], axis=0) + 1, axis=1)\n",
    "        df['rank'] = avg_rank\n",
    "        true_rank = np.argsort(np.argsort(avg_rank)) + 1\n",
    "        df['true_rank'] = true_rank\n",
    "        ranks.append(avg_rank)\n",
    "        df.to_csv(f'/mnt/hdd/tanhz/clip_bk/statistic/robust-{dataset_name}-{x_type}.csv')\n",
    "        if dfs is None:\n",
    "            dfs = df\n",
    "        else:\n",
    "            dfs += df\n",
    "    ranks = np.array(ranks)\n",
    "    dfs /= len(dataset_names)\n",
    "    dfs['true_rank'] = np.argsort(np.argsort(dfs['rank']))  + 1\n",
    "    df_ranks = pd.DataFrame(ranks,columns=trainers, index=dataset_names)\n",
    "    df_ranks.loc['avg'] = np.mean(ranks, axis=0)\n",
    "    df_ranks.loc['true_rank'] = np.argsort(np.argsort(np.mean(ranks, axis=0))) + 1\n",
    "    df_ranks.to_csv(f'/mnt/hdd/tanhz/clip_bk/statistic/ranks-{x_type}.csv')\n",
    "    dfs.to_csv(f'/mnt/hdd/tanhz/clip_bk/statistic/robust-avg-{x_type}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MLL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
