{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36204073",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from src.usflows.explib.datasets import DistributionDataset\n",
    "from src.usflows.distributions import GMM\n",
    "import scipy.stats as stats\n",
    "from scipy.stats import gaussian_kde\n",
    "from matplotlib.patches import Circle\n",
    "import copy\n",
    "\n",
    "import time\n",
    "from typing import Optional, Tuple, List\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "import logging\n",
    "\n",
    "# ---------- logging ----------\n",
    "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
    "\n",
    "\n",
    "# ---------- reproducibility ----------\n",
    "def seed_everything(seed: int = 42) -> None:\n",
    "    import random\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "\n",
    "seed_everything(42)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "DIMS = [2, 8, 32, 128]\n",
    "ARCH = \"Deep SVDD\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3f2bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --------------------------- utilities ---------------------------\n",
    "\n",
    "def resolve_device(requested: str = \"cuda\") -> torch.device:\n",
    "    \"\"\"Return a torch.device, falling back to CPU if CUDA is unavailable.\"\"\"\n",
    "    dev = torch.device(requested)\n",
    "    if dev.type == \"cuda\" and not torch.cuda.is_available():\n",
    "        logging.warning(\"CUDA requested but not available. Falling back to CPU.\")\n",
    "        dev = torch.device(\"cpu\")\n",
    "    return dev\n",
    "\n",
    "\n",
    "def he_init_leaky_relu(m: nn.Module) -> None:\n",
    "    \"\"\"Init for Linear layers with LeakyReLU nonlinearity.\"\"\"\n",
    "    if isinstance(m, nn.Linear):\n",
    "        nn.init.kaiming_normal_(m.weight, nonlinearity=\"leaky_relu\")\n",
    "        if m.bias is not None:\n",
    "            nn.init.zeros_(m.bias)\n",
    "\n",
    "\n",
    "class NumpyArrayDataset(Dataset):\n",
    "    \"\"\"Minimal dataset yielding (x, y, idx) from NumPy arrays.\"\"\"\n",
    "    def __init__(self, X: np.ndarray, y: Optional[np.ndarray] = None):\n",
    "        assert X.ndim == 2, \"X must be of shape (n_samples, n_features)\"\n",
    "        X = X.astype(np.float32, copy=False)\n",
    "        if y is None:\n",
    "            y = np.zeros((X.shape[0],), dtype=np.int64)\n",
    "        else:\n",
    "            y = y.astype(np.int64, copy=False)\n",
    "\n",
    "        self.X = torch.from_numpy(X)\n",
    "        self.y = torch.from_numpy(y)\n",
    "        self.idxs = torch.arange(X.shape[0], dtype=torch.long)\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return self.X.shape[0]\n",
    "\n",
    "    def __getitem__(self, idx: int):\n",
    "        return self.X[idx], self.y[idx], self.idxs[idx]\n",
    "\n",
    "\n",
    "class NumpyDatasetWrapper:\n",
    "    \"\"\"\n",
    "    Wraps NumPy arrays for training/validation/test.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        X_train: np.ndarray,\n",
    "        X_val: np.ndarray,\n",
    "        X_test: np.ndarray,\n",
    "        device: torch.device,\n",
    "    ):\n",
    "        self.train_set = NumpyArrayDataset(X_train, None)\n",
    "        self.val_set = NumpyArrayDataset(X_val, None)\n",
    "        self.test_set = NumpyArrayDataset(X_test, None)\n",
    "        self._device = device\n",
    "\n",
    "    def loaders(self, batch_size: int = 128, num_workers: int = 0) -> Tuple[DataLoader, DataLoader, DataLoader]:\n",
    "        pin_mem = self._device.type == \"cuda\"\n",
    "        train_loader = DataLoader(\n",
    "            self.train_set, batch_size=batch_size, shuffle=True,\n",
    "            drop_last=False, num_workers=num_workers, pin_memory=pin_mem\n",
    "        )\n",
    "        val_loader = DataLoader(\n",
    "            self.val_set, batch_size=batch_size, shuffle=False,\n",
    "            drop_last=False, num_workers=num_workers, pin_memory=pin_mem\n",
    "        )\n",
    "        test_loader = DataLoader(\n",
    "            self.test_set, batch_size=batch_size, shuffle=False,\n",
    "            drop_last=False, num_workers=num_workers, pin_memory=pin_mem\n",
    "        )\n",
    "        return train_loader, val_loader, test_loader\n",
    "\n",
    "\n",
    "# --------------------------- model ---------------------------\n",
    "\n",
    "class DeepSVDDNet(nn.Module):\n",
    "    \"\"\"\n",
    "    Dynamically constructed encoder for DeepSVDD.\n",
    "\n",
    "    Width rule (given input_dim = x):\n",
    "      - 2 layers:  x,        x\n",
    "      - 3 layers:  x,       2x,        x\n",
    "      - 4 layers:  x,       4x,       2x,        x\n",
    "      - 5 layers:  x,       8x,       4x,       2x,        x\n",
    "      - 6 layers:  x,      16x,       8x,       4x,       2x,        x\n",
    "\n",
    "    Each listed width is a hidden layer (Linear + [BN] + LeakyReLU),\n",
    "    followed by a final Linear to `latent_dim` without BN/activation.\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int = 10,\n",
    "        latent_dim: int = 2,\n",
    "        batch_norm: bool = True,\n",
    "        depth: int = 3,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        # --- config ---\n",
    "        assert 2 <= depth <= 6, \"Depth must be between 2 and 6.\"\n",
    "        self.rep_dim = int(latent_dim)\n",
    "        self.batch_norm = bool(batch_norm)\n",
    "        self.depth = int(depth)\n",
    "        d = int(input_dim)\n",
    "\n",
    "        # --- build hidden width schedule based on rule ---\n",
    "        if self.depth == 2:\n",
    "            factors = [1, 1]\n",
    "        elif self.depth == 3:\n",
    "            factors = [1, 2, 1]\n",
    "        elif self.depth == 4:\n",
    "            factors = [1, 4, 2, 1]\n",
    "        elif self.depth == 5:\n",
    "            factors = [1, 8, 4, 2, 1]\n",
    "        else:  # depth == 6\n",
    "            factors = [1, 16, 8, 4, 2, 1]\n",
    "            \n",
    "\n",
    "        hidden_dims = [int(f * d) for f in factors]\n",
    "        assert len(hidden_dims) == self.depth\n",
    "\n",
    "        # --- construct layers ---\n",
    "        self.fcs = nn.ModuleList()\n",
    "        self.bns = nn.ModuleList() if self.batch_norm else None\n",
    "\n",
    "        in_width = d\n",
    "        for out_width in hidden_dims:\n",
    "            self.fcs.append(nn.Linear(in_width, out_width, bias=False))\n",
    "            if self.batch_norm:\n",
    "                self.bns.append(nn.BatchNorm1d(out_width, eps=1e-4, affine=False))\n",
    "            in_width = out_width\n",
    "\n",
    "        # init\n",
    "        self.apply(he_init_leaky_relu)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        if self.batch_norm:\n",
    "            for fc, bn in zip(self.fcs, self.bns):\n",
    "                x = F.leaky_relu(bn(fc(x)))\n",
    "        else:\n",
    "            for fc in self.fcs:\n",
    "                x = F.leaky_relu(fc(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "# --------------------------- trainer ---------------------------\n",
    "\n",
    "class DeepSVDDTrainer:\n",
    "    def __init__(\n",
    "        self,\n",
    "        objective: str = \"one-class\",\n",
    "        R: float = 0.0,\n",
    "        c=None,\n",
    "        nu: float = 0.1,\n",
    "        optimizer_name: str = \"adam\",\n",
    "        lr: float = 1e-3,\n",
    "        n_epochs: int = 150,\n",
    "        lr_milestones: tuple = (),\n",
    "        early_stopping_patience: int = 3,\n",
    "        batch_size: int = 128,\n",
    "        weight_decay: float = 1e-6,\n",
    "        device: str = \"cuda\",\n",
    "        n_jobs_dataloader: int = 0,\n",
    "    ):\n",
    "        self.optimizer_name = optimizer_name\n",
    "        self.lr = lr\n",
    "        self.n_epochs = n_epochs\n",
    "        self.lr_milestones = lr_milestones\n",
    "        self.batch_size = batch_size\n",
    "        self.weight_decay = weight_decay\n",
    "        self.early_stopping_patience = early_stopping_patience\n",
    "\n",
    "        self.device = resolve_device(device)\n",
    "\n",
    "        self.n_jobs_dataloader = n_jobs_dataloader\n",
    "\n",
    "        assert objective in (\"one-class\", \"soft-boundary\"), \"Objective must be either 'one-class' or 'soft-boundary'.\"\n",
    "        self.objective = objective\n",
    "\n",
    "        # Deep SVDD parameters\n",
    "        self.R = torch.tensor(float(R), device=self.device)  # radius R initialized with 0 by default.\n",
    "        self.c = torch.as_tensor(c, dtype=torch.float32, device=self.device) if c is not None else None\n",
    "        self.nu = float(nu)\n",
    "\n",
    "        # Optimization parameters\n",
    "        self.warm_up_n_epochs = 10  # number of epochs before R updates (soft-boundary)\n",
    "\n",
    "        # Results\n",
    "        self.train_time: Optional[float] = None\n",
    "        self.test_auc: Optional[float] = None\n",
    "        self.test_time: Optional[float] = None\n",
    "        self.test_scores = None\n",
    "\n",
    "        self.test_input_output_pairs: List[Tuple[np.ndarray, np.ndarray]] = []  # (x_i, z_i)\n",
    "\n",
    "    def _compute_batch_loss(self, outputs: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"Compute Deep SVDD loss for a batch given current objective/R/c.\"\"\"\n",
    "        dist = torch.sum((outputs - self.c) ** 2, dim=1)\n",
    "        if self.objective == \"soft-boundary\":\n",
    "            scores = dist - self.R ** 2\n",
    "            loss = self.R ** 2 + (1.0 / self.nu) * torch.mean(torch.clamp(scores, min=0.0))\n",
    "        else:  # one-class\n",
    "            loss = torch.mean(dist)\n",
    "        return loss\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _eval_val_loss(self, net: nn.Module, val_loader: DataLoader) -> float:\n",
    "        \"\"\"Evaluate mean validation loss over val_loader.\"\"\"\n",
    "        net.eval()\n",
    "        total = 0.0\n",
    "        n_batches = 0\n",
    "        for data in val_loader:\n",
    "            inputs, _, _ = data\n",
    "            inputs = inputs.to(self.device)\n",
    "            outputs = net(inputs)\n",
    "            loss = self._compute_batch_loss(outputs)\n",
    "            total += float(loss.item())\n",
    "            n_batches += 1\n",
    "        return total / max(n_batches, 1)\n",
    "\n",
    "    def train(self, dataset, net: nn.Module) -> Tuple[nn.Module, float]:\n",
    "        logger = logging.getLogger()\n",
    "\n",
    "        # Set device for network\n",
    "        net = net.to(self.device)\n",
    "\n",
    "        # Get train & validation data loaders\n",
    "        train_loader, val_loader, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)\n",
    "\n",
    "        # Optimizer (Adam/AMSGrad)\n",
    "        optimizer = optim.Adam(\n",
    "            net.parameters(), lr=self.lr, weight_decay=self.weight_decay,\n",
    "            amsgrad=self.optimizer_name == \"amsgrad\"\n",
    "        )\n",
    "\n",
    "        # LR scheduler\n",
    "        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=list(self.lr_milestones), gamma=0.1)\n",
    "\n",
    "        # Initialize hypersphere center c (if not provided)\n",
    "        if self.c is None:\n",
    "            logger.info(\"Initializing center c...\")\n",
    "            self.c = self.init_center_c(train_loader, net)\n",
    "            logger.info(\"Center c initialized.\")\n",
    "\n",
    "        # Training\n",
    "        logger.info(\"Starting training with validation-based early stopping...\")\n",
    "        start_time = time.time()\n",
    "        net.train()\n",
    "\n",
    "        early_stopping_counter = 0\n",
    "        best_val_loss = float(\"inf\")\n",
    "        best_state = copy.deepcopy(net.state_dict())\n",
    "\n",
    "        for epoch in range(self.n_epochs):\n",
    "            loss_epoch = 0.0\n",
    "            n_batches = 0\n",
    "            epoch_start_time = time.time()\n",
    "\n",
    "            # -------- train loop --------\n",
    "            net.train()\n",
    "            for data in train_loader:\n",
    "                inputs, _, _ = data\n",
    "                inputs = inputs.to(self.device)\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                outputs = net(inputs)\n",
    "                loss = self._compute_batch_loss(outputs)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "                # Update R for soft-boundary after warmup (based on current batch distances)\n",
    "                if (self.objective == \"soft-boundary\") and (epoch >= self.warm_up_n_epochs):\n",
    "                    with torch.no_grad():\n",
    "                        dist = torch.sum((outputs - self.c) ** 2, dim=1)\n",
    "                        self.R.data = torch.tensor(get_radius(dist, self.nu), device=self.device)\n",
    "\n",
    "                loss_epoch += float(loss.item())\n",
    "                n_batches += 1\n",
    "\n",
    "            # -------- scheduler step --------\n",
    "            scheduler.step()\n",
    "            if epoch in self.lr_milestones:\n",
    "                logger.info(\"  LR scheduler: new learning rate is %g\", float(scheduler.get_last_lr()[0]))\n",
    "\n",
    "            # -------- validation --------\n",
    "            val_loss = self._eval_val_loss(net, val_loader)\n",
    "\n",
    "            # -------- logging --------\n",
    "            epoch_train_time = time.time() - epoch_start_time\n",
    "            logger.info(\n",
    "                \"  Epoch %d/%d\\t Time: %.3f\\t TrainLoss: %.8f\\t ValLoss: %.8f\",\n",
    "                epoch + 1, self.n_epochs, epoch_train_time, loss_epoch / max(n_batches, 1), val_loss\n",
    "            )\n",
    "\n",
    "            # -------- early stopping on validation loss --------\n",
    "            if val_loss < best_val_loss:\n",
    "                best_val_loss = val_loss\n",
    "                best_state = copy.deepcopy(net.state_dict())\n",
    "                early_stopping_counter = 0\n",
    "            else:\n",
    "                early_stopping_counter += 1\n",
    "\n",
    "            if early_stopping_counter >= self.early_stopping_patience:\n",
    "                logger.info(\"Early stopping triggered (no val improvement for %d epochs).\", self.early_stopping_patience)\n",
    "                break\n",
    "\n",
    "        # Restore best weights\n",
    "        net.load_state_dict(best_state)\n",
    "\n",
    "        self.train_time = time.time() - start_time\n",
    "        logging.getLogger().info(\"Training time: %.3f\", self.train_time)\n",
    "        logging.getLogger().info(\"Finished training.\")\n",
    "        return net, best_val_loss\n",
    "\n",
    "    def test(self, dataset, net: nn.Module):\n",
    "        logger = logging.getLogger()\n",
    "\n",
    "        # Set device for network\n",
    "        net = net.to(self.device)\n",
    "\n",
    "        _, _, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)\n",
    "\n",
    "        # Testing\n",
    "        logger.info(\"Starting testing...\")\n",
    "        start_time = time.time()\n",
    "        idx_label_score = []\n",
    "        self.test_input_output_pairs = []  # reset\n",
    "\n",
    "        net.eval()\n",
    "        with torch.no_grad():\n",
    "            for data in test_loader:\n",
    "                inputs, labels, idx = data\n",
    "                inputs = inputs.to(self.device)\n",
    "                outputs = net(inputs)\n",
    "                dist = torch.sum((outputs - self.c) ** 2, dim=1)\n",
    "                scores = dist - self.R ** 2 if self.objective == \"soft-boundary\" else dist\n",
    "\n",
    "                # Save triples (idx, label, score) and the latent outputs (as before)\n",
    "                idx_label_score += list(zip(\n",
    "                    idx.cpu().numpy().tolist(),\n",
    "                    labels.cpu().numpy().tolist(),\n",
    "                    scores.cpu().numpy().tolist(),\n",
    "                    outputs.cpu().numpy().tolist()\n",
    "                ))\n",
    "\n",
    "                # Store (input, output) pairs as numpy arrays for every sample\n",
    "                self.test_input_output_pairs += list(zip(\n",
    "                    inputs.cpu().numpy(),         # original (possibly standardized) input x_i\n",
    "                    outputs.cpu().numpy()         # latent representation z_i = f_theta(x_i)\n",
    "                ))\n",
    "\n",
    "        self.test_time = time.time() - start_time\n",
    "        logger.info(\"Testing time: %.3f\", self.test_time)\n",
    "        self.test_scores = idx_label_score\n",
    "\n",
    "        # Compute AUC if labels provided (0/1)\n",
    "        _, labels, scores, _ = zip(*idx_label_score)\n",
    "        labels = np.array(labels)\n",
    "        scores = np.array(scores)\n",
    "        try:\n",
    "            self.test_auc = roc_auc_score(labels, scores)\n",
    "            logger.info(\"Test set AUC: %.2f%%\", 100.0 * self.test_auc)\n",
    "        except ValueError:\n",
    "            self.test_auc = None\n",
    "            logger.info(\"AUC not computed (labels may be missing or degenerate).\")\n",
    "\n",
    "        logger.info(\"Finished testing.\")\n",
    "\n",
    "    def init_center_c(self, train_loader: DataLoader, net: nn.Module, eps: float = 0.1) -> torch.Tensor:\n",
    "        \"\"\"Initialize hypersphere center c as the mean of f(x) over the training data.\"\"\"\n",
    "        n_samples = 0\n",
    "        c = torch.zeros(net.rep_dim, device=self.device)\n",
    "        net.eval()\n",
    "        with torch.no_grad():\n",
    "            for data in train_loader:\n",
    "                inputs, _, _ = data\n",
    "                inputs = inputs.to(self.device)\n",
    "                outputs = net(inputs)\n",
    "                n_samples += outputs.shape[0]\n",
    "                c += torch.sum(outputs, dim=0)\n",
    "\n",
    "        c /= max(n_samples, 1)\n",
    "\n",
    "        # Avoid trivial zero-dimension\n",
    "        c[(torch.abs(c) < eps) & (c < 0)] = -eps\n",
    "        c[(torch.abs(c) < eps) & (c > 0)] = eps\n",
    "        \n",
    "        net.c = c  # store in net for convenience\n",
    "        return c\n",
    "\n",
    "\n",
    "def get_radius(dist: torch.Tensor, nu: float) -> float:\n",
    "    \"\"\"Solve for R via the (1 - nu)-quantile of distances.\"\"\"\n",
    "    return float(np.quantile(torch.sqrt(dist).cpu().numpy(), 1.0 - nu))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "582dfcea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll_norm_scatter_plot(deepsvdd_net, dim, ref_distribution, n_samples=10000):\n",
    "        \"\"\"\n",
    "        Scatter plot of log-probabilities of latent norms vs base distribution.\n",
    "        \n",
    "        Args:\n",
    "            ref_distribution: Reference distribution for nll computation\n",
    "            ax: Matplotlib axis (optional)\n",
    "            n_samples: Number of samples for base distribution\n",
    "        \"\"\"\n",
    "        plt.rcParams.update({\n",
    "            \"pgf.texsystem\": \"pdflatex\",\n",
    "            \"text.usetex\": False,\n",
    "            \"pgf.rcfonts\": False,\n",
    "            \"font.size\": 14,\n",
    "            \"axes.labelsize\": 16,\n",
    "            \"xtick.labelsize\": 13,\n",
    "            \"ytick.labelsize\": 13,\n",
    "            \"legend.fontsize\": 12\n",
    "        })\n",
    "        plt.style.use('ggplot')\n",
    "        \n",
    "        scatter_fig, ax = plt.subplots()\n",
    "\n",
    "        # Sample from the reference distribution\n",
    "        base_samples = ref_distribution.sample((n_samples,)).to(device)\n",
    "        base_samples = base_samples.view(base_samples.shape[0], -1)\n",
    "\n",
    "        # Compute log-probabilities\n",
    "        with torch.no_grad():\n",
    "            nlls = -ref_distribution.log_prob(base_samples).cpu().numpy()\n",
    "            mappings = deepsvdd_net(base_samples.to(device))\n",
    "            c = deepsvdd_net.c\n",
    "            latent_norms = (mappings - c).norm(p=2, dim=1).cpu().numpy()\n",
    "\n",
    "        # Compute Pearson correlation\n",
    "        pearson_r, _ = stats.pearsonr(nlls, latent_norms)\n",
    "        spearman_rho, _ = stats.spearmanr(nlls, latent_norms)\n",
    "        kendall_tau, _ = stats.kendalltau(nlls, latent_norms)\n",
    "\n",
    "        # Scatter plot\n",
    "        ax.scatter(nlls, latent_norms, alpha=0.5)\n",
    "        ax.set_xlabel(\"Negative Log-Likelihood\")\n",
    "        ax.set_ylabel(\"Latent Norm\")\n",
    "        ax.set_title(\"Negative Log-Likelihood vs Latent Norm\")\n",
    "\n",
    "        ax.text(0.6, 0.05, f\"Pearson R: {pearson_r:.2f}\\nSpearman Rho: {spearman_rho:.2f}\\nKendall Tau: {kendall_tau:.2f}\", \n",
    "                transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.5))\n",
    "\n",
    "        ax.set_title(f\"NLL vs Latent Norm ({dim}D)\")\n",
    "        scatter_fig.savefig(f\"gmm_eval_nll_vs_latent_norms_{dim}D_{ARCH}_lr5.png\", dpi=400)\n",
    "\n",
    "    \n",
    "def plot_gmm_contour(deepsvdd_net, test_input_output_pairs, color_scale):\n",
    "    _, ax = plt.subplots(figsize=(5, 5))\n",
    "    ax.set_facecolor('white')\n",
    "    \n",
    "    latents = np.array([z for _, z in test_input_output_pairs])\n",
    "    center = deepsvdd_net.c.cpu().numpy()\n",
    "    latents = torch.tensor(latents) - torch.tensor(center)\n",
    "\n",
    "    x = latents[:, 0].numpy()\n",
    "    y = latents[:, 1].numpy()\n",
    "\n",
    "    x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]\n",
    "    positions = np.vstack([x_grid.ravel(), y_grid.ravel()])\n",
    "\n",
    "\n",
    "    ax.scatter(x, y, s=5, c=color_scale, cmap=\"magma\", alpha=1)\n",
    "    #plt.colorbar(scatter, label='Data Density')\n",
    "\n",
    "    # Customize plot\n",
    "    ax.set_title(f'Centered Latent Data Distribution\\n2D GMM ({ARCH})')\n",
    "\n",
    "    # Draw circle at 99th percentile of centered distances from origin\n",
    "    distances = np.linalg.norm(latents.numpy(), axis=1)\n",
    "    radius_99 = np.percentile(distances, 99)\n",
    "    ax.add_patch(Circle((0, 0), radius_99,fill=False, edgecolor='black', linewidth=.5, linestyle='--'))\n",
    "\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(alpha=0.2)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"contour_latent_2d_gmm_{ARCH}_density_lr5.png\", dpi=400)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7afac4cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEPTH = [2, 3, 4, 5, 6]\n",
    "BATCH_NORM = [True, True]\n",
    "\n",
    "for dim in DIMS:\n",
    "    print(f\"{dim}D GMM\")\n",
    "    hdim = int(dim/2)\n",
    "    distribution=GMM(\n",
    "        loc=torch.stack([-torch.ones(dim), torch.ones(dim)]), \n",
    "        covariance_matrix=torch.stack([\n",
    "            torch.diag(torch.tensor([5.]*hdim + [.5]*hdim)), \n",
    "            torch.eye(dim)\n",
    "        ]),\n",
    "        mixture_weights=torch.ones(2)/2\n",
    "    )\n",
    "    ref_dist = distribution\n",
    "\n",
    "    ds_train = DistributionDataset(\n",
    "        distribution=distribution,\n",
    "        num_samples=10000\n",
    "    )[:][0]\n",
    "    \n",
    "    ds_test = DistributionDataset(\n",
    "        distribution=distribution,\n",
    "        num_samples=1000\n",
    "    )[:][0]\n",
    "    ds_val = DistributionDataset(\n",
    "        distribution=distribution,\n",
    "        num_samples=1000\n",
    "    )[:][0]\n",
    "    ds_train = ds_train.numpy()\n",
    "    ds_test = ds_test.numpy()\n",
    "    ds_val = ds_val.numpy()\n",
    "\n",
    "    dataset = NumpyDatasetWrapper(X_train=ds_train, X_val=ds_val, X_test=ds_test, device=resolve_device(device))\n",
    "\n",
    "    val_losses = []\n",
    "    nets = []\n",
    "    for batch_norm in BATCH_NORM:\n",
    "        print(f\"Batch Norm: {batch_norm}\")\n",
    "        for depth in DEPTH:\n",
    "            print(f\"Depth: {depth}\")\n",
    "\n",
    "            # Instantiate and train model\n",
    "            net = DeepSVDDNet(input_dim=dim, latent_dim=dim, batch_norm=batch_norm, depth=depth)\n",
    "            trainer = DeepSVDDTrainer(\n",
    "                objective=\"one-class\",  # or 'soft-boundary'\n",
    "                nu=0.10,\n",
    "                lr=1e-5,\n",
    "                n_epochs=100,\n",
    "                lr_milestones=(100,),  # no decay\n",
    "                early_stopping_patience=3,\n",
    "                batch_size=32,\n",
    "                weight_decay=1e-3, # Checked 1e-3, 1e-4 1e-5\n",
    "                device=device,\n",
    "            )\n",
    "            net, val_loss = trainer.train(dataset, net)\n",
    "            val_losses.append(val_loss)\n",
    "            nets.append(net)\n",
    "            \n",
    "    # Select best model based on validation loss\n",
    "    best_idx = np.argmin(val_losses)\n",
    "    net = nets[best_idx]\n",
    "    trainer.test(dataset, net)\n",
    "    \n",
    "    # Print used depth and batch norm of best model\n",
    "    best_depth = DEPTH[best_idx % len(DEPTH)]\n",
    "    print(f\"Best model - Depth: {best_depth}, Batch Norm: {BATCH_NORM[best_idx // len(DEPTH)]}\")\n",
    "\n",
    "    # (input, output) for every test sample\n",
    "    test_input_output_pairs = trainer.test_input_output_pairs \n",
    "    c = net.c.cpu().numpy()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        tensor_ds_test = torch.tensor(ds_test).to(device)\n",
    "        color_scale = torch.exp(distribution.log_prob(tensor_ds_test))\n",
    "\n",
    "    nll_norm_scatter_plot(deepsvdd_net=net, dim=dim, ref_distribution=ref_dist)\n",
    "    if dim == 2:\n",
    "        plot_gmm_contour(deepsvdd_net=net, test_input_output_pairs=test_input_output_pairs, color_scale=color_scale)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nf4ad-ls",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
