{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9c8ff20b-de73-4333-aba1-46a2f123e42e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import xgboost as xgb\n",
    "import sklearn\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "import sklearn.model_selection as ms\n",
    "import numpy as np\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "from nnlib.nnlib import utils\n",
    "\n",
    "import warnings\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import h5py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4489286c-741e-4cc9-a25f-07a1af491458",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "予測確率:\n",
      "予測ラベル:\n"
     ]
    }
   ],
   "source": [
    "# クラスの数\n",
    "num_classes = 10\n",
    "# サンプル数\n",
    "num_samples = 100\n",
    "\n",
    "# 予測確率をランダムに生成\n",
    "preds = np.random.rand(num_samples, num_classes)\n",
    "# 確率の合計が1になるように正規化\n",
    "preds /= np.sum(preds, axis=1, keepdims=True)\n",
    "\n",
    "# 予測ラベルを確率に基づいてランダムに生成\n",
    "labels = np.argmax(preds, axis=1)\n",
    "\n",
    "print(\"予測確率:\")\n",
    "#print(preds)\n",
    "print(\"予測ラベル:\")\n",
    "#print(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "fa09de57-8306-47b1-b207-2592d70ba5bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ True, False, False, False, False, False, False, False, False,\n",
       "       False, False, False, False, False, False, False, False, False,\n",
       "       False,  True, False, False, False, False, False, False, False,\n",
       "       False, False, False, False,  True, False, False, False,  True,\n",
       "       False, False, False, False, False, False, False, False, False,\n",
       "       False, False, False,  True, False, False, False, False, False,\n",
       "       False, False, False,  True, False, False, False, False, False,\n",
       "       False, False,  True, False,  True, False, False, False, False,\n",
       "       False, False, False, False, False, False, False, False, False,\n",
       "       False,  True, False, False, False, False, False, False, False,\n",
       "       False, False, False, False, False, False,  True, False, False,\n",
       "       False])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.where()\n",
    "labels == 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7e6c8022-e90f-4fd1-8f34-cbe8201a53c6",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "index 0 is out of bounds for axis 0 with size 0",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m np\u001b[38;5;241m.\u001b[39marray(\u001b[43m[\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m)\n",
      "Cell \u001b[0;32mIn[15], line 1\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m----> 1\u001b[0m np\u001b[38;5;241m.\u001b[39marray([\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m labels\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m)])\n",
      "\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9669aca8-d02f-4f2b-85bf-02de7d6014e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is needed to convert labels from one-hot to conventional form\n",
    "label_index = np.array([np.where(r==1)[0][0] for r in label])\n",
    "with torch.no_grad():\n",
    "        if p.shape[1] !=2:\n",
    "            p_new = torch.from_numpy(p)\n",
    "            p_b = torch.zeros(N,1)\n",
    "            label_binary = np.zeros((N,1))\n",
    "            for i in range(N):\n",
    "                pred_label = int(torch.argmax(p_new[i]).numpy())\n",
    "                if pred_label == label_index[i]:\n",
    "                    label_binary[i] = 1\n",
    "                p_b[i] = p_new[i,pred_label]/torch.sum(p_new[i,:])  \n",
    "        else:\n",
    "            p_b = torch.from_numpy((p/np.sum(p,1)[:,None])[:,1])\n",
    "            label_binary = label_index"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
