{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "# sys.path.append(\".\")\n",
    "# from configs.coco_classnames import coco_classnames\n",
    "\n",
    "res_dir = \"./coco_features/\"\n",
    "\n",
    "# Load feature, label, and relevant infos\n",
    "features = np.load(os.path.join(res_dir, \"features.npy\"))\n",
    "labels = np.load(os.path.join(res_dir, \"labels.npy\"))\n",
    "labels -= 1\n",
    "label_id = np.load(os.path.join(res_dir, \"label_id.npy\"))\n",
    "img_id = np.load(os.path.join(res_dir, \"img_id.npy\"))\n",
    "print(features.shape, labels.shape, label_id.shape, img_id.shape)\n",
    "\n",
    "# Delete nans\n",
    "nan_mask = np.isnan(features).any(axis=1)\n",
    "features_new = features[~nan_mask,:]\n",
    "labels_new = labels[~nan_mask]\n",
    "label_id_new = label_id[~nan_mask]\n",
    "img_id_new = img_id[~nan_mask]\n",
    "\n",
    "print(features.shape, labels.shape, label_id.shape, img_id.shape)\n",
    "\n",
    "# Calculate Gaussian distribution dict\n",
    "from measures.feature_measure import maha\n",
    "maha_dict = maha(features, labels, indist_classes=7)\n",
    "maha_path = \"./coco_features/maha_dict.npy\"\n",
    "np.save(maha_path, maha_dict)\n",
    "\n",
    "# # Load maha dict\n",
    "# maha_dict = np.load(os.path.join(res_dir, \"maha_dict.npy\"), allow_pickle=True).item()\n",
    "# print(maha_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from measures.feature_measure import maha_distance\n",
    "from tqdm import tqdm\n",
    "\n",
    "maha_res_list_class = []\n",
    "for i in range(90):\n",
    "    maha_res_list_class.append([])\n",
    "\n",
    "for i in tqdm(range(len(labels_new))):\n",
    "    label = labels_new[i]\n",
    "    \n",
    "    maha_dist = maha_distance(features_new[i], cov_inv_in=maha_intermediate_dict['class_cov_invs'][labels_new[i]], mean_in=maha_intermediate_dict['class_means'][labels_new[i]], norm_type=\"L2\")\n",
    "    res_dict = {\n",
    "        \"maha_dist\": maha_dist.item(),\n",
    "        \"label\": labels_new[i],\n",
    "        \"ann_id\": label_id_new[i],\n",
    "    }\n",
    "    maha_res_list_class[label].append(res_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maha_res_list_reclass = []\n",
    "maha_res_dict_reclass = {}\n",
    "for maha_dist_class in maha_res_list_class:\n",
    "    maha_dist_per_class = []\n",
    "    annids_per_class = []\n",
    "    if len(maha_dist_class)>0:\n",
    "        for maha_dist_dic in maha_dist_class:\n",
    "            maha_dist_per_class.append(maha_dist_dic['maha_dist'])\n",
    "            annids_per_class.append(maha_dist_dic['ann_id'])\n",
    "        maha_dist_per_class = np.array(maha_dist_per_class)\n",
    "        \n",
    "        # maha_dist_max = maha_dist_per_class.max()\n",
    "        # maha_dist_min = maha_dist_per_class.min()\n",
    "        # print(maha_dist_max, maha_dist_min)\n",
    "        transformed_data = np.log(np.sqrt(maha_dist_per_class))\n",
    "        min_val = np.min(transformed_data)\n",
    "        max_val = np.max(transformed_data)\n",
    "        print(max_val, min_val)\n",
    "        # data_minmax = (transformed_data - min_val) / (max_val - min_val)\n",
    "        left = 0.01\n",
    "        right = 0.99\n",
    "        k = (right-left)/(max_val - min_val)\n",
    "        data_minmax = left+k*(transformed_data - min_val)\n",
    "        flag = 0.\n",
    "        for k in range(len(data_minmax)):\n",
    "            maha_res_dict_reclass[str(annids_per_class[k])] = np.array([data_minmax[k]])\n",
    "            maha_res_list_reclass.append(data_minmax[k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"./coco_features/maha_dic_perclass_data_scores.npy\", maha_res_dict_reclass)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
