{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import zero\n",
    "import lib\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import sdv\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sdv.datasets.demo import download_demo\n",
    "from scripts.train import Trainer\n",
    "from scripts.sample import to_good_ohe\n",
    "from scripts.utils_train import get_model, make_dataset, update_ema, make_dataset_from_df\n",
    "from tab_ddpm import GaussianMultinomialDiffusion, MLP\n",
    "from sdv.evaluation.single_table import evaluate_quality\n",
    "from sdv.metadata import SingleTableMetadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/scratch/ssd004/scratch/weipang/env/tab_ddpm/lib/python3.10/site-packages/sdv/datasets/demo.py:88: DtypeWarning: Columns (7) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  data[table_name] = pd.read_csv(io.StringIO(file_.decode()))\n"
     ]
    }
   ],
   "source": [
    "real_data, metadata = download_demo(\n",
    "    modality='multi_table',\n",
    "    dataset_name='rossmann'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: Metadata Pages: 1 -->\n",
       "<svg width=\"255pt\" height=\"485pt\"\n",
       " viewBox=\"0.00 0.00 255.00 485.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 481)\">\n",
       "<title>Metadata</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-481 251,-481 251,4 -4,4\"/>\n",
       "<!-- store -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>store</title>\n",
       "<path fill=\"#ffec8b\" stroke=\"#000000\" d=\"M12,-256.5C12,-256.5 235,-256.5 235,-256.5 241,-256.5 247,-262.5 247,-268.5 247,-268.5 247,-464.5 247,-464.5 247,-470.5 241,-476.5 235,-476.5 235,-476.5 12,-476.5 12,-476.5 6,-476.5 0,-470.5 0,-464.5 0,-464.5 0,-268.5 0,-268.5 0,-262.5 6,-256.5 12,-256.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.5\" y=\"-461.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">store</text>\n",
       "<polyline fill=\"none\" stroke=\"#000000\" points=\"0,-453.5 247,-453.5 \"/>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-438.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">CompetitionOpenSinceYear : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-423.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">CompetitionOpenSinceMonth : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-408.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">CompetitionDistance : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-393.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Promo2 : boolean</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-378.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Promo2SinceYear : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-363.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Promo2SinceWeek : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-348.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Store : id</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-333.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">StoreType : categorical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-318.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Assortment : categorical</text>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-303.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">PromoInterval : categorical</text>\n",
       "<polyline fill=\"none\" stroke=\"#000000\" points=\"0,-295.5 247,-295.5 \"/>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-280.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Primary key: Store</text>\n",
       "</g>\n",
       "<!-- historical -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>historical</title>\n",
       "<path fill=\"#ffec8b\" stroke=\"#000000\" d=\"M53.5,-.5C53.5,-.5 193.5,-.5 193.5,-.5 199.5,-.5 205.5,-6.5 205.5,-12.5 205.5,-12.5 205.5,-192.5 205.5,-192.5 205.5,-198.5 199.5,-204.5 193.5,-204.5 193.5,-204.5 53.5,-204.5 53.5,-204.5 47.5,-204.5 41.5,-198.5 41.5,-192.5 41.5,-192.5 41.5,-12.5 41.5,-12.5 41.5,-6.5 47.5,-.5 53.5,-.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.5\" y=\"-189.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">historical</text>\n",
       "<polyline fill=\"none\" stroke=\"#000000\" points=\"41.5,-181.5 205.5,-181.5 \"/>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-166.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Date : datetime</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-151.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Store : id</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-136.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">DayOfWeek : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-121.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Promo : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-106.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">StateHoliday : categorical</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-91.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Id : id</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-76.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Open : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-61.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">SchoolHoliday : numerical</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-46.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Customers : numerical</text>\n",
       "<polyline fill=\"none\" stroke=\"#000000\" points=\"41.5,-38.5 205.5,-38.5 \"/>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-23.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Primary key: Id</text>\n",
       "<text text-anchor=\"start\" x=\"49.5\" y=\"-8.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">Foreign key (store): Store</text>\n",
       "</g>\n",
       "<!-- store&#45;&gt;historical -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>store&#45;&gt;historical</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M123.5,-256.4519C123.5,-242.81 123.5,-228.8462 123.5,-215.1656\"/>\n",
       "<polygon fill=\"none\" stroke=\"#000000\" points=\"127,-204.7692 123.5001,-214.7693 120,-204.7693 127,-204.7692\"/>\n",
       "<text text-anchor=\"middle\" x=\"167\" y=\"-226.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\"> &#160;Store → Store</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe91e5ed6c0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metadata.visualize(\n",
    "    show_table_details='full',\n",
    "    show_relationship_labels=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "store_df = real_data['store']\n",
    "historical_df = real_data['historical']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "store_df_info = {}\n",
    "store_cat_cols = [\n",
    "    'StoreType', 'Assortment', 'Promo2', 'PromoInterval'\n",
    "]\n",
    "store_num_cols = [\n",
    "    'CompetitionDistance', 'CompetitionOpenSinceMonth',\n",
    "    'CompetitionOpenSinceYear', 'Promo2SinceWeek',\n",
    "    'Promo2SinceYear'\n",
    "]\n",
    "store_y_col = 'Store'\n",
    "\n",
    "store_df_info['cat_cols'] = store_cat_cols\n",
    "store_df_info['num_cols'] = store_num_cols\n",
    "store_df_info['y_col'] = store_y_col\n",
    "store_df_info['n_classes'] = len(set(store_df[store_y_col].values))\n",
    "store_df_info['task_type'] = 'multiclass'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "historical_df_info = {}\n",
    "historical_cat_cols = [\n",
    "    'Store', 'Date', 'Open', 'Promo', 'StateHoliday', 'SchoolHoliday'\n",
    "]\n",
    "historical_num_cols = [\n",
    "    'Id', 'DayOfWeek', 'Customers', \n",
    "]\n",
    "\n",
    "historical_df_info['cat_cols'] = historical_cat_cols\n",
    "historical_df_info['num_cols'] = historical_num_cols\n",
    "historical_df_info['task_type'] = 'multiclass'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
    "\n",
    "def prepare_for_clustering(df, cat_cols, num_cols):\n",
    "    # Create copies of the original dataframe to avoid modification\n",
    "    cat_df = df.copy()\n",
    "    num_df = df.copy()\n",
    "    \n",
    "    # Transform categorical columns to class indices using LabelEncoder\n",
    "    label_encoders = {}\n",
    "    for col in cat_cols:\n",
    "        le = LabelEncoder()\n",
    "        cat_df[col] = le.fit_transform(cat_df[col].astype(str))\n",
    "        label_encoders[col] = le\n",
    "    \n",
    "    # Transform numerical columns by normalizing using StandardScaler\n",
    "    scaler = StandardScaler()\n",
    "    num_df[num_cols] = scaler.fit_transform(num_df[num_cols])\n",
    "    \n",
    "    # Convert categorical and numerical DataFrames to numpy arrays\n",
    "    cat_array = cat_df[cat_cols].to_numpy()\n",
    "    num_array = num_df[num_cols].to_numpy()\n",
    "\n",
    "    return cat_array, num_array\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/scratch/ssd004/scratch/weipang/env/tab_ddpm/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
      "  super()._check_params_vs_input(X, default_n_init=10)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "historical_cat, historical_num = prepare_for_clustering(\n",
    "    historical_df, \n",
    "    historical_cat_cols, \n",
    "    historical_num_cols\n",
    ")\n",
    "\n",
    "combined_array = np.concatenate((historical_cat, historical_num), axis=1)\n",
    "\n",
    "# Train K-Means clustering algorithm\n",
    "num_clusters = 100  # Number of clusters you want\n",
    "kmeans = KMeans(n_clusters=num_clusters)\n",
    "clusters = kmeans.fit_predict(combined_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "historical_df['cluster'] = clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "historical_df_info = {}\n",
    "historical_cat_cols = [\n",
    "    'Date', 'Open', 'Promo', 'StateHoliday', 'SchoolHoliday', 'cluster'\n",
    "]\n",
    "historical_num_cols = [\n",
    "    'Id', 'DayOfWeek', 'Customers', \n",
    "]\n",
    "\n",
    "historical_y_col = 'cluster'\n",
    "\n",
    "historical_df_info['cat_cols'] = historical_cat_cols\n",
    "historical_df_info['num_cols'] = historical_num_cols\n",
    "historical_df_info['n_classes'] = len(set(historical_df[historical_y_col].values))\n",
    "historical_df_info['task_type'] = 'multiclass'\n",
    "historical_df_info['y_col'] = historical_y_col"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "T_dict = {\n",
    "    'seed': 0,\n",
    "    'normalization': \"quantile\",\n",
    "    # 'num_nan_policy': 'mean',\n",
    "    'num_nan_policy': None,\n",
    "    'cat_nan_policy': None,\n",
    "    'cat_min_frequency': None,\n",
    "    'cat_encoding': None,\n",
    "    'y_policy': \"default\"\n",
    "}\n",
    "model_params = {\n",
    "    'num_classes': 2,\n",
    "    'is_y_cond': True,\n",
    "    'rtdl_params': {\n",
    "        'd_layers': [\n",
    "            256,\n",
    "            1024,\n",
    "            1024,\n",
    "            1024,\n",
    "            1024,\n",
    "            256,\n",
    "        ],\n",
    "        'dropout': 0.0\n",
    "    }\n",
    "}\n",
    "\n",
    "real_data_path = 'data/adult'\n",
    "real_data_path = os.path.normpath(real_data_path)\n",
    "\n",
    "model_type = 'mlp'\n",
    "num_timesteps = 100\n",
    "gaussian_loss_type = \"mse\"\n",
    "\n",
    "steps = 30000\n",
    "lr = 0.0020099410620098234\n",
    "weight_decay = 0.0\n",
    "batch_size = 4096\n",
    "scheduler = 'cosine'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero.improve_reproducibility(seed)\n",
    "T = lib.Transformations(**T_dict)\n",
    "dataset = make_dataset_from_df(\n",
    "    historical_df, \n",
    "    T,\n",
    "    is_y_cond=True,\n",
    "    ratios=[0.7, 0.2, 0.1], \n",
    "    df_info=historical_df_info\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = lib.prepare_fast_dataloader(\n",
    "    dataset, \n",
    "    split='train', \n",
    "    batch_size=batch_size,\n",
    "    y_type='long'\n",
    ")\n",
    "val_loader = lib.prepare_fast_dataloader(\n",
    "    dataset, \n",
    "    split='val', \n",
    "    batch_size=batch_size,\n",
    "    y_type='long'\n",
    ")\n",
    "test_loader = lib.prepare_fast_dataloader(\n",
    "    dataset, \n",
    "    split='test', \n",
    "    batch_size=batch_size,\n",
    "    y_type='long'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tab_ddpm.resample import create_named_schedule_sampler\n",
    "import torch.nn.functional as F\n",
    "from tab_ddpm import logger\n",
    "from tab_ddpm.utils import *\n",
    "from tab_ddpm.modules import MLPDiffusion\n",
    "\n",
    "def split_microbatches(microbatch, *args):\n",
    "    bs = len(args[0])\n",
    "    if microbatch == -1 or microbatch >= bs:\n",
    "        yield tuple(args)\n",
    "    else:\n",
    "        for i in range(0, bs, microbatch):\n",
    "            yield tuple(x[i : i + microbatch] if x is not None else None for x in args)\n",
    "\n",
    "def compute_top_k(logits, labels, k, reduction=\"mean\"):\n",
    "    _, top_ks = torch.topk(logits, k, dim=-1)\n",
    "    if reduction == \"mean\":\n",
    "        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()\n",
    "    elif reduction == \"none\":\n",
    "        return (top_ks == labels[:, None]).float().sum(dim=-1)\n",
    "\n",
    "def log_loss_dict(diffusion, ts, losses):\n",
    "    for key, values in losses.items():\n",
    "        logger.logkv_mean(key, values.mean().item())\n",
    "        # Log the quantiles (four quartiles, in particular).\n",
    "        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):\n",
    "            quartile = int(4 * sub_t / diffusion.num_timesteps)\n",
    "            logger.logkv_mean(f\"{key}_q{quartile}\", sub_loss)\n",
    "\n",
    "def classifier_forward_backward_log(\n",
    "        classifier, \n",
    "        optimizer, \n",
    "        data_loader, \n",
    "        dataset, \n",
    "        schedule_sampler, \n",
    "        diffusion, \n",
    "        prefix=\"train\"\n",
    "    ):\n",
    "    batch, labels = next(data_loader)\n",
    "    labels = labels.to(device)\n",
    "    num_batch = batch[:, :dataset.n_num_features].to(device)\n",
    "    cat_batch = batch[:, dataset.n_num_features:].to(device)\n",
    "\n",
    "    t, _ = schedule_sampler.sample(batch.shape[0], device)\n",
    "    num_batch = diffusion.gaussian_q_sample(num_batch, t)\n",
    "    log_x_cat = index_to_log_onehot(cat_batch.long(), diffusion.num_classes).to(device)\n",
    "    log_x_cat = diffusion.q_sample(log_x_cat, t)\n",
    "    batch = torch.cat((num_batch, log_x_cat), dim=1)\n",
    "\n",
    "    for i, (sub_batch, sub_labels, sub_t) in enumerate(\n",
    "        split_microbatches(-1, batch, labels, t)\n",
    "    ):\n",
    "        logits = classifier(sub_batch, sub_t)\n",
    "        loss = F.cross_entropy(logits, sub_labels, reduction=\"none\")\n",
    "\n",
    "        losses = {}\n",
    "        losses[f\"{prefix}_loss\"] = loss.detach()\n",
    "        losses[f\"{prefix}_acc@1\"] = compute_top_k(\n",
    "            logits, sub_labels, k=1, reduction=\"none\"\n",
    "        )\n",
    "        losses[f\"{prefix}_acc@5\"] = compute_top_k(\n",
    "            logits, sub_labels, k=5, reduction=\"none\"\n",
    "        )\n",
    "        log_loss_dict(diffusion, sub_t, losses)\n",
    "        del losses\n",
    "        loss = loss.mean()\n",
    "        if loss.requires_grad:\n",
    "            if i == 0:\n",
    "                optimizer.zero_grad()\n",
    "            loss.backward(loss * len(sub_batch) / len(batch))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[942   2   2   4   2 100]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1055"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K = np.array(dataset.get_category_sizes('train'))\n",
    "if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot':\n",
    "    K = np.array([0])\n",
    "print(K)\n",
    "num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0\n",
    "np.sum(K) + num_numerical_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[942   2   2   4   2 100]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "selected index k out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[21], line 53\u001b[0m\n\u001b[1;32m     48\u001b[0m logger\u001b[38;5;241m.\u001b[39mlogkv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep\u001b[39m\u001b[38;5;124m\"\u001b[39m, step \u001b[38;5;241m+\u001b[39m resume_step)\n\u001b[1;32m     49\u001b[0m logger\u001b[38;5;241m.\u001b[39mlogkv(\n\u001b[1;32m     50\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     51\u001b[0m     (step \u001b[38;5;241m+\u001b[39m resume_step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m*\u001b[39m batch_size,\n\u001b[1;32m     52\u001b[0m )\n\u001b[0;32m---> 53\u001b[0m \u001b[43mclassifier_forward_backward_log\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     54\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclassifier\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     55\u001b[0m \u001b[43m    \u001b[49m\u001b[43mclassifier_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     56\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     57\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     58\u001b[0m \u001b[43m    \u001b[49m\u001b[43mschedule_sampler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     59\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdiffusion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     60\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m     61\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     63\u001b[0m classifier_optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     64\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m step \u001b[38;5;241m%\u001b[39m eval_interval:\n",
      "Cell \u001b[0;32mIn[15], line 62\u001b[0m, in \u001b[0;36mclassifier_forward_backward_log\u001b[0;34m(classifier, optimizer, data_loader, dataset, schedule_sampler, diffusion, prefix)\u001b[0m\n\u001b[1;32m     58\u001b[0m losses[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m     59\u001b[0m losses[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_acc@1\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m compute_top_k(\n\u001b[1;32m     60\u001b[0m     logits, sub_labels, k\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     61\u001b[0m )\n\u001b[0;32m---> 62\u001b[0m losses[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_acc@5\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_top_k\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     63\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlogits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msub_labels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnone\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m     64\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     65\u001b[0m log_loss_dict(diffusion, sub_t, losses)\n\u001b[1;32m     66\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m losses\n",
      "Cell \u001b[0;32mIn[15], line 16\u001b[0m, in \u001b[0;36mcompute_top_k\u001b[0;34m(logits, labels, k, reduction)\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_top_k\u001b[39m(logits, labels, k, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m---> 16\u001b[0m     _, top_ks \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtopk\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlogits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     17\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m reduction \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m     18\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m (top_ks \u001b[38;5;241m==\u001b[39m labels[:, \u001b[38;5;28;01mNone\u001b[39;00m])\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem()\n",
      "\u001b[0;31mRuntimeError\u001b[0m: selected index k out of range"
     ]
    }
   ],
   "source": [
    "eval_interval = 5\n",
    "log_interval = 10\n",
    "\n",
    "K = np.array(dataset.get_category_sizes('train'))\n",
    "if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot':\n",
    "    K = np.array([0])\n",
    "print(K)\n",
    "num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0\n",
    "\n",
    "classifier = MLPDiffusion(\n",
    "    # d_in=np.sum(K) + num_numerical_features,\n",
    "    d_in = num_numerical_features,\n",
    "    num_classes=historical_df_info['n_classes'],\n",
    "    is_y_cond=True,\n",
    "    rtdl_params={\n",
    "        'd_layers': [\n",
    "            128,\n",
    "            512,\n",
    "            512,\n",
    "            512,\n",
    "        ],\n",
    "        'dropout': 0.\n",
    "    },\n",
    "    dim_t=128\n",
    ").to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "classifier_optimizer = optim.AdamW(classifier.parameters(), lr=0.001)\n",
    "\n",
    "diffusion = GaussianMultinomialDiffusion(\n",
    "    num_classes=K,\n",
    "    num_numerical_features=num_numerical_features,\n",
    "    denoise_fn=None,\n",
    "    gaussian_loss_type=gaussian_loss_type,\n",
    "    num_timesteps=num_timesteps,\n",
    "    scheduler=scheduler,\n",
    "    device=device\n",
    ")\n",
    "diffusion.to(device)\n",
    "\n",
    "schedule_sampler = create_named_schedule_sampler(\n",
    "    'uniform', diffusion\n",
    ")\n",
    "\n",
    "\n",
    "resume_step = 0\n",
    "for step in range(150000):\n",
    "    logger.logkv(\"step\", step + resume_step)\n",
    "    logger.logkv(\n",
    "        \"samples\",\n",
    "        (step + resume_step + 1) * batch_size,\n",
    "    )\n",
    "    classifier_forward_backward_log(\n",
    "        classifier, \n",
    "        classifier_optimizer, \n",
    "        train_loader, \n",
    "        dataset, \n",
    "        schedule_sampler, \n",
    "        diffusion, \n",
    "        prefix=\"train\"\n",
    "    )\n",
    "\n",
    "    classifier_optimizer.step()\n",
    "    if not step % eval_interval:\n",
    "        with torch.no_grad():\n",
    "            classifier.eval()\n",
    "            classifier_forward_backward_log(\n",
    "                classifier, \n",
    "                classifier_optimizer, \n",
    "                train_loader, \n",
    "                dataset, \n",
    "                schedule_sampler, \n",
    "                diffusion, \n",
    "                prefix=\"val\"\n",
    "            )\n",
    "            classifier.train()\n",
    "\n",
    "    if not step % log_interval:\n",
    "        logger.dumpkvs()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero.improve_reproducibility(seed)\n",
    "T = lib.Transformations(**T_dict)\n",
    "dataset = make_dataset(\n",
    "    real_data_path,\n",
    "    T,\n",
    "    num_classes=model_params['num_classes'],\n",
    "    is_y_cond=model_params['is_y_cond'],\n",
    "    change_val=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "K = np.array(dataset.get_category_sizes('train'))\n",
    "if len(K) == 0 or T_dict['cat_encoding'] == 'one-hot':\n",
    "    K = np.array([0])\n",
    "print(K)\n",
    "\n",
    "num_numerical_features = dataset.X_num['train'].shape[1] if dataset.X_num is not None else 0\n",
    "d_in = np.sum(K) + num_numerical_features\n",
    "model_params['d_in'] = d_in\n",
    "print(d_in)\n",
    "\n",
    "print(model_params)\n",
    "model = get_model(\n",
    "    model_type,\n",
    "    model_params,\n",
    "    num_numerical_features,\n",
    "    category_sizes=dataset.get_category_sizes('train')\n",
    ")\n",
    "model.to(device)\n",
    "\n",
    "train_loader = lib.prepare_fast_dataloader(dataset, split='train', batch_size=batch_size)\n",
    "\n",
    "diffusion = GaussianMultinomialDiffusion(\n",
    "    num_classes=K,\n",
    "    num_numerical_features=num_numerical_features,\n",
    "    denoise_fn=model,\n",
    "    gaussian_loss_type=gaussian_loss_type,\n",
    "    num_timesteps=num_timesteps,\n",
    "    scheduler=scheduler,\n",
    "    device=device\n",
    ")\n",
    "diffusion.to(device)\n",
    "diffusion.train()\n",
    "\n",
    "trainer = Trainer(\n",
    "    diffusion,\n",
    "    train_loader,\n",
    "    lr=lr,\n",
    "    weight_decay=weight_decay,\n",
    "    steps=steps,\n",
    "    device=device\n",
    ")\n",
    "trainer.run_loop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_num_samples = 50000\n",
    "test_batch_size = 10000\n",
    "\n",
    "diffusion.eval()\n",
    "_, empirical_class_dist = torch.unique(torch.from_numpy(dataset.y['train']), return_counts=True)\n",
    "x_gen, y_gen = diffusion.sample_all(test_num_samples, test_batch_size, empirical_class_dist.float(), ddim=False)\n",
    "X_gen, y_gen = x_gen.numpy(), y_gen.numpy()\n",
    "num_numerical_features_sample = num_numerical_features + int(dataset.is_regression and not model_params[\"is_y_cond\"])\n",
    "\n",
    "X_num_ = X_gen\n",
    "if num_numerical_features_sample < X_gen.shape[1]:\n",
    "    if T_dict['cat_encoding'] == 'one-hot':\n",
    "        X_gen[:, num_numerical_features_sample:] = to_good_ohe(dataset.cat_transform.steps[0][1], X_num_[:, num_numerical_features_sample:])\n",
    "    X_cat = dataset.cat_transform.inverse_transform(X_gen[:, num_numerical_features_sample:])\n",
    "\n",
    "if num_numerical_features != 0:\n",
    "    X_num_ = dataset.num_transform.inverse_transform(X_gen[:, :num_numerical_features_sample])\n",
    "    X_num = X_num_[:, :num_numerical_features_sample]\n",
    "\n",
    "    X_num_real = np.load(os.path.join(real_data_path, \"X_num_train.npy\"), allow_pickle=True)\n",
    "    disc_cols = []\n",
    "    for col in range(X_num_real.shape[1]):\n",
    "        uniq_vals = np.unique(X_num_real[:, col])\n",
    "        if len(uniq_vals) <= 32 and ((uniq_vals - np.round(uniq_vals)) == 0).all():\n",
    "            disc_cols.append(col)\n",
    "    print(\"Discrete cols:\", disc_cols)\n",
    "    if model_params['num_classes'] == 0:\n",
    "        y_gen = X_num[:, 0]\n",
    "        X_num = X_num[:, 1:]\n",
    "    if len(disc_cols):\n",
    "        X_num = lib.round_columns(X_num_real, X_num, disc_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cat_real = np.load(os.path.join(real_data_path, \"X_cat_train.npy\"), allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_real = np.concatenate((X_cat_real, X_num_real.astype(int)), axis=1)\n",
    "gen_real = np.concatenate((X_cat, X_num.astype(int)), axis=1)\n",
    "\n",
    "df_total = pd.DataFrame(total_real)\n",
    "df_gen = pd.DataFrame(gen_real)\n",
    "\n",
    "for col in df_total.columns:\n",
    "    if col > 7:\n",
    "        df_total[col] = df_total[col].astype(int)\n",
    "        df_gen[col] = df_gen[col].astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "metadata = SingleTableMetadata()\n",
    "metadata.detect_from_dataframe(data=df_total)\n",
    "for col in df_total.columns:\n",
    "    if int(col) > 7:\n",
    "        metadata.update_column(\n",
    "            column_name=col,\n",
    "            sdtype='numerical'\n",
    "        )\n",
    "quality_report = evaluate_quality(\n",
    "    df_total,\n",
    "    df_gen,\n",
    "    metadata\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
