{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d2a67595",
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'NamedTuple' from 'typing_extensions' (/opt/conda/lib/python3.8/site-packages/typing_extensions.py)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Input \u001b[0;32mIn [4]\u001b[0m, in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmanifold\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TSNE\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlibs\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtransform\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;241m,\u001b[39m \u001b[38;5;21;01mjson\u001b[39;00m\u001b[38;5;241m,\u001b[39m \u001b[38;5;21;01mcopy\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/__init__.py:21\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;66;03m# License: BSD 3-Clause\u001b[39;00m\n\u001b[1;32m     19\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m__future__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m annotations\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     22\u001b[0m     _api_calls,\n\u001b[1;32m     23\u001b[0m     config,\n\u001b[1;32m     24\u001b[0m     datasets,\n\u001b[1;32m     25\u001b[0m     evaluations,\n\u001b[1;32m     26\u001b[0m     exceptions,\n\u001b[1;32m     27\u001b[0m     extensions,\n\u001b[1;32m     28\u001b[0m     flows,\n\u001b[1;32m     29\u001b[0m     runs,\n\u001b[1;32m     30\u001b[0m     setups,\n\u001b[1;32m     31\u001b[0m     study,\n\u001b[1;32m     32\u001b[0m     tasks,\n\u001b[1;32m     33\u001b[0m     utils,\n\u001b[1;32m     34\u001b[0m )\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m__version__\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m __version__\n\u001b[1;32m     36\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLDataFeature, OpenMLDataset\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/evaluations/__init__.py:3\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# License: BSD 3-Clause\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mevaluation\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLEvaluation\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m list_evaluation_measures, list_evaluations, list_evaluations_setups\n\u001b[1;32m      6\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOpenMLEvaluation\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlist_evaluations\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlist_evaluation_measures\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlist_evaluations_setups\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     11\u001b[0m ]\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/evaluations/evaluation.py:7\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mflows\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mruns\u001b[39;00m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtasks\u001b[39;00m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mOpenMLEvaluation\u001b[39;00m:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/runs/__init__.py:3\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# License: BSD 3-Clause\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m      4\u001b[0m     delete_run,\n\u001b[1;32m      5\u001b[0m     get_run,\n\u001b[1;32m      6\u001b[0m     get_run_trace,\n\u001b[1;32m      7\u001b[0m     get_runs,\n\u001b[1;32m      8\u001b[0m     initialize_model_from_run,\n\u001b[1;32m      9\u001b[0m     initialize_model_from_trace,\n\u001b[1;32m     10\u001b[0m     list_runs,\n\u001b[1;32m     11\u001b[0m     run_exists,\n\u001b[1;32m     12\u001b[0m     run_flow_on_task,\n\u001b[1;32m     13\u001b[0m     run_model_on_task,\n\u001b[1;32m     14\u001b[0m )\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrun\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLRun\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrace\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLRunTrace, OpenMLTraceIteration\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/runs/functions.py:32\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mflows\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mflow\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _copy_server_fields\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msetups\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m initialize_model, setup_exists\n\u001b[0;32m---> 32\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtasks\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     33\u001b[0m     OpenMLClassificationTask,\n\u001b[1;32m     34\u001b[0m     OpenMLClusteringTask,\n\u001b[1;32m     35\u001b[0m     OpenMLLearningCurveTask,\n\u001b[1;32m     36\u001b[0m     OpenMLRegressionTask,\n\u001b[1;32m     37\u001b[0m     OpenMLSupervisedTask,\n\u001b[1;32m     38\u001b[0m     OpenMLTask,\n\u001b[1;32m     39\u001b[0m     TaskType,\n\u001b[1;32m     40\u001b[0m     get_task,\n\u001b[1;32m     41\u001b[0m )\n\u001b[1;32m     43\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrun\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLRun\n\u001b[1;32m     44\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrace\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLRunTrace\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/tasks/__init__.py:3\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# License: BSD 3-Clause\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m      4\u001b[0m     create_task,\n\u001b[1;32m      5\u001b[0m     delete_task,\n\u001b[1;32m      6\u001b[0m     get_task,\n\u001b[1;32m      7\u001b[0m     get_tasks,\n\u001b[1;32m      8\u001b[0m     list_tasks,\n\u001b[1;32m      9\u001b[0m )\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msplit\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLSplit\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtask\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     12\u001b[0m     OpenMLClassificationTask,\n\u001b[1;32m     13\u001b[0m     OpenMLClusteringTask,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     18\u001b[0m     TaskType,\n\u001b[1;32m     19\u001b[0m )\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/tasks/functions.py:18\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_dataset\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexceptions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLCacheException\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtask\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m     19\u001b[0m     OpenMLClassificationTask,\n\u001b[1;32m     20\u001b[0m     OpenMLClusteringTask,\n\u001b[1;32m     21\u001b[0m     OpenMLLearningCurveTask,\n\u001b[1;32m     22\u001b[0m     OpenMLRegressionTask,\n\u001b[1;32m     23\u001b[0m     OpenMLSupervisedTask,\n\u001b[1;32m     24\u001b[0m     OpenMLTask,\n\u001b[1;32m     25\u001b[0m     TaskType,\n\u001b[1;32m     26\u001b[0m )\n\u001b[1;32m     28\u001b[0m TASKS_CACHE_DIR_NAME \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtasks\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_cached_tasks\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mint\u001b[39m, OpenMLTask]:\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/tasks/task.py:19\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbase\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLBase\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mopenml\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _create_cache_directory_for_id\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msplit\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m OpenMLSplit\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TYPE_CHECKING:\n\u001b[1;32m     22\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/openml/tasks/split.py:8\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpathlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Path\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping_extensions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NamedTuple\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01marff\u001b[39;00m  \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'NamedTuple' from 'typing_extensions' (/opt/conda/lib/python3.8/site-packages/typing_extensions.py)"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "import openml\n",
    "from libs.transform import *\n",
    "import torch, json, copy\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=FutureWarning)\n",
    "warnings.filterwarnings('ignore', category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6540c3e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "openml_id = 4153"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4f5614a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'/home/SemiTab/dataset_id.json', 'r') as file:\n",
    "    data_info = json.load(file)\n",
    "tasktype = data_info.get(str(openml_id))['tasktype']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a2d90265",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = openml.datasets.get_dataset(openml_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "45a44b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, categorical_indicator, attribute_names = dataset.get_data(\n",
    "        target=dataset.default_target_attribute\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eaee4f18",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.tensor(X.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3389d2c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = Masking(0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "72e0626b",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The NVIDIA driver on your system is too old (found version 11080). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m xaug \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mimage\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdeepcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
      "File \u001b[0;32m/home/SemiTab/libs/transform.py:24\u001b[0m, in \u001b[0;36mMasking.__call__\u001b[0;34m(self, sample)\u001b[0m\n\u001b[1;32m     21\u001b[0m mask \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39muniform(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, size\u001b[38;5;241m=\u001b[39mimg\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmask_prob    \n\u001b[1;32m     23\u001b[0m img[torch\u001b[38;5;241m.\u001b[39mtensor(mask)\u001b[38;5;241m.\u001b[39mto(img\u001b[38;5;241m.\u001b[39mdevice)] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmasking_constant\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mimage\u001b[39m\u001b[38;5;124m'\u001b[39m: img, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmask\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mrequires_grad_(requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)}\n",
      "File \u001b[0;32m/opt/conda/envs/tabular/lib/python3.9/site-packages/torch/cuda/__init__.py:314\u001b[0m, in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m    312\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCUDA_MODULE_LOADING\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m os\u001b[38;5;241m.\u001b[39menviron:\n\u001b[1;32m    313\u001b[0m     os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCUDA_MODULE_LOADING\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLAZY\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 314\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cuda_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    315\u001b[0m \u001b[38;5;66;03m# Some of the queued calls may reentrantly call _lazy_init();\u001b[39;00m\n\u001b[1;32m    316\u001b[0m \u001b[38;5;66;03m# we need to just return without initializing in that case.\u001b[39;00m\n\u001b[1;32m    317\u001b[0m \u001b[38;5;66;03m# However, we must not let any *other* threads in!\u001b[39;00m\n\u001b[1;32m    318\u001b[0m _tls\u001b[38;5;241m.\u001b[39mis_initializing \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The NVIDIA driver on your system is too old (found version 11080). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver."
     ]
    }
   ],
   "source": [
    "xaug = t({\"image\": copy.deepcopy(X), \"mask\": None})[\"image\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4186fd2d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
