{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tab ddpm\n",
    "ddpm_dir = '/mnt/data/sonia/ckpts/tab-ddpm'\n",
    "import numpy as np\n",
    "from typing import cast\n",
    "from copy import deepcopy\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import pickle \n",
    "# https://github.com/yandex-research/rtdl-num-embeddings/blob/main/bin/datasets.py#L64\n",
    "def tabddpm(config, train, val, test, alldf, outpath_date, outpath_latest):\n",
    "    assert (train.dtypes == val.dtypes).all()\n",
    "    assert (train.dtypes == test.dtypes).all()\n",
    "    assert (train.columns == val.columns).all()\n",
    "    assert (train.columns == test.columns).all()\n",
    "    train = deepcopy(train)\n",
    "    val = deepcopy(val)\n",
    "    test = deepcopy(test)\n",
    "    \n",
    "    if config['task'] == 'classification':\n",
    "        label_encoder = LabelEncoder()\n",
    "        label_encoder.fit(alldf[config['labs'][0]])\n",
    "        train[config['labs'][0]] = label_encoder.transform(train[config['labs'][0]])\n",
    "        val[config['labs'][0]] = label_encoder.transform(val[config['labs'][0]])\n",
    "        test[config['labs'][0]] = label_encoder.transform(test[config['labs'][0]])\n",
    "\n",
    "    def get_Xy(df, config):\n",
    "        df = deepcopy(df)\n",
    "        y = df.pop(config['labs'][0]).astype('int64')\n",
    "        d = {'y': y}\n",
    "        xnum = df.loc[:, config['nums']].values\n",
    "        if xnum.shape[1] > 0:\n",
    "            d['X_num'] = xnum\n",
    "        xcat = df.loc[:,  config['ords']].values \n",
    "        if xcat.shape[1] > 0:\n",
    "            d['X_cat'] = xcat\n",
    "        return d\n",
    "        \n",
    "    traindict = get_Xy(train, config)\n",
    "    trainidx = np.arange(0, len(train))\n",
    "    valdict = get_Xy(val, config)\n",
    "    validx = np.arange(len(train), len(train)+len(val))\n",
    "    testdict = get_Xy(test, config)\n",
    "    testidx = np.arange(len(train)+len(val), len(train)+len(val)+len(test))\n",
    "    \n",
    "    datedirname = '.'.join(config['creation_time'].split())\n",
    "    task_type = 'regression'\n",
    "    if config['task'] == 'classification' and len(label_encoder.classes_) == 2:\n",
    "        task_type = 'binclass'\n",
    "    elif config['task'] == 'classification':\n",
    "        task_type = 'multiclass'\n",
    "    info = {\n",
    "        'name': config['dataset_name'],\n",
    "        'id': datedirname,\n",
    "        'task_type': task_type,\n",
    "        'n_num_features': len(config['nums']),\n",
    "        'n_cat_features': len(config['ords']),\n",
    "        'train_size': len(train),\n",
    "        'val_size': len(val),\n",
    "        'test_size': len(test)\n",
    "    }\n",
    "    \n",
    "    outpath_date_tddpm = os.path.join(outpath_date, 'tab-ddpm')\n",
    "    outpath_latest_tddpm = os.path.join(outpath_latest, 'tab-ddpm')\n",
    "    ddpm_data_dir = os.path.join(ddpm_dir, config['dataset_name'])\n",
    "    for path in [outpath_date_tddpm, outpath_latest_tddpm, ]:\n",
    "        os.makedirs(path, exist_ok=True)\n",
    "        with open(os.path.join(path, 'info.json'), 'w') as f:\n",
    "            f.write(json.dumps(info, indent=4))\n",
    "        if config['task'] == 'classification':\n",
    "            with open(os.path.join(path, 'label_encoder.pkl'), 'wb') as file:\n",
    "                pickle.dump(label_encoder, file)\n",
    "        for name, npy in traindict.items():\n",
    "            np.save(os.path.join(path, f'{name}_train.npy'), npy)\n",
    "        for name, npy in valdict.items():\n",
    "            np.save(os.path.join(path, f'{name}_val.npy'), npy)\n",
    "        for name, npy in testdict.items():\n",
    "            np.save(os.path.join(path, f'{name}_test.npy'), npy)\n",
    "        np.save(os.path.join(path, 'idx_train.npy'), trainidx)\n",
    "        np.save(os.path.join(path, 'idx_val.npy'), validx)\n",
    "        np.save(os.path.join(path, 'idx_test.npy'), testidx)\n",
    "        \n",
    "    return traindict, valdict, testdict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_114413/2131061127.py:7: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n",
      "  dataset = openml.datasets.get_dataset('sick')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (2829, 30) val (282, 30) test (661, 30)\n"
     ]
    }
   ],
   "source": [
    "import openml\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "dataset = openml.datasets.get_dataset('sick')\n",
    "df, _, _, _ = dataset.get_data(dataset_format=\"dataframe\")\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'sick',\n",
    "    'task': 'classification',\n",
    "    'raw_path': \"openml.datasets.get_dataset('sick')\",\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': list(df.columns),\n",
    "}\n",
    "\n",
    "config['ords'] = ['sex', 'on_thyroxine', 'query_on_thyroxine', 'on_antithyroid_medication', 'sick', 'pregnant', 'thyroid_surgery',\n",
    "                  'I131_treatment', 'query_hypothyroid', 'query_hyperthyroid', 'lithium', 'goitre', 'tumor', 'hypopituitary', \n",
    "                  'psych', 'TSH_measured', 'T3_measured', 'TT4_measured', 'T4U_measured', 'FTI_measured',\n",
    "                  'TBG_measured', 'referral_source']\n",
    "config['nums'] = ['age', 'TSH', 'T3', 'TT4', 'T4U', 'FTI', 'TBG', ]\n",
    "config['labs'] = ['Class']\n",
    "\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (36631, 15) val (3663, 15) test (8548, 15)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'adult',\n",
    "    'raw_path': './adult.csv',\n",
    "    'task': 'classification',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': ['age', 'class', 'financial-weight', 'education', 'years-education', 'marital-status', 'occupation', 'relationship', \n",
    "        'race', 'sex', 'gain-capital', 'loss-capital', 'hours-per-week', 'native-country', 'income'],\n",
    "    'ords': ['class', 'education', 'marital-status', 'occupation', 'relationship',\n",
    "        'race', 'sex', 'native-country'],\n",
    "    'nums': ['age', \"financial-weight\", 'years-education', 'gain-capital', 'loss-capital', 'hours-per-week'],\n",
    "    'labs': ['income']\n",
    "}\n",
    "\n",
    "# read in, rename columns\n",
    "df = pd.read_csv(config['raw_path'])\n",
    "df.columns = config['cols']\n",
    "\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diabetes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same processing, but without get_dummies() as https://huggingface.co/datasets/imodels/diabetes-readmission"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/data/sonia/miniconda3/envs/great/lib/python3.11/site-packages/ucimlrepo/fetch.py:97: DtypeWarning: Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  df = pd.read_csv(data_url)\n",
      "/tmp/ipykernel_114413/2797172938.py:88: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'Circulatory' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n",
      "  data.loc[(data[col]>=390) & (data[col]<=459) | (data[col]==785), \"temp_diag\"] = \"Circulatory\"\n",
      "/tmp/ipykernel_114413/2797172938.py:88: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'Circulatory' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n",
      "  data.loc[(data[col]>=390) & (data[col]<=459) | (data[col]==785), \"temp_diag\"] = \"Circulatory\"\n",
      "/tmp/ipykernel_114413/2797172938.py:88: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'Circulatory' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n",
      "  data.loc[(data[col]>=390) & (data[col]<=459) | (data[col]==785), \"temp_diag\"] = \"Circulatory\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "keep meds ['metformin' 'repaglinide' 'nateglinide' 'chlorpropamide' 'glimepiride'\n",
      " 'glipizide' 'glyburide' 'pioglitazone' 'rosiglitazone' 'acarbose'\n",
      " 'miglitol' 'tolazamide' 'insulin' 'glyburide-metformin']\n",
      "train (76322, 37) val (7632, 37) test (17809, 37)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "from ucimlrepo import fetch_ucirepo \n",
    "import numpy as np\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'diabetes',\n",
    "    'raw_path': 'fetch_ucirepo(id=296)[\"data\"][\"original\"]',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'task': 'classification',\n",
    "}\n",
    "# fetch dataset \n",
    "df = fetch_ucirepo(id=296)['data']['original']\n",
    "\n",
    "#preprocessing\n",
    "df['readmitted'] = df['readmitted'].replace({'NO': 'no', '>30': 'yes', '<30': 'yes'}) #target\n",
    "df = df[df['gender'] != 'Unknown/Invalid']\n",
    "df['age'] = df['age'].replace({\"[70-80)\":\"70+\",\n",
    "                               \"[60-70)\":\"[50-70)\",\n",
    "                               \"[50-60)\":\"[50-70)\",\n",
    "                               \"[80-90)\":\"70+\",\n",
    "                               \"[40-50)\":\"[20-50)\",\n",
    "                               \"[30-40)\":\"[20-50)\",\n",
    "                               \"[90-100)\":\"70+\",\n",
    "                               \"[20-30)\":\"[20-50)\"})\n",
    "df['admission_type_id'] = df['admission_type_id'].replace({1.0:\"Emergency\",\n",
    "                                                           2.0:\"Emergency\",\n",
    "                                                           3.0:\"Elective\",\n",
    "                                                           4.0:\"Newborn\",\n",
    "                                                           5.0:'?',\n",
    "                                                           6.0:'?',\n",
    "                                                           7.0:\"Trauma Center\",\n",
    "                                                           8.0:'?'})\n",
    "df['discharge_disposition_id'] = df['discharge_disposition_id'].replace(\n",
    "    {1:\"Discharged-Home\",\n",
    "     6:\"Discharged-Home\",\n",
    "     8:\"Discharged-Home\",\n",
    "     13:\"Discharged-Home\",\n",
    "     19:\"Discharged-Home\",\n",
    "     18:'?', 25:'?', 26:'?',\n",
    "     2:\"Other\", 3:\"Other\", 4:\"Other\",\n",
    "     5:\"Other\", 7:\"Other\", 9:\"Other\",\n",
    "     10:\"Other\", 11:\"Other\", 12:\"Other\",\n",
    "     14:\"Other\", 15:\"Other\", 16:\"Other\",\n",
    "     17:\"Other\", 20:\"Other\", 21:\"Other\",\n",
    "     22:\"Other\", 23:\"Other\", 24:\"Other\",\n",
    "     27:\"Other\", 28:\"Other\", 29:\"Other\", 30:\"Other\"}\n",
    ") \n",
    "df['admission_source_id'] = df['admission_source_id'].replace(\n",
    "    {1:\"Referral\", 2:\"Referral\", 3:\"Referral\", 4:\"Transfer\",\n",
    "     5:\"Transfer\", 6:\"Transfer\", 7:\"Emergency\", 8:\"Other\",\n",
    "     9:\"Other\", 10:\"Transfer\", 11:\"Other\", 12:\"Other\",\n",
    "     13:\"Other\", 14:\"Other\", 15:'?', 17:'?', \n",
    "     18:\"Transfer\", 19:\"Other\", 20:'?', 21:'?',\n",
    "     22:\"Transfer\", 23:\"Other\", 24: \"Other\", 25:\"Transfer\",\n",
    "     26: \"Transfer\"}\n",
    ")\n",
    "df['medical_specialty'] = df['medical_specialty'].replace(\n",
    "    {\"Orthopedics-Reconstructive\": \"Orthopedics\",\n",
    "     \"Surgeon\": \"Surgery-General\",\n",
    "     \"Surgery-Cardiovascular\": \"Surgery-Cardiovascular/Thoracic\",\n",
    "     \"Surgery-Thoracic\": \"Surgery-Cardiovascular/Thoracic\",\n",
    "     \"Pediatrics-Endocrinology\": \"Pediatrics\",\n",
    "     \"Pediatrics-CriticalCare\": \"Pediatrics\",\n",
    "     \"Pediatrics-Pulmonology\": \"Pediatrics\",\n",
    "     \"Radiologist\": \"Radiology\",\n",
    "     \"Oncology\": \"Hematology/Oncology\",\n",
    "     \"Hematology\": \"Hematology/Oncology\",\n",
    "     \"Gynecology\": \"Obstetrics/Gynecology\",\n",
    "     \"Obstetrics\": \"Obstetrics/Gynecology\"\n",
    "     }\n",
    ")\n",
    "df['medical_specialty'] = df['medical_specialty'].replace(\n",
    "    {spec: \"Other\" for spec in df['medical_specialty'].value_counts().index.values[15:]}\n",
    ")\n",
    "def map_diagnosis(data, cols):\n",
    "    for col in cols:\n",
    "        data.loc[(data[col].str.contains(\"V\")) | (data[col].str.contains(\"E\")), col] = -1\n",
    "        data[col] = data[col].astype(np.float16)\n",
    "\n",
    "    for col in cols:\n",
    "        data[\"temp_diag\"] = np.nan\n",
    "        data.loc[(data[col]>=390) & (data[col]<=459) | (data[col]==785), \"temp_diag\"] = \"Circulatory\"\n",
    "        data.loc[(data[col]>=460) & (data[col]<=519) | (data[col]==786), \"temp_diag\"] = \"Respiratory\"\n",
    "        data.loc[(data[col]>=520) & (data[col]<=579) | (data[col]==787), \"temp_diag\"] = \"Digestive\"\n",
    "        data.loc[(data[col]>=680) & (data[col]<=709) | (data[col]==782), \"temp_diag\"] = \"Skin\"\n",
    "        data.loc[(data[col]>=240) & (data[col]<250) | (data[col]>251) & (data[col]<=279), \"temp_diag\"] = \"Non-diabetes;endocrine/metabolic\"\n",
    "        data.loc[(data[col]>=250) & (data[col]<251), \"temp_diag\"] = \"Diabetes\"\n",
    "        data.loc[(data[col]>=800) & (data[col]<=999), \"temp_diag\"] = \"Injury\"\n",
    "        data.loc[(data[col]>=710) & (data[col]<=739), \"temp_diag\"] = \"Musculoskeletal\"\n",
    "        data.loc[(data[col]>=580) & (data[col]<=629) | (data[col] == 788), \"temp_diag\"] = \"Genitourinary\"\n",
    "        data.loc[(data[col]>=140) & (data[col]<=239), \"temp_diag\"] = \"Neoplasms\"\n",
    "        data.loc[(data[col]>=290) & (data[col]<=319), \"temp_diag\"] = \"Mental\"\n",
    "        data.loc[(data[col]>=1) & (data[col]<=139), \"temp_diag\"] = \"Infectious\"\n",
    "\n",
    "        data[\"temp_diag\"] = data[\"temp_diag\"].fillna(\"Other\")\n",
    "        data[col] = data[\"temp_diag\"]\n",
    "        data = data.drop(\"temp_diag\", axis=1)\n",
    "\n",
    "    return data\n",
    "df = map_diagnosis(df, [\"diag_1\",\"diag_2\",\"diag_3\"])\n",
    "df['change'] = df['change'].replace({'Ch': 'yes', 'No': 'no'})\n",
    "all_meds = df.columns[24:47]\n",
    "keep_meds = all_meds.values[\n",
    "    [(df[med].value_counts().shape[0] > 1) and (df[med].value_counts()['Steady'] > 30) for med in all_meds]\n",
    "]\n",
    "drop_meds = all_meds.values[~all_meds.isin(keep_meds)]\n",
    "print('keep meds', keep_meds)\n",
    "\n",
    "drop_columns = ['encounter_id', 'patient_nbr', 'weight', 'payer_code'] + drop_meds.tolist()\n",
    "df = df.drop(drop_columns, axis=1)\n",
    "\n",
    "# specify column types\n",
    "config['max_col_length'] = 20\n",
    "config['cols'] = list(df.columns)\n",
    "config['ords'] = ['race', 'gender', 'age', 'admission_type_id', 'discharge_disposition_id', 'admission_source_id',\n",
    "                  'medical_specialty', 'diag_1', 'diag_2', 'diag_3', 'max_glu_serum', 'A1Cresult', ] + list(keep_meds) +\\\n",
    "                 ['change', 'diabetesMed']\n",
    "config['nums'] = ['time_in_hospital', 'num_lab_procedures', 'num_procedures', 'num_medications', 'number_outpatient',\n",
    "                  'number_emergency', 'number_inpatient', 'number_diagnoses']\n",
    "config['labs'] = ['readmitted']\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "df = df.fillna('?')\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Diabetes NEW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_159473/3647332689.py:7: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n",
      "  dataset = openml.datasets.get_dataset('diabetes')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (576, 9) val (57, 9) test (135, 9)\n"
     ]
    }
   ],
   "source": [
    "import openml\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "dataset = openml.datasets.get_dataset('diabetes')\n",
    "df, _, _, _ = dataset.get_data(dataset_format=\"dataframe\")\n",
    "\n",
    "cols = ['pregnancies', 'glucose-plasma', 'blood-pressure', 'skin-thickness', 'insulin', 'BMI', 'pedigree', 'age', 'diagnosis']\n",
    "ords = []\n",
    "labs = ['diagnosis']\n",
    "nums = ['pregnancies', 'glucose-plasma', 'blood-pressure', 'skin-thickness', 'insulin', 'BMI', 'pedigree', 'age']\n",
    "\n",
    "df.columns = cols \n",
    "df['diagnosis'] = df['diagnosis'].map(lambda x: 'positive' if x=='tested_positive' else 'negative')\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'diabetes-new',\n",
    "    'task': 'classification',\n",
    "    'raw_path': \"openml.datasets.get_dataset('diabetes')\",\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': cols,\n",
    "    'ords': ords,\n",
    "    'nums': nums,\n",
    "    'labs': labs,\n",
    "}\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CA Housing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/data/sonia/miniconda3/envs/great/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (15480, 10) val (1548, 10) test (3612, 10)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "df = pd.read_csv(\"hf://datasets/leostelon/california-housing/housing.csv\")\n",
    "\n",
    "#rename cols so none start with same token\n",
    "cols = ['longitude', 'latitude', 'age_median', 'rooms', 'bedrooms', 'population', \n",
    "        'households', 'income_median', 'value_median_house', 'ocean_proximity']\n",
    "df.columns = cols\n",
    "ints = ['age_median', 'rooms', 'bedrooms', 'population', 'households', 'value_median_house']\n",
    "df = df.fillna('?')\n",
    "def mapping(v):\n",
    "    if v == '?': return '?'\n",
    "    else: return int(v)\n",
    "df[ints] = df[ints].map(mapping)\n",
    "# df[df.isna()] = '?'\n",
    "# df[ints] = df[ints].astype(int)\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'house',\n",
    "    'raw_path': 'hf://datasets/leostelon/california-housing/housing.csv',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'task': 'regression',\n",
    "}\n",
    "\n",
    "config['cols'] = list(df.columns)\n",
    "config[\"ords\"] = [\"ocean_proximity\"]\n",
    "config[\"nums\"] = [\"longitude\", \"latitude\", \"age_median\", \"rooms\", \"bedrooms\", \"population\", \"households\", \"income_median\"]\n",
    "config[\"labs\"] = [\"value_median_house\"]\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CA-Housing NEW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (15480, 9) val (1548, 9) test (3612, 9)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "df = fetch_california_housing(as_frame=True).frame\n",
    "\n",
    "#rename cols so none start with same token\n",
    "cols = ['income_median', 'age_median', 'rooms', 'bedrooms', 'population', \n",
    "        'occupancy', 'latitude', 'longitude', 'value_median_house']\n",
    "df.columns = cols\n",
    "ints = ['age_median', 'rooms', 'bedrooms', 'population', 'households', 'value_median_house']\n",
    "df = df.fillna('?')\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'house-new',\n",
    "    'raw_path': 'fetch_california_housing(as_frame=True).frame',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'task': 'regression',\n",
    "}\n",
    "\n",
    "config['cols'] = list(df.columns)\n",
    "config[\"ords\"] = []\n",
    "config[\"nums\"] = ['income_median', 'age_median', 'rooms', 'bedrooms', 'population', \n",
    "        'occupancy', 'latitude', 'longitude',]\n",
    "config[\"labs\"] = [\"value_median_house\"]\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CA Housing New Tiny"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (5160, 6) val (1548, 6) test (13932, 6)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "\n",
    "df = fetch_california_housing(as_frame=True).frame\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'house-new-tiny',\n",
    "    'raw_path': 'fetch_california_housing(as_frame=True).frame',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.25,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'task': 'regression',\n",
    "}\n",
    "\n",
    "#rename cols so none start with same token\n",
    "cols = ['income_median', 'age_median', 'rooms', 'bedrooms', 'population', \n",
    "        'occupancy', 'latitude', 'longitude', 'value_median_house']\n",
    "df.columns = cols\n",
    "ints = ['age_median', 'rooms', 'bedrooms', 'population', 'households', 'value_median_house']\n",
    "df = df.fillna('?')\n",
    "\n",
    "\n",
    "\n",
    "config['cols'] = ['income_median','age_median','rooms','bedrooms','occupancy','value_median_house']\n",
    "config[\"ords\"] = []\n",
    "config[\"nums\"] = ['income_median', 'age_median', 'rooms', 'bedrooms', \n",
    "        'occupancy',]\n",
    "config[\"labs\"] = [\"value_median_house\"]\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "df = df[config['cols']]\n",
    "\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "# ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_51891/1533411410.py:7: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n",
      "  dataset = openml.datasets.get_dataset('rainfall_bangladesh')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (12566, 4) val (1256, 4) test (2933, 4)\n"
     ]
    }
   ],
   "source": [
    "import openml\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "dataset = openml.datasets.get_dataset('rainfall_bangladesh')\n",
    "df, _, _, _ = dataset.get_data(dataset_format=\"dataframe\")\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'rain',\n",
    "    'task': 'regression',\n",
    "    'raw_path': \"openml.datasets.get_dataset('rainfall_bangladesh')\",\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': list(df.columns),\n",
    "}\n",
    "config['ords'] = ['Station', 'Month']\n",
    "config['nums'] = ['Year']\n",
    "config['labs'] = ['Rainfall']\n",
    "\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Abalone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_112632/408466766.py:7: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n",
      "  dataset = openml.datasets.get_dataset('abalone')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (3132, 9) val (313, 9) test (732, 9)\n"
     ]
    }
   ],
   "source": [
    "import openml\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "dataset = openml.datasets.get_dataset('abalone')\n",
    "df, _, _, _ = dataset.get_data(dataset_format=\"dataframe\")\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'abalone',\n",
    "    'task': 'regression',\n",
    "    'raw_path': \"openml.datasets.get_dataset('abalone')\",\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': list(df.columns),\n",
    "}\n",
    "config['ords'] = ['Sex']\n",
    "config['nums'] = ['Length',\t'Diameter',\t'Height',\t'Whole_weight',\t'Shucked_weight',\t'Viscera_weight',\t'Shell_weight']\n",
    "config['labs'] = ['Class_number_of_rings']\n",
    "\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Travel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (715, 7) val (71, 7) test (168, 7)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'travel',\n",
    "    'raw_path': 'https://www.kaggle.com/datasets/tejashvi14/tour-travels-customer-churn-prediction?resource=download',\n",
    "    'task': 'classification',\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'cols': ['Age','Frequent-Flyer','Class','Services','Social-Media','Hotel','Target'],\n",
    "    'ords': ['Frequent-Flyer','Class','Social-Media','Hotel'],\n",
    "    'nums': ['Age','Services',],\n",
    "    'labs': ['Target']\n",
    "}\n",
    "\n",
    "# read in, rename columns\n",
    "df = pd.read_csv('travel.csv')\n",
    "assert set(config['ords']+config['nums']+config['labs'])==set(config['cols']) \n",
    "assert len(config['ords'])+len(config['nums'])+len(config['labs']) == len(config['cols'])\n",
    "\n",
    "df[config['ords']] = df[config['ords']].map(lambda x: '-'.join(x.split(' ')))\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "train_size = int(config['train_frac'] * n)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        \n",
    "# ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# cautab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/sonia\n",
      "fatal: destination path 'CauTabBench' already exists and is not an empty directory.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sonia/miniconda3/envs/great/lib/python3.11/site-packages/IPython/core/magics/osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.\n",
      "  bkms = self.shell.db.get('bookmarks', {})\n",
      "/home/sonia/miniconda3/envs/great/lib/python3.11/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/sonia/CauTabBench\n",
      "100\n",
      "sim_lg (17117, 11) (1902, 11) (19019, 11)\n",
      "Numerical (17117, 10)\n",
      "Categorical (17117, 0)\n",
      "Processing and Saving sim_lg Successfully!\n",
      "sim_lg\n",
      "Total 19019\n",
      "Train 17117\n",
      "Test 1902\n",
      "Num 10\n",
      "Cat 1\n",
      "train (15834, 11) val (1283, 11)\n",
      "/home/sonia/tabby\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/sonia/miniconda3/envs/great/lib/python3.11/site-packages/IPython/core/magics/osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.\n",
      "  bkms = self.shell.db.get('bookmarks', {})\n",
      "/home/sonia/miniconda3/envs/great/lib/python3.11/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    }
   ],
   "source": [
    "%cd ~\n",
    "!git clone https://github.com/TURuibo/CauTabBench.git \n",
    "%cd CauTabBench\n",
    "!python process_sim_dataset.py --seed 100 --cm lg\n",
    "\n",
    "import pandas  as pd \n",
    "import json \n",
    "import os\n",
    "import datetime\n",
    "from shutil import copyfile\n",
    "\n",
    "config = {\n",
    "    'dataset_name': 'cautab',\n",
    "    'raw_path': 'python process_sim_dataset.py --seed 100 --cm lg',\n",
    "    'random_state': 42,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "    'task': 'classification',\n",
    "}\n",
    "config['cols'] = ['V0','V1','V2','V3','V4','V5','V6','V7','V8','V9','target']\n",
    "config['nums'] = ['V0','V1','V2','V3','V4','V5','V6','V7','V8','V9',]\n",
    "config['labs'] = ['target']\n",
    "\n",
    "df = pd.read_csv('~/CauTabBench/data/sim_lg/100/train.csv')\n",
    "# split into train/val/test sets\n",
    "n = len(df)\n",
    "val_size = int(config['val_frac'] * n)\n",
    "train = df.iloc[:-val_size, :]\n",
    "val = df.iloc[-val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape,)\n",
    "test = pd.read_csv('~/CauTabBench/data/sim_lg/100/train.csv')\n",
    "alls = pd.concat([df, test], axis=0)\n",
    "\n",
    "%cd ~/tabby/\n",
    "# write everything out\n",
    "datedirname = '.'.join(config['creation_time'].split())\n",
    "outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "for path in [outpath_date, outpath_latest]:\n",
    "    os.makedirs(path, exist_ok=True)\n",
    "    train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "    val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    # copyfile('~/CauTabBench/data/sim_lg/100/test.csv', os.path.join(path, 'val.csv'))\n",
    "    test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "    alls.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "    with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "        json.dump(config, f)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inria Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, get_dataset_config_names, load_dataset_builder\n",
    "import string\n",
    "import pandas as pd \n",
    "import json \n",
    "import datetime\n",
    "from copy import deepcopy\n",
    "import os\n",
    "\n",
    "defaultconfig = {\n",
    "    'random_state': 42,\n",
    "    'train_frac': 0.75,\n",
    "    'val_frac': 0.075,\n",
    "    'creation_time': str(datetime.datetime.now()),\n",
    "    'max_col_length': 20,\n",
    "}\n",
    "\n",
    "skip = ['clf_num_california', 'clf_num_Diabetes130US']\n",
    "\n",
    "prepend = [str(num) for num in range(10)] + list(string.ascii_lowercase)\n",
    "prepend = [e+'. ' for e in prepend] #36 items\n",
    "\n",
    "names = get_dataset_config_names(\"inria-soda/tabular-benchmark\")\n",
    "# names = ['clf_cat_albert']\n",
    "for name in names:\n",
    "    if name in skip:\n",
    "        continue\n",
    "    \n",
    "    ds = load_dataset_builder(\"inria-soda/tabular-benchmark\", name)\n",
    "    if len(ds.info.features) > len(prepend):\n",
    "        skip.append(name)\n",
    "        continue\n",
    "    \n",
    "    df = load_dataset(\"inria-soda/tabular-benchmark\", name)['train'].to_pandas() # only has a train split\n",
    "    ncols = len(df.columns)\n",
    "    df.columns = [pre+col for pre,col in zip(prepend[:ncols], df.columns)]\n",
    "    \n",
    "    config = deepcopy(defaultconfig)\n",
    "    config['dataset_name'] = name\n",
    "    config['cols'] = list(df.columns)\n",
    "    config['labs'] = [config['cols'][-1]] # last col is label\n",
    "    if name.startswith('clf'):\n",
    "        config['task'] = 'classification'\n",
    "    elif name.startswith('reg'):\n",
    "        config['task'] = 'regression'\n",
    "    else:\n",
    "        raise Exception('unknown task for', name)\n",
    "    if name.startswith('clf_cat') or name.startswith('reg_cat'): # features are numerical or categorical\n",
    "        # not best way, but just assume str cols are ordinal and non-str are numerical\n",
    "        config['ords'] = list(df.dtypes[df.dtypes=='str'].index)\n",
    "        config['nums'] = list(df.dtypes[df.dtypes!='str'].index)\n",
    "    elif name.startswith('clf_num') or name.startswith('reg_num'): #features all numerical\n",
    "        config['nums'] = list(df.columns)[:-1]\n",
    "        config['ords'] = []\n",
    "        \n",
    "    df = df.fillna('?')\n",
    "    df = df.sample(frac=1, random_state=config['random_state'], ignore_index=True)\n",
    "    # split into train/val/test sets\n",
    "    n = len(df)\n",
    "    train_size = int(config['train_frac'] * n)\n",
    "    val_size = int(config['val_frac'] * n)\n",
    "    train = df.iloc[:train_size, :]\n",
    "    val = df.iloc[train_size:train_size+val_size, :]\n",
    "    test = df.iloc[train_size+val_size:, :]\n",
    "    print(name, '\\t\\t\\t\\ttrain', train.shape, 'val', val.shape, 'test', test.shape)\n",
    "    \n",
    "    # write everything out\n",
    "    datedirname = '.'.join(config['creation_time'].split())\n",
    "    outpath_date   = os.path.join('./data/', config['dataset_name'], datedirname)\n",
    "    outpath_latest = os.path.join('./data/', config['dataset_name'], 'latest')\n",
    "\n",
    "    for path in [outpath_date, outpath_latest]:\n",
    "        os.makedirs(path, exist_ok=True)\n",
    "        train.to_csv(os.path.join(path, 'train.csv'), index=False)\n",
    "        val.to_csv(os.path.join(path, 'val.csv'), index=False)\n",
    "        test.to_csv(os.path.join(path, 'test.csv'), index=False)\n",
    "        df.to_csv(os.path.join(path, 'all.csv'), index=False)\n",
    "        with open(os.path.join(path, 'config.json'), 'w') as f:\n",
    "            json.dump(config, f)\n",
    "                    \n",
    "    # ddpmout =tabddpm(config, train, val, test, df, outpath_date, outpath_latest)\n",
    "\n",
    "print('skipped\\n', skip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "great",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
