{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "654d0dd7",
   "metadata": {},
   "source": [
    "Let's explore the last-layer embeddings in the QMIA model on C10 with Class 1 (automobile) held out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6771d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "from lightning_utils import LightningQMIA\n",
    "from data_utils import CustomDataModule\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "device = torch.device(\"cuda:3\")\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"3\"\n",
    "\n",
    "def str2bool(v):\n",
    "    if isinstance(v, bool):\n",
    "        return v\n",
    "    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n",
    "        return True\n",
    "    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n",
    "        return False\n",
    "    else:\n",
    "        raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n",
    "\n",
    "\n",
    "parser = argparse.ArgumentParser(description=\"QMIA evaluation\")\n",
    "parser.add_argument(\"--seed\", type=int, default=0, help=\"random seed\")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--batch_size\", type=int, default=16, help=\"batch size\"\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--image_size\",\n",
    "    type=int,\n",
    "    default=-1,\n",
    "    help=\"image input size, set to -1 to use dataset's default value\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--base_image_size\",\n",
    "    type=int,\n",
    "    default=-1,\n",
    "    help=\"image input size to base model, set to -1 to use dataset's default value\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--architecture\",\n",
    "    type=str,\n",
    "    default=\"facebook/convnext-tiny-224\",\n",
    "    help=\"Attack Model Type\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--base_architecture\",\n",
    "    type=str,\n",
    "    default=\"resnet-18\",\n",
    "    help=\"Base Model Type\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--score_fn\",\n",
    "    type=str,\n",
    "    default=\"top_two_margin\",\n",
    "    help=\"score function (true_logit_margin, top_two_margin)\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--loss_fn\",\n",
    "    type=str,\n",
    "    default=\"gaussian\",\n",
    "    help=\"loss function (gaussian, pinball)\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--base_model_dataset\",\n",
    "    type=str,\n",
    "    default=\"cinic10/0_16\",\n",
    "    help=\"dataset (i.e. cinic10/0_16, imagenet/0_16, cifar100/0_16)\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--attack_dataset\",\n",
    "    type=str,\n",
    "    default=None,\n",
    "    help=\"dataset (i.e. cinic10/0_16, imagenet/0_16, cifar100/0_16), if None, use the same as base_model_dataset\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--checkpoint\",\n",
    "    type=str,\n",
    "    default=\"best_val_loss\",\n",
    "    help=\"checkpoint path (either best_val_loss or last)\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--model_root\",\n",
    "    type=str,\n",
    "    default=\"./models/\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--data_root\",\n",
    "    type=str,\n",
    "    default=\"./data/\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--results_root\",\n",
    "    type=str,\n",
    "    default=\"./outputs/\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--data_mode\",\n",
    "    type=str,\n",
    "    default=\"eval\",\n",
    "    help=\"data mode (either base, mia, or eval)\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--cls_drop\",\n",
    "    type=int,\n",
    "    nargs=\"*\",\n",
    "    default=[],\n",
    "    help=\"drop classes from the dataset, e.g. --cls_drop 1 3 7\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--cls_drop_range\",\n",
    "    type=str,\n",
    "    default=None,\n",
    "    help=\"drop classes from the dataset, e.g. --cls_drop_range 0-500\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--DEBUG\",\n",
    "    action=\"store_true\",\n",
    "    help=\"debug mode, set to True to run on CPU and with fewer epochs\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--rerun\", action=\"store_true\", help=\"whether to rerun the evaluation\"\n",
    ")\n",
    "\n",
    "args = parser.parse_args([\n",
    "    \"--attack_dataset=cinic10/0_16\",\n",
    "    \"--base_model_dataset=cinic10/0_16\",\n",
    "    \"--architecture=facebook/convnext-tiny-224\",\n",
    "    \"--base_architecture=cifar-resnet-50\",\n",
    "    \"--model_root=./models/\",\n",
    "    \"--data_root=../image_QMIA_v3/data/\",\n",
    "    \"--batch_size=128\",\n",
    "    \"--image_size=224\",\n",
    "    \"--score_fn=top_two_margin\",\n",
    "    \"--loss_fn=gaussian\",\n",
    "    \"--checkpoint=last\",\n",
    "    \"--cls_drop=1\",\n",
    "    \"--rerun\"\n",
    "])\n",
    "\n",
    "seed = args.seed\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "if args.attack_dataset is None:\n",
    "    args.attack_dataset = args.base_model_dataset\n",
    "\n",
    "if args.cls_drop and args.cls_drop_range:\n",
    "    raise ValueError(\n",
    "        \"You can only specify one of --cls_drop and --cls_drop_range\"\n",
    "    )\n",
    "\n",
    "if args.cls_drop:\n",
    "    cls_drop_str = \"\".join(str(c) for c in args.cls_drop)\n",
    "elif args.cls_drop_range:\n",
    "    start, end = map(int, args.cls_drop_range.split(\"-\"))\n",
    "    cls_drop_str = f\"{start}to{end}\"\n",
    "    args.cls_drop = list(range(start, end))\n",
    "else:\n",
    "    cls_drop_str = \"none\"\n",
    "\n",
    "args.attack_checkpoint_path = os.path.join(\n",
    "    args.model_root,\n",
    "    \"mia\",\n",
    "    \"base_\" + args.base_model_dataset,\n",
    "    args.base_architecture,\n",
    "    \"attack_\" + args.attack_dataset,\n",
    "    args.architecture,\n",
    "    \"score_fn_\" + args.score_fn,\n",
    "    \"loss_fn_\" + args.loss_fn,\n",
    "    \"cls_drop_\" + cls_drop_str,\n",
    ")\n",
    "\n",
    "args.base_checkpoint_path = os.path.join(\n",
    "    args.model_root,\n",
    "    \"base\",\n",
    "    args.base_model_dataset,\n",
    "    args.base_architecture\n",
    ")\n",
    "\n",
    "args.attack_results_path = os.path.join(\n",
    "    args.attack_checkpoint_path,\n",
    "    \"predictions\",\n",
    ")\n",
    "\n",
    "args.attack_plots_path = os.path.join(\n",
    "    args.attack_results_path,\n",
    "    \"plots\",\n",
    ")\n",
    "\n",
    "if \"cifar100\" in args.base_model_dataset.lower():\n",
    "    args.num_base_classes = 100\n",
    "elif \"imagenet\" in args.base_model_dataset.lower():\n",
    "    args.num_base_classes = 1000\n",
    "elif \"cifar20\" in args.base_model_dataset.lower():\n",
    "    args.num_base_classes = 20\n",
    "else:\n",
    "    args.num_base_classes = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf961fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "lightning_model = LightningQMIA.load_from_checkpoint(os.path.join(\n",
    "        args.attack_checkpoint_path,\n",
    "        f\"{args.checkpoint}.ckpt\"\n",
    "    ))\n",
    "lightning_model.eval()\n",
    "\n",
    "datamodule = CustomDataModule(\n",
    "    dataset_name=args.base_model_dataset,\n",
    "    stage=args.data_mode,\n",
    "    num_workers=16,\n",
    "    image_size=args.image_size,\n",
    "    base_image_size=args.base_image_size,\n",
    "    batch_size=args.batch_size if not args.DEBUG else 2,\n",
    "    data_root=args.data_root,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c10c09aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tqdm.notebook import tqdm\n",
    "import umap\n",
    "\n",
    "datamodule.setup(stage=\"mia\")\n",
    "dataloader = datamodule.val_dataloader()\n",
    "\n",
    "def get_batch(batch):\n",
    "    if len(batch) == 2:\n",
    "        raise ValueError(\n",
    "            \"Batch should contain 3 elements: samples, targets, and base_samples\"\n",
    "        )\n",
    "    else:\n",
    "        samples, targets, base_samples = batch\n",
    "    return samples, targets, base_samples\n",
    "\n",
    "# Check if embeddings already exist\n",
    "if os.path.exists(os.path.join(args.attack_results_path, \"embeddings.pt\")):\n",
    "    print(\"Embeddings already exist. Loading...\")\n",
    "    combined = torch.load(os.path.join(args.attack_results_path, \"embeddings.pt\"), weights_only=False)\n",
    "    embeddings = combined[\"embeddings\"]\n",
    "    targets = combined[\"targets\"]\n",
    "else:\n",
    "    print(\"Extracting embeddings...\")\n",
    "    all_embeddings = []\n",
    "    all_targets = []\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(dataloader, desc=\"Extracting embeddings\"):\n",
    "            # Unpack your batch\n",
    "            samples, targets, base_samples = get_batch(batch)\n",
    "            samples = samples.to('cuda')\n",
    "            \n",
    "            # Forward through the ConvNeXt model to get the embeddings\n",
    "            # Access the model's penultimate layer (before final classification)\n",
    "            # For your specific case with HuggingFaceTupleWrapper and ConvNextForImageClassification\n",
    "            convnext_model = lightning_model.model.model_base.convnext\n",
    "            \n",
    "            # Get embeddings before the final classification head\n",
    "            x = convnext_model.embeddings(samples)\n",
    "            x = convnext_model.encoder(x)\n",
    "            \n",
    "            # The pooler typically gives the final representation before classification\n",
    "            embeddings = x.last_hidden_state.mean(dim=(2, 3))  # Global average pooling\n",
    "            \n",
    "            # Store embeddings and targets\n",
    "            all_embeddings.append(embeddings.cpu())\n",
    "            all_targets.append(targets.cpu())\n",
    "\n",
    "    # Concatenate all batches\n",
    "    embeddings = torch.cat(all_embeddings, dim=0).numpy()\n",
    "    targets = torch.cat(all_targets, dim=0).numpy()\n",
    "\n",
    "    print(f\"Extracted {len(embeddings)} embeddings with shape: {embeddings.shape}\")\n",
    "\n",
    "    # Save the embeddings and targets\n",
    "    torch.save({\n",
    "        \"embeddings\": embeddings,\n",
    "        \"targets\": targets\n",
    "    }, os.path.join(args.attack_results_path, \"embeddings.pt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f396c611",
   "metadata": {},
   "source": [
    "UMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1f1446d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# seen_classes = [c for c in range(args.num_base_classes) if c not in args.cls_drop]\n",
    "\n",
    "# # Check if targets are in seen classes\n",
    "# is_seen = np.isin(targets, seen_classes)\n",
    "\n",
    "# # Create UMAP reducer\n",
    "# reducer = umap.UMAP(\n",
    "#     n_neighbors=15,  # Lower for more local structure, higher for more global\n",
    "#     min_dist=0.1,    # Lower for tighter clusters\n",
    "#     n_components=2,  # 2D for visualization\n",
    "#     random_state=42  # For reproducibility\n",
    "# )\n",
    "\n",
    "# # Fit and transform the embeddings\n",
    "# print(\"Computing UMAP (this may take a minute)...\")\n",
    "# embedding = reducer.fit_transform(embeddings)\n",
    "\n",
    "# # Create plots\n",
    "# plt.figure(figsize=(12, 10))\n",
    "\n",
    "# # Color points by class\n",
    "# scatter = plt.scatter(\n",
    "#     embedding[:, 0], \n",
    "#     embedding[:, 1], \n",
    "#     c=targets, \n",
    "#     cmap='tab10', \n",
    "#     alpha=0.7,\n",
    "#     s=10\n",
    "# )\n",
    "\n",
    "# # Add a colorbar\n",
    "# plt.colorbar(scatter, label='Class')\n",
    "# plt.title('UMAP of ConvNeXt Embeddings by Class')\n",
    "# plt.tight_layout()\n",
    "# plt.savefig('umap_by_class.png', dpi=300)\n",
    "# plt.show()\n",
    "\n",
    "# # If we have seen/unseen class information, create that plot too\n",
    "# if is_seen is not None:\n",
    "#     plt.figure(figsize=(12, 10))\n",
    "    \n",
    "#     # Create a scatter plot colored by seen/unseen status\n",
    "#     scatter = plt.scatter(\n",
    "#         embedding[:, 0], \n",
    "#         embedding[:, 1], \n",
    "#         c=is_seen, \n",
    "#         cmap=plt.cm.get_cmap('coolwarm', 2),  # Red for unseen, blue for seen\n",
    "#         alpha=0.7,\n",
    "#         s=10\n",
    "#     )\n",
    "    \n",
    "#     # Add a colorbar with labels\n",
    "#     cbar = plt.colorbar(scatter, ticks=[0, 1])\n",
    "#     cbar.set_ticklabels(['Unseen', 'Seen'])\n",
    "    \n",
    "#     plt.title('UMAP of ConvNeXt Embeddings - Seen vs. Unseen Classes')\n",
    "#     plt.tight_layout()\n",
    "#     plt.savefig('umap_seen_unseen.png', dpi=300)\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "622a2d5c",
   "metadata": {},
   "source": [
    "### Gaussian Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f850cb9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import torch\n",
    "\n",
    "# Standard scale the data\n",
    "scaler = StandardScaler()\n",
    "scaled_embeds = scaler.fit_transform(embeddings)\n",
    "\n",
    "# Reduce the embedding size\n",
    "N_COMPONENTS = 2\n",
    "pca = PCA(n_components=N_COMPONENTS)\n",
    "pca_embeds = pca.fit_transform(scaled_embeds)\n",
    "\n",
    "# Check if targets are in seen classes\n",
    "seen_classes = [c for c in range(args.num_base_classes) if c not in args.cls_drop]\n",
    "is_seen = np.isin(targets, seen_classes)\n",
    "\n",
    "P_indices = np.where(is_seen)[0] # indices of seen data\n",
    "Q_indices = range(len(targets)) # indices of all data\n",
    "\n",
    "# keep_size = len(P_indices) # keep the same number of seen and unseen data\n",
    "# Q_indices = np.random.choice(Q_indices, size=keep_size, replace=False)\n",
    "\n",
    "P = torch.tensor(pca_embeds[P_indices]).to('cuda') # seen data\n",
    "Q = torch.tensor(pca_embeds).to('cuda')  # all data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10478aa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from math import pi\n",
    "from scipy.special import logsumexp\n",
    "\n",
    "def calculate_matmul_n_times(n_components, mat_a, mat_b):\n",
    "    \"\"\"\n",
    "    Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].\n",
    "    Bypasses torch.matmul to reduce memory footprint.\n",
    "    args:\n",
    "        mat_a:      torch.Tensor (n, k, 1, d)\n",
    "        mat_b:      torch.Tensor (1, k, d, d)\n",
    "    \"\"\"\n",
    "    res = torch.zeros(mat_a.shape).to(mat_a.device)\n",
    "    \n",
    "    for i in range(n_components):\n",
    "        mat_a_i = mat_a[:, i, :, :].squeeze(-2)\n",
    "        mat_b_i = mat_b[0, i, :, :].squeeze()\n",
    "        res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1)\n",
    "    \n",
    "    return res\n",
    "\n",
    "\n",
    "def calculate_matmul(mat_a, mat_b):\n",
    "    \"\"\"\n",
    "    Calculate matrix product of two matrics with mat_a[0] >= mat_b[0].\n",
    "    Bypasses torch.matmul to reduce memory footprint.\n",
    "    args:\n",
    "        mat_a:      torch.Tensor (n, k, 1, d)\n",
    "        mat_b:      torch.Tensor (n, k, d, 1)\n",
    "    \"\"\"\n",
    "    assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1\n",
    "    return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True)\n",
    "\n",
    "class GaussianMixture(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Fits a mixture of k=1,..,K Gaussians to the input data (K is supplied via n_components).\n",
    "    Input tensors are expected to be flat with dimensions (n: number of samples, d: number of features).\n",
    "    The model then extends them to (n, 1, d).\n",
    "    The model parametrization (mu, sigma) is stored as (1, k, d),\n",
    "    probabilities are shaped (n, k, 1) if they relate to an individual sample,\n",
    "    or (1, k, 1) if they assign membership probabilities to one of the mixture components.\n",
    "    \"\"\"\n",
    "    def __init__(self, n_components, n_features, covariance_type=\"full\", eps=1.e-6, init_params=\"kmeans\", mu_init=None, var_init=None):\n",
    "        \"\"\"\n",
    "        Initializes the model and brings all tensors into their required shape.\n",
    "        The class expects data to be fed as a flat tensor in (n, d).\n",
    "        The class owns:\n",
    "            x:               torch.Tensor (n, 1, d)\n",
    "            mu:              torch.Tensor (1, k, d)\n",
    "            var:             torch.Tensor (1, k, d) or (1, k, d, d)\n",
    "            pi:              torch.Tensor (1, k, 1)\n",
    "            covariance_type: str\n",
    "            eps:             float\n",
    "            init_params:     str\n",
    "            log_likelihood:  float\n",
    "            n_components:    int\n",
    "            n_features:      int\n",
    "        args:\n",
    "            n_components:    int\n",
    "            n_features:      int\n",
    "        options:\n",
    "            mu_init:         torch.Tensor (1, k, d)\n",
    "            var_init:        torch.Tensor (1, k, d) or (1, k, d, d)\n",
    "            covariance_type: str\n",
    "            eps:             float\n",
    "            init_params:     str\n",
    "        \"\"\"\n",
    "        super(GaussianMixture, self).__init__()\n",
    "\n",
    "        self.n_components = n_components\n",
    "        self.n_features = n_features\n",
    "\n",
    "        self.mu_init = mu_init\n",
    "        self.var_init = var_init\n",
    "        self.eps = eps\n",
    "\n",
    "        self.log_likelihood = -np.inf\n",
    "\n",
    "        self.covariance_type = covariance_type\n",
    "        self.init_params = init_params\n",
    "\n",
    "        assert self.covariance_type in [\"full\", \"diag\"]\n",
    "        assert self.init_params in [\"kmeans\", \"random\"]\n",
    "\n",
    "        self._init_params()\n",
    "\n",
    "\n",
    "    def _init_params(self):\n",
    "        if self.mu_init is not None:\n",
    "            assert self.mu_init.size() == (1, self.n_components, self.n_features), \"Input mu_init does not have required tensor dimensions (1, %i, %i)\" % (self.n_components, self.n_features)\n",
    "            # (1, k, d)\n",
    "            self.mu = torch.nn.Parameter(self.mu_init, requires_grad=False)\n",
    "        else:\n",
    "            self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features), requires_grad=False)\n",
    "\n",
    "        if self.covariance_type == \"diag\":\n",
    "            if self.var_init is not None:\n",
    "                # (1, k, d)\n",
    "                assert self.var_init.size() == (1, self.n_components, self.n_features), \"Input var_init does not have required tensor dimensions (1, %i, %i)\" % (self.n_components, self.n_features)\n",
    "                self.var = torch.nn.Parameter(self.var_init, requires_grad=False)\n",
    "            else:\n",
    "                self.var = torch.nn.Parameter(torch.ones(1, self.n_components, self.n_features), requires_grad=False)\n",
    "        elif self.covariance_type == \"full\":\n",
    "            if self.var_init is not None:\n",
    "                # (1, k, d, d)\n",
    "                assert self.var_init.size() == (1, self.n_components, self.n_features, self.n_features), \"Input var_init does not have required tensor dimensions (1, %i, %i, %i)\" % (self.n_components, self.n_features, self.n_features)\n",
    "                self.var = torch.nn.Parameter(self.var_init, requires_grad=False)\n",
    "            else:\n",
    "                self.var = torch.nn.Parameter(\n",
    "                    torch.eye(self.n_features).reshape(1, 1, self.n_features, self.n_features).repeat(1, self.n_components, 1, 1),\n",
    "                    requires_grad=False\n",
    "                )\n",
    "\n",
    "        # (1, k, 1)\n",
    "        self.pi = torch.nn.Parameter(torch.Tensor(1, self.n_components, 1), requires_grad=False).fill_(1. / self.n_components)\n",
    "        self.params_fitted = False\n",
    "\n",
    "\n",
    "    def check_size(self, x):\n",
    "        if len(x.size()) == 2:\n",
    "            # (n, d) --> (n, 1, d)\n",
    "            x = x.unsqueeze(1)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "    def bic(self, x):\n",
    "        \"\"\"\n",
    "        Bayesian information criterion for a batch of samples.\n",
    "        args:\n",
    "            x:      torch.Tensor (n, d) or (n, 1, d)\n",
    "        returns:\n",
    "            bic:    float\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "        n = x.shape[0]\n",
    "\n",
    "        # Free parameters for covariance, means and mixture components\n",
    "        free_params = self.n_features * self.n_components + self.n_features + self.n_components - 1\n",
    "\n",
    "        bic = -2. * self.__score(x, as_average=False).mean() * n + free_params * np.log(n)\n",
    "\n",
    "        return bic\n",
    "\n",
    "\n",
    "    def fit(self, x, delta=1e-3, n_iter=100, warm_start=False):\n",
    "        \"\"\"\n",
    "        Fits model to the data.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, d) or (n, k, d)\n",
    "        options:\n",
    "            delta:      float\n",
    "            n_iter:     int\n",
    "            warm_start: bool\n",
    "        \"\"\"\n",
    "        if not warm_start and self.params_fitted:\n",
    "            self._init_params()\n",
    "\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        if self.init_params == \"kmeans\" and self.mu_init is None:\n",
    "            mu = self.get_kmeans_mu(x, n_centers=self.n_components)\n",
    "            self.mu.data = mu\n",
    "\n",
    "        i = 0\n",
    "        j = np.inf\n",
    "\n",
    "        while (i <= n_iter) and (j >= delta):\n",
    "\n",
    "            log_likelihood_old = self.log_likelihood\n",
    "            mu_old = self.mu\n",
    "            var_old = self.var\n",
    "\n",
    "            self.__em(x)\n",
    "            self.log_likelihood = self.__score(x)\n",
    "\n",
    "            if torch.isinf(self.log_likelihood.abs()) or torch.isnan(self.log_likelihood):\n",
    "                device = self.mu.device\n",
    "                # When the log-likelihood assumes unbound values, reinitialize model\n",
    "                self.__init__(self.n_components,\n",
    "                    self.n_features,\n",
    "                    covariance_type=self.covariance_type,\n",
    "                    mu_init=self.mu_init,\n",
    "                    var_init=self.var_init,\n",
    "                    eps=self.eps)\n",
    "                for p in self.parameters():\n",
    "                    p.data = p.data.to(device)\n",
    "                if self.init_params == \"kmeans\":\n",
    "                    self.mu.data, = self.get_kmeans_mu(x, n_centers=self.n_components)\n",
    "\n",
    "            i += 1\n",
    "            j = self.log_likelihood - log_likelihood_old\n",
    "\n",
    "            if j <= delta:\n",
    "                # When score decreases, revert to old parameters\n",
    "                self.__update_mu(mu_old)\n",
    "                self.__update_var(var_old)\n",
    "\n",
    "            print(\"Iteration %i, log-likelihood: %.3f, delta: %.3f\" % (i, self.log_likelihood, j))\n",
    "\n",
    "        self.params_fitted = True\n",
    "\n",
    "\n",
    "    def predict(self, x, probs=False):\n",
    "        \"\"\"\n",
    "        Assigns input data to one of the mixture components by evaluating the likelihood under each.\n",
    "        If probs=True returns normalized probabilities of class membership.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, d) or (n, 1, d)\n",
    "            probs:      bool\n",
    "        returns:\n",
    "            p_k:        torch.Tensor (n, k)\n",
    "            (or)\n",
    "            y:          torch.LongTensor (n)\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)\n",
    "\n",
    "        if probs:\n",
    "            p_k = torch.exp(weighted_log_prob)\n",
    "            return torch.squeeze(p_k / (p_k.sum(1, keepdim=True)))\n",
    "        else:\n",
    "            return torch.squeeze(torch.max(weighted_log_prob, 1)[1].type(torch.LongTensor))\n",
    "\n",
    "\n",
    "    def predict_proba(self, x):\n",
    "        \"\"\"\n",
    "        Returns normalized probabilities of class membership.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, d) or (n, 1, d)\n",
    "        returns:\n",
    "            y:          torch.LongTensor (n)\n",
    "        \"\"\"\n",
    "        return self.predict(x, probs=True)\n",
    "\n",
    "\n",
    "    def sample(self, n):\n",
    "        \"\"\"\n",
    "        Samples from the model.\n",
    "        args:\n",
    "            n:          int\n",
    "        returns:\n",
    "            x:          torch.Tensor (n, d)\n",
    "            y:          torch.Tensor (n)\n",
    "        \"\"\"\n",
    "        counts = torch.distributions.multinomial.Multinomial(total_count=n, probs=self.pi.squeeze()).sample()\n",
    "        x = torch.empty(0, device=counts.device)\n",
    "        y = torch.cat([torch.full([int(sample)], j, device=counts.device) for j, sample in enumerate(counts)])\n",
    "\n",
    "        # Only iterate over components with non-zero counts\n",
    "        for k in np.arange(self.n_components)[counts > 0]: \n",
    "            if self.covariance_type == \"diag\":\n",
    "                x_k = self.mu[0, k] + torch.randn(int(counts[k]), self.n_features, device=x.device) * torch.sqrt(self.var[0, k])\n",
    "            elif self.covariance_type == \"full\":\n",
    "                d_k = torch.distributions.multivariate_normal.MultivariateNormal(self.mu[0, k], self.var[0, k])\n",
    "                x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))])\n",
    "\n",
    "            x = torch.cat((x, x_k), dim=0)\n",
    "\n",
    "        return x, y\n",
    "\n",
    "\n",
    "    def score_samples(self, x):\n",
    "        \"\"\"\n",
    "        Computes log-likelihood of samples under the current model.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, d) or (n, 1, d)\n",
    "        returns:\n",
    "            score:      torch.LongTensor (n)\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        score = self.__score(x, as_average=False)\n",
    "        return score\n",
    "\n",
    "\n",
    "    def _estimate_log_prob(self, x):\n",
    "        \"\"\"\n",
    "        Returns a tensor with dimensions (n, k, 1), which indicates the log-likelihood that samples belong to the k-th Gaussian.\n",
    "        args:\n",
    "            x:            torch.Tensor (n, d) or (n, 1, d)\n",
    "        returns:\n",
    "            log_prob:     torch.Tensor (n, k, 1)\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        if self.covariance_type == \"full\":\n",
    "            mu = self.mu\n",
    "            var = self.var\n",
    "\n",
    "            precision = torch.inverse(var)\n",
    "            d = x.shape[-1]\n",
    "\n",
    "            log_2pi = d * np.log(2. * pi)\n",
    "\n",
    "            log_det = self._calculate_log_det(precision)\n",
    "\n",
    "            x_mu_T = (x - mu).unsqueeze(-2)\n",
    "            x_mu = (x - mu).unsqueeze(-1)\n",
    "\n",
    "            x_mu_T_precision = calculate_matmul_n_times(self.n_components, x_mu_T, precision)\n",
    "            x_mu_T_precision_x_mu = calculate_matmul(x_mu_T_precision, x_mu)\n",
    "\n",
    "            return -.5 * (log_2pi - log_det + x_mu_T_precision_x_mu)\n",
    "\n",
    "        elif self.covariance_type == \"diag\":\n",
    "            mu = self.mu\n",
    "            prec = torch.rsqrt(self.var)\n",
    "\n",
    "            log_p = torch.sum((mu * mu + x * x - 2 * x * mu) * prec, dim=2, keepdim=True)\n",
    "            log_det = torch.sum(torch.log(prec), dim=2, keepdim=True)\n",
    "\n",
    "            return -.5 * (self.n_features * np.log(2. * pi) + log_p - log_det)\n",
    "\n",
    "\n",
    "    def _calculate_log_det(self, var):\n",
    "        \"\"\"\n",
    "        Calculate log determinant in log space, to prevent overflow errors.\n",
    "        args:\n",
    "            var:            torch.Tensor (1, k, d, d)\n",
    "        \"\"\"\n",
    "        log_det = torch.empty(size=(self.n_components,)).to(var.device)\n",
    "        \n",
    "        for k in range(self.n_components):\n",
    "            log_det[k] = 2 * torch.log(torch.diagonal(torch.linalg.cholesky(var[0,k]))).sum()\n",
    "\n",
    "        return log_det.unsqueeze(-1)\n",
    "\n",
    "\n",
    "    def _e_step(self, x):\n",
    "        \"\"\"\n",
    "        Computes log-responses that indicate the (logarithmic) posterior belief (sometimes called responsibilities) that a data point was generated by one of the k mixture components.\n",
    "        Also returns the mean of the mean of the logarithms of the probabilities (as is done in sklearn).\n",
    "        This is the so-called expectation step of the EM-algorithm.\n",
    "        args:\n",
    "            x:              torch.Tensor (n, d) or (n, 1, d)\n",
    "        returns:\n",
    "            log_prob_norm:  torch.Tensor (1)\n",
    "            log_resp:       torch.Tensor (n, k, 1)\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)\n",
    "\n",
    "        log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True)\n",
    "        log_resp = weighted_log_prob - log_prob_norm\n",
    "\n",
    "        return torch.mean(log_prob_norm), log_resp\n",
    "\n",
    "\n",
    "    def _m_step(self, x, log_resp):\n",
    "        \"\"\"\n",
    "        From the log-probabilities, computes new parameters pi, mu, var (that maximize the log-likelihood). This is the maximization step of the EM-algorithm.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, d) or (n, 1, d)\n",
    "            log_resp:   torch.Tensor (n, k, 1)\n",
    "        returns:\n",
    "            pi:         torch.Tensor (1, k, 1)\n",
    "            mu:         torch.Tensor (1, k, d)\n",
    "            var:        torch.Tensor (1, k, d)\n",
    "        \"\"\"\n",
    "        x = self.check_size(x)\n",
    "\n",
    "        resp = torch.exp(log_resp)\n",
    "\n",
    "        pi = torch.sum(resp, dim=0, keepdim=True) + self.eps\n",
    "        mu = torch.sum(resp * x, dim=0, keepdim=True) / pi\n",
    "\n",
    "        if self.covariance_type == \"full\":\n",
    "            eps = (torch.eye(self.n_features) * self.eps).to(x.device)\n",
    "            var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0,\n",
    "                            keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps\n",
    "\n",
    "        elif self.covariance_type == \"diag\":\n",
    "            x2 = (resp * x * x).sum(0, keepdim=True) / pi\n",
    "            mu2 = mu * mu\n",
    "            xmu = (resp * mu * x).sum(0, keepdim=True) / pi\n",
    "            var = x2 - 2 * xmu + mu2 + self.eps\n",
    "\n",
    "        pi = pi / x.shape[0]\n",
    "\n",
    "        return pi, mu, var\n",
    "\n",
    "\n",
    "    def __em(self, x):\n",
    "        \"\"\"\n",
    "        Performs one iteration of the expectation-maximization algorithm by calling the respective subroutines.\n",
    "        args:\n",
    "            x:          torch.Tensor (n, 1, d)\n",
    "        \"\"\"\n",
    "        _, log_resp = self._e_step(x)\n",
    "        pi, mu, var = self._m_step(x, log_resp)\n",
    "\n",
    "        self.__update_pi(pi)\n",
    "        self.__update_mu(mu)\n",
    "        self.__update_var(var)\n",
    "\n",
    "\n",
    "    def __score(self, x, as_average=True):\n",
    "        \"\"\"\n",
    "        Computes the log-likelihood of the data under the model.\n",
    "        args:\n",
    "            x:                  torch.Tensor (n, 1, d)\n",
    "            sum_data:           bool\n",
    "        returns:\n",
    "            score:              torch.Tensor (1)\n",
    "            (or)\n",
    "            per_sample_score:   torch.Tensor (n)\n",
    "\n",
    "        \"\"\"\n",
    "        weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi)\n",
    "        per_sample_score = torch.logsumexp(weighted_log_prob, dim=1)\n",
    "\n",
    "        if as_average:\n",
    "            return per_sample_score.mean()\n",
    "        else:\n",
    "            return torch.squeeze(per_sample_score)\n",
    "\n",
    "\n",
    "    def __update_mu(self, mu):\n",
    "        \"\"\"\n",
    "        Updates mean to the provided value.\n",
    "        args:\n",
    "            mu:         torch.FloatTensor\n",
    "        \"\"\"\n",
    "        assert mu.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], \"Input mu does not have required tensor dimensions (%i, %i) or (1, %i, %i)\" % (self.n_components, self.n_features, self.n_components, self.n_features)\n",
    "\n",
    "        if mu.size() == (self.n_components, self.n_features):\n",
    "            self.mu = mu.unsqueeze(0)\n",
    "        elif mu.size() == (1, self.n_components, self.n_features):\n",
    "            self.mu.data = mu\n",
    "\n",
    "\n",
    "    def __update_var(self, var):\n",
    "        \"\"\"\n",
    "        Updates variance to the provided value.\n",
    "        args:\n",
    "            var:        torch.FloatTensor\n",
    "        \"\"\"\n",
    "        if self.covariance_type == \"full\":\n",
    "            assert var.size() in [(self.n_components, self.n_features, self.n_features), (1, self.n_components, self.n_features, self.n_features)], \"Input var does not have required tensor dimensions (%i, %i, %i) or (1, %i, %i, %i)\" % (self.n_components, self.n_features, self.n_features, self.n_components, self.n_features, self.n_features)\n",
    "\n",
    "            if var.size() == (self.n_components, self.n_features, self.n_features):\n",
    "                self.var = var.unsqueeze(0)\n",
    "            elif var.size() == (1, self.n_components, self.n_features, self.n_features):\n",
    "                self.var.data = var\n",
    "\n",
    "        elif self.covariance_type == \"diag\":\n",
    "            assert var.size() in [(self.n_components, self.n_features), (1, self.n_components, self.n_features)], \"Input var does not have required tensor dimensions (%i, %i) or (1, %i, %i)\" % (self.n_components, self.n_features, self.n_components, self.n_features)\n",
    "\n",
    "            if var.size() == (self.n_components, self.n_features):\n",
    "                self.var = var.unsqueeze(0)\n",
    "            elif var.size() == (1, self.n_components, self.n_features):\n",
    "                self.var.data = var\n",
    "\n",
    "\n",
    "    def __update_pi(self, pi):\n",
    "        \"\"\"\n",
    "        Updates pi to the provided value.\n",
    "        args:\n",
    "            pi:         torch.FloatTensor\n",
    "        \"\"\"\n",
    "        assert pi.size() in [(1, self.n_components, 1)], \"Input pi does not have required tensor dimensions (%i, %i, %i)\" % (1, self.n_components, 1)\n",
    "\n",
    "        self.pi.data = pi\n",
    "\n",
    "\n",
    "    def get_kmeans_mu(self, x, n_centers, init_times=50, min_delta=1e-3):\n",
    "        \"\"\"\n",
    "        Find an initial value for the mean. Requires a threshold min_delta for the k-means algorithm to stop iterating.\n",
    "        The algorithm is repeated init_times often, after which the best centerpoint is returned.\n",
    "        args:\n",
    "            x:            torch.FloatTensor (n, d) or (n, 1, d)\n",
    "            init_times:   init\n",
    "            min_delta:    int\n",
    "        \"\"\"\n",
    "        if len(x.size()) == 3:\n",
    "            x = x.squeeze(1)\n",
    "        x_min, x_max = x.min(), x.max()\n",
    "        x = (x - x_min) / (x_max - x_min)\n",
    "        \n",
    "        min_cost = np.inf\n",
    "\n",
    "        for i in range(init_times):\n",
    "            tmp_center = x[np.random.choice(np.arange(x.shape[0]), size=n_centers, replace=False), ...]\n",
    "            l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - tmp_center), p=2, dim=2)\n",
    "            l2_cls = torch.argmin(l2_dis, dim=1)\n",
    "\n",
    "            cost = 0\n",
    "            for c in range(n_centers):\n",
    "                cost += torch.norm(x[l2_cls == c] - tmp_center[c], p=2, dim=1).mean()\n",
    "\n",
    "            if cost < min_cost:\n",
    "                min_cost = cost\n",
    "                center = tmp_center\n",
    "\n",
    "        delta = np.inf\n",
    "\n",
    "        while delta > min_delta:\n",
    "            l2_dis = torch.norm((x.unsqueeze(1).repeat(1, n_centers, 1) - center), p=2, dim=2)\n",
    "            l2_cls = torch.argmin(l2_dis, dim=1)\n",
    "            center_old = center.clone()\n",
    "\n",
    "            for c in range(n_centers):\n",
    "                center[c] = x[l2_cls == c].mean(dim=0)\n",
    "\n",
    "            delta = torch.norm((center_old - center), dim=1).max()\n",
    "\n",
    "        return (center.unsqueeze(0)*(x_max - x_min) + x_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1cb2573",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "# Number of components for each GMM\n",
    "n_components_P = args.num_base_classes - len(args.cls_drop)\n",
    "n_components_Q = args.num_base_classes\n",
    "\n",
    "# Fit GMM to P\n",
    "gmm_P = GaussianMixture(n_components_P, N_COMPONENTS, covariance_type=\"full\").to('cuda')\n",
    "gmm_P.fit(P, n_iter=1000, delta=1e-6)\n",
    "\n",
    "# Fit GMM to Q\n",
    "gmm_Q = GaussianMixture(n_components_Q, N_COMPONENTS, covariance_type=\"full\").to('cuda')\n",
    "gmm_Q.fit(Q, n_iter=1000, delta=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a26e34c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from matplotlib import rcParams\n",
    "\n",
    "def simple_gmm_visualization(P, Q, gmm_P, gmm_Q):\n",
    "    \"\"\"\n",
    "    Simple visualization of 2D data and GMM contours.\n",
    "    \"\"\"\n",
    "    # Convert tensors to numpy if needed\n",
    "    if isinstance(P, torch.Tensor):\n",
    "        P = P.cpu().numpy()\n",
    "    if isinstance(Q, torch.Tensor):\n",
    "        Q = Q.cpu().numpy()\n",
    "    \n",
    "    # Create a mesh grid for contours\n",
    "    x_min, x_max = min(P[:, 0].min(), Q[:, 0].min()) - 1, max(P[:, 0].max(), Q[:, 0].max()) + 1\n",
    "    y_min, y_max = min(P[:, 1].min(), Q[:, 1].min()) - 1, max(P[:, 1].max(), Q[:, 1].max()) + 1\n",
    "    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200))\n",
    "    mesh_points = np.c_[xx.ravel(), yy.ravel()]\n",
    "    \n",
    "    # Convert mesh points to tensor and move to the same device as the GMMs\n",
    "    device = next(gmm_P.parameters()).device\n",
    "    mesh_tensor = torch.tensor(mesh_points, dtype=torch.float32).to(device)\n",
    "    \n",
    "    # Get densities using score_samples method\n",
    "    P_log_density = gmm_P.score_samples(mesh_tensor).reshape(xx.shape).cpu()\n",
    "    Q_log_density = gmm_Q.score_samples(mesh_tensor).reshape(xx.shape).cpu()\n",
    "    \n",
    "    # Convert log densities to regular densities\n",
    "    P_density = np.exp(P_log_density)\n",
    "    Q_density = np.exp(Q_log_density)\n",
    "    \n",
    "    # Create figure with two subplots\n",
    "    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 13.5))\n",
    "\n",
    "    rcParams['font.family'] = 'Linux Libertine O'\n",
    "    rcParams['font.size']=30\n",
    "\n",
    "    # Fix same colorscale for both plots\n",
    "    colors = plt.cm.tab10(np.linspace(0, 1, args.num_base_classes))\n",
    "    \n",
    "    # Plot P data and GMM\n",
    "    for i in seen_classes:\n",
    "        ax1.scatter(P[targets[P_indices] == i, 0], P[targets[P_indices] == i, 1], s=10, alpha=0.3, label=f'Class {i}', color=colors[i], rasterized=True)\n",
    "    # ax1.scatter(P[:, 0], P[:, 1], s=10, alpha=0.6, c='blue', label='P Data')\n",
    "    ax1.contour(xx, yy, P_density, levels=10, colors='black', alpha=0.8)\n",
    "    ax1.set_title('Training Subset', fontsize=40)\n",
    "    # ax1.set_xlabel('PCA Component 1')\n",
    "    # ax1.set_ylabel('PCA Component 2')\n",
    "    ax1.grid(True, alpha=0.3)\n",
    "    leg = ax1.legend(loc='upper left', handletextpad=0.3, borderaxespad=0.1, handlelength=0.5, labelspacing=0.2)\n",
    "    ax1.set_xlim(-70,40)\n",
    "    ax1.tick_params(axis='both', which='major')\n",
    "    for handle in leg.legendHandles:\n",
    "        handle.set_sizes([120])\n",
    "        handle.set_alpha(0.8)\n",
    "    \n",
    "    # Plot Q data and GMM\n",
    "    for i in seen_classes:\n",
    "        ax2.scatter(Q[targets[Q_indices] == i, 0], Q[targets[Q_indices] == i, 1], s=10, alpha=0.3, label=f'Class {i}', color=colors[i], rasterized=True)\n",
    "    for i in args.cls_drop:\n",
    "        ax2.scatter(Q[targets[Q_indices] == i, 0], Q[targets[Q_indices] == i, 1], s=10, alpha=0.3, label=f'Class {i}', color=colors[i], rasterized=True)\n",
    "    # ax2.scatter(Q[:, 0], Q[:, 1], s=10, alpha=0.6, c='red', label='Q Data')\n",
    "    ax2.contour(xx, yy, Q_density, levels=10, colors='black', alpha=0.8)\n",
    "    ax2.set_title('Full Dataset', fontsize=40)\n",
    "    # ax2.set_xlabel('PCA Component 1')\n",
    "    # ax2.set_ylabel('PCA Component 2')\n",
    "    ax2.grid(True, alpha=0.3)\n",
    "    leg = ax2.legend(loc='upper left', handletextpad=0.3, borderaxespad=0.1, handlelength=0.5, labelspacing=0.2, )\n",
    "    ax2.set_xlim(-70,40)\n",
    "    ax2.tick_params(axis='both', which='both')\n",
    "    for handle in leg.legendHandles:\n",
    "        handle.set_sizes([120])\n",
    "        handle.set_alpha(0.8)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(os.path.join(args.attack_plots_path, \"gmm_visualization.png\"), dpi=300, bbox_inches='tight')  \n",
    "    \n",
    "# Usage example\n",
    "simple_gmm_visualization(P, Q, gmm_P, gmm_Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "180e2e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_density_ratio(gmm_P, gmm_Q, X):\n",
    "    \"\"\"\n",
    "    Compute the density ratio P(x)/Q(x) for points in X\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    gmm_P: GaussianMixture\n",
    "        The numerator GMM (P)\n",
    "    gmm_Q: GaussianMixture\n",
    "        The denominator GMM (Q)\n",
    "    X: torch.Tensor\n",
    "        Points at which to evaluate the density ratio (n, d) or (n, 1, d)\n",
    "        \n",
    "    Returns:\n",
    "    --------\n",
    "    torch.Tensor\n",
    "        Density ratio values for each point in X (n)\n",
    "    \"\"\"\n",
    "    # Get log probabilities from each model\n",
    "    # The score_samples method returns log probabilities\n",
    "    log_prob_P = gmm_P.score_samples(X)  # Returns (n)\n",
    "    log_prob_Q = gmm_Q.score_samples(X)  # Returns (n)\n",
    "    \n",
    "    # Compute ratio in log space first (for numerical stability)\n",
    "    log_ratio = log_prob_P - log_prob_Q\n",
    "    \n",
    "    # Convert to normal space\n",
    "    ratio = torch.exp(log_ratio)\n",
    "\n",
    "    return ratio\n",
    "\n",
    "density_ratios = compute_density_ratio(gmm_P, gmm_Q, Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03b00ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Median density ratio for seen classes:\", torch.median(density_ratios[is_seen]).item())\n",
    "print(\"Median density ratio for unseen classes:\", torch.median(density_ratios[~is_seen]).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6145a83c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "sns.histplot(density_ratios.cpu().numpy(), bins=50, kde=True)\n",
    "plt.title('Histogram of Density Ratios')\n",
    "plt.xlabel('Density Ratio (P(x)/Q(x))')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2581a62e",
   "metadata": {},
   "source": [
    "Fit a Linear Regressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e511bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "\n",
    "# Prepare training data\n",
    "X_train = scaled_embeds\n",
    "y_train = density_ratios.cpu().numpy()\n",
    "# Initialize and fit the linear regression model\n",
    "lin_reg = LinearRegression()\n",
    "lin_reg.fit(X_train, y_train)\n",
    "# Print the training accuracy\n",
    "y_pred = lin_reg.predict(X_train)\n",
    "print(\"Training R-Squared:\", r2_score(y_train, y_pred))\n",
    "print(\"Training MSE:\", mean_squared_error(y_train, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2db0da24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Compute prediction error: positive if prediction > ground truth, negative otherwise\n",
    "errors = y_pred - y_train  # shape: (n_samples,)\n",
    "\n",
    "# Reduce X_train to 2 dimensions for plotting\n",
    "X_2d = Q.cpu().numpy()  # shape: (n_samples, 2)\n",
    "\n",
    "# Color: red if error is positive, blue if negative\n",
    "colors = ['green' if abs(err) < 0.05 else 'indianred' if err >= 0 else 'steelblue' for err in errors]\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.scatter(X_2d[:, 0], X_2d[:, 1], c=colors, s=5, alpha=0.3)\n",
    "plt.title('2D Visualization of X_train Colored by Prediction Error')\n",
    "plt.xlabel('PCA Component 1')\n",
    "plt.ylabel('PCA Component 2')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qmiaenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
