{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e65b1a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef1ac9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# For reproducibility\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "class TabularDataset(Dataset):\n",
    "    \"\"\"Custom PyTorch Dataset for tabular data.\"\"\"\n",
    "    def __init__(self, features, labels):\n",
    "        labels = np.array(labels)\n",
    "        self.features = torch.tensor(features, dtype=torch.float32)\n",
    "        self.labels = torch.tensor(labels, dtype=torch.long)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.features)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.features[idx], self.labels[idx]\n",
    "\n",
    "class LGMVAE(nn.Module):\n",
    "    def __init__(self, input_dim, z_dim, c_dim, y_dim, beta_1=1.0, beta_2=1.0):\n",
    "        super().__init__()\n",
    "        self.input_dim, self.z_dim, self.c_dim, self.y_dim = input_dim, z_dim, c_dim, y_dim\n",
    "        self.beta_1, self.beta_2 = beta_1, beta_2\n",
    "\n",
    "        if c_dim % y_dim != 0:\n",
    "            raise ValueError(\"c_dim must be a multiple of y_dim for even cluster distribution.\")\n",
    "        self.clusters_per_class = c_dim // y_dim\n",
    "\n",
    "        # Inference Network: q(c|x,y) and q(z|x,c,y)\n",
    "        self.qc_net = nn.Sequential(nn.Linear(input_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, c_dim))\n",
    "        self.qz_net = nn.Sequential(nn.Linear(input_dim + c_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU())\n",
    "        self.qz_mean = nn.Linear(512, z_dim)\n",
    "        self.qz_logvar = nn.Linear(512, z_dim)\n",
    "\n",
    "        # Generative Network: p(z|c) and p(x|z) \n",
    "        self.pc_mean = nn.Linear(c_dim, z_dim)\n",
    "        self.pc_logvar = nn.Linear(c_dim, z_dim)\n",
    "        self.px_net = nn.Sequential(nn.Linear(z_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, input_dim))\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def forward(self, x, y_one_hot):\n",
    "        batch_size = x.size(0)\n",
    "        x_y = torch.cat([x, y_one_hot], dim=1)\n",
    "        qc_logits = self.qc_net(x_y)\n",
    "        qc_probs = F.softmax(qc_logits, dim=1)\n",
    "\n",
    "        c_cats = torch.eye(self.c_dim, device=x.device).unsqueeze(0).expand(batch_size, -1, -1)\n",
    "        x_expanded = x.unsqueeze(1).expand(-1, self.c_dim, -1)\n",
    "        y_expanded = y_one_hot.unsqueeze(1).expand(-1, self.c_dim, -1)\n",
    "\n",
    "        x_c_y = torch.cat([x_expanded, c_cats, y_expanded], dim=2).reshape(-1, self.input_dim + self.c_dim + self.y_dim)\n",
    "        c_flat = c_cats.reshape(-1, self.c_dim)\n",
    "\n",
    "        qz_hidden = self.qz_net(x_c_y)\n",
    "        qz_mu, qz_logvar = self.qz_mean(qz_hidden), self.qz_logvar(qz_hidden)\n",
    "        z_sample = self.reparameterize(qz_mu, qz_logvar)\n",
    "\n",
    "        pc_mu, pc_logvar = self.pc_mean(c_flat), self.pc_logvar(c_flat)\n",
    "        px_recon = self.px_net(z_sample)\n",
    "\n",
    "        qz_mu, qz_logvar = qz_mu.reshape(batch_size, self.c_dim, -1), qz_logvar.reshape(batch_size, self.c_dim, -1)\n",
    "        pc_mu, pc_logvar = pc_mu.reshape(batch_size, self.c_dim, -1), pc_logvar.reshape(batch_size, self.c_dim, -1)\n",
    "        px_recon = px_recon.reshape(batch_size, self.c_dim, -1)\n",
    "\n",
    "        recon_loss = F.mse_loss(px_recon, x_expanded, reduction='none').sum(dim=2)\n",
    "        \n",
    "        kl_z = 0.5 * torch.sum(pc_logvar - qz_logvar - 1 + (qz_logvar.exp() + (qz_mu - pc_mu).pow(2)) / pc_logvar.exp(), dim=2)\n",
    "        y_labels = torch.argmax(y_one_hot, dim=1)\n",
    "        pc_prior = torch.zeros_like(qc_probs)\n",
    "        for i in range(batch_size):\n",
    "            label = y_labels[i]\n",
    "            start_cluster, end_cluster = label * self.clusters_per_class, (label + 1) * self.clusters_per_class\n",
    "            pc_prior[i, start_cluster:end_cluster] = 1.0 / self.clusters_per_class\n",
    "        \n",
    "        kl_c = torch.sum(qc_probs * (torch.log(qc_probs + 1e-10) - torch.log(pc_prior + 1e-10)), dim=1)\n",
    "\n",
    "        loss_per_c = torch.sum(qc_probs * (recon_loss + self.beta_1 * kl_z), dim=1)\n",
    "        final_loss = torch.mean(loss_per_c + self.beta_2 * kl_c)\n",
    "\n",
    "        return final_loss, torch.mean(recon_loss), self.beta_1 * torch.mean(kl_z), self.beta_2 * torch.mean(kl_c)\n",
    "\n",
    "\n",
    "    def sample(self, y_label, num_samples=1):\n",
    "        self.eval()\n",
    "        with torch.no_grad():\n",
    "            # Identify the clusters for the given class label ---\n",
    "            start_cluster = y_label * self.clusters_per_class\n",
    "            end_cluster = (y_label + 1) * self.clusters_per_class\n",
    "            cluster_range = list(range(start_cluster, end_cluster))\n",
    "            num_clusters_in_class = len(cluster_range)\n",
    "            \n",
    "            # Logic for balanced sampling from clusters\n",
    "            if num_clusters_in_class == 0:\n",
    "                return torch.empty(0, self.input_dim)\n",
    "\n",
    "            samples_per_cluster = num_samples // num_clusters_in_class\n",
    "            remainder = num_samples % num_clusters_in_class\n",
    "\n",
    "            c_idx_list = []\n",
    "            for i, cluster_idx in enumerate(cluster_range):\n",
    "                # Determine the total number of samples for this specific cluster\n",
    "                num_to_sample = samples_per_cluster\n",
    "                if i < remainder:\n",
    "                    # Distribute the remainder across the first few clusters\n",
    "                    num_to_sample += 1\n",
    "                \n",
    "                if num_to_sample > 0:\n",
    "                    c_idx_list.extend([cluster_idx] * num_to_sample)\n",
    "            \n",
    "            c_idx = torch.tensor(c_idx_list, dtype=torch.long, device=next(self.parameters()).device)\n",
    "            c_idx = c_idx[torch.randperm(len(c_idx))]\n",
    "            \n",
    "            c_cats = F.one_hot(c_idx, num_classes=self.c_dim).float()\n",
    "            pc_mu, pc_logvar = self.pc_mean(c_cats), self.pc_logvar(c_cats)\n",
    "            z = self.reparameterize(pc_mu, pc_logvar)\n",
    "            \n",
    "            recon_features = self.px_net(z)\n",
    "            return recon_features\n",
    "        \n",
    "def validate_epoch(model, val_loader, device):\n",
    "    model.eval()\n",
    "    total_val_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for data, labels in val_loader:\n",
    "            data, labels = data.to(device), labels.to(device)\n",
    "            y_one_hot = F.one_hot(labels, num_classes=model.y_dim).float()\n",
    "            # MODIFIED: Removed image-specific binarization\n",
    "            loss, _, _, _ = model(data, y_one_hot)\n",
    "            total_val_loss += loss.item()\n",
    "    return total_val_loss / len(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4116b212",
   "metadata": {},
   "outputs": [],
   "source": [
    "dname = \"wine\"\n",
    "\n",
    "from ucimlrepo import fetch_ucirepo\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.utils import resample\n",
    "\n",
    "# fetch dataset \n",
    "wine_quality = fetch_ucirepo(id=186) \n",
    "  \n",
    "# data (as pandas dataframes) \n",
    "X = wine_quality.data.features\n",
    "y = wine_quality.data.targets\n",
    "y = y['quality'].apply(lambda value: 0 if value <= 5 else 1)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(X)\n",
    "\n",
    "df_large = pd.DataFrame(X)\n",
    "df_large['target'] = y\n",
    "df_majority = df_large[df_large.target == 1]\n",
    "df_minority = df_large[df_large.target == 0]\n",
    "df_minority_upsampled = resample(\n",
    "    df_minority, \n",
    "    replace=True,                  \n",
    "    n_samples=int(len(df_minority) * 1.5),\n",
    "    random_state=42\n",
    ")\n",
    "df_train_upsampled = pd.concat([df_majority, df_minority_upsampled])\n",
    "df_train_upsampled = df_train_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)\n",
    "X_upsampled = df_train_upsampled.drop(columns='target').values\n",
    "y_upsampled = df_train_upsampled['target'].values\n",
    "\n",
    "X_train, X_temp, y_train, y_temp = train_test_split(\n",
    "    X_upsampled, y_upsampled,\n",
    "    test_size=0.3, \n",
    "    random_state=42, \n",
    "    stratify=y_upsampled,\n",
    ")\n",
    "\n",
    "X_val, X_test, y_val, y_test = train_test_split(\n",
    "    X_temp, y_temp,\n",
    "    test_size=0.5,\n",
    "    random_state=42,\n",
    "    stratify=y_temp\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "870762d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ce_dataset_model import *\n",
    "mname='rf'\n",
    "clf, retrained_clfs = get_models(mname, X_train, y_train, X_val, y_val, X_test, y_test, n_est=25, max_depth=5)\n",
    "rf_classifier = clf\n",
    "\n",
    "y_pred_train = rf_classifier.predict(X_train)\n",
    "test_set_size = 0.1\n",
    "\n",
    "# Split both the features (X_train) and the predicted labels (y_pred_train)\n",
    "X_train_new, X_test_new, y_train_new, y_test_new = train_test_split(\n",
    "    X_train, \n",
    "    y_pred_train, \n",
    "    test_size=test_set_size, \n",
    "    random_state=42,\n",
    "    stratify=y_pred_train \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05a76190",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 512 \n",
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 3e-4\n",
    "INPUT_DIM = X_train_new.shape[1]\n",
    "Y_DIM = len(np.unique(y_train))\n",
    "Z_DIM = 8\n",
    "C_DIM = Y_DIM * 5\n",
    "BETA_1 = 0.5\n",
    "BETA_2 = 0.3\n",
    "PATIENCE = 5\n",
    "BEST_MODEL_PATH = f'{dname}_{mname}_lgmvae.pth'\n",
    "device = \"cuda:1\"\n",
    "print(f\"Input Dim: {INPUT_DIM}, Num Classes: {Y_DIM}, Num Clusters: {C_DIM}\")\n",
    "\n",
    "train_dataset = TabularDataset(X_train_new, np.array(y_train_new))\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_dataset = TabularDataset(X_test_new, np.array(y_test_new))\n",
    "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "model = LGMVAE(\n",
    "    input_dim=INPUT_DIM, z_dim=Z_DIM, c_dim=C_DIM, y_dim=Y_DIM,\n",
    "    beta_1=BETA_1, beta_2=BETA_2\n",
    ").to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "best_val_loss = float('inf')\n",
    "patience_counter = 0\n",
    "\n",
    "for epoch in range(1, EPOCHS + 1):\n",
    "    model.train()\n",
    "    total_loss, total_recon_loss, total_kl_z, total_kl_c = 0, 0, 0, 0\n",
    "    pbar = tqdm(train_loader, desc=f\"Epoch {epoch}/{EPOCHS}\")\n",
    "\n",
    "    for batch_idx, (data, labels) in enumerate(pbar):\n",
    "        data, labels = data.to(device), labels.to(device)\n",
    "        y_one_hot = F.one_hot(labels, num_classes=Y_DIM).float()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss, recon_loss, kl_z, kl_c = model(data, y_one_hot)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        total_recon_loss += recon_loss.item()\n",
    "        total_kl_z += kl_z.item()\n",
    "        total_kl_c += kl_c.item()\n",
    "    \n",
    "    num_batches = len(train_loader)\n",
    "    avg_train_loss = total_loss / num_batches\n",
    "    avg_recon = total_recon_loss / num_batches\n",
    "    avg_kl_z = total_kl_z / num_batches\n",
    "    avg_kl_c = total_kl_c / num_batches\n",
    "    \n",
    "    print(f\"====> Epoch: {epoch} Avg Train Loss: {avg_train_loss:.4f} | \"\n",
    "          f\"Recon: {avg_recon:.4f} | KL_z: {avg_kl_z:.4f} | KL_c: {avg_kl_c:.4f}\")\n",
    "\n",
    "    avg_val_loss = validate_epoch(model, val_loader, device)\n",
    "    if avg_val_loss < best_val_loss:\n",
    "        best_val_loss = avg_val_loss\n",
    "        patience_counter = 0\n",
    "        torch.save(model.state_dict(), BEST_MODEL_PATH)\n",
    "        print(f\"--> Val Loss improved to {avg_val_loss:.4f}. Saving model.\")\n",
    "    else:\n",
    "        patience_counter += 1\n",
    "        print(f\"--> Val Loss did not improve. Patience: {patience_counter}/{PATIENCE}\")\n",
    "\n",
    "    if patience_counter >= PATIENCE:\n",
    "        print(\"Early stopping triggered.\")\n",
    "        break\n",
    "\n",
    "print(\"\\nTraining finished.\")\n",
    "\n",
    "# evaluate its utility for CEs\n",
    "_, _ = evaluate_generative_model_utility(model, rf_classifier, X_train_new, y_train_new, X_test_new, y_test_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44d788bf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llma2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
