{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d028510e",
   "metadata": {},
   "source": [
    "# Mixture of Gaussians Experiment Suite (Unlearning Context)\n",
    "\n",
    "This notebook defines a modular experiment suite for a mixture of two isotropic Gaussians in $\\mathbb{R}^d$, with editable parameters for means, variances, mixture weights, and sample size. The two components represent the **forget set** ($p_f$) and **retain set** ($p_r$), as used in machine unlearning experiments. It samples data, applies a quadratic feature map, and fits logistic regression.\n",
    "\n",
    "The notebook uses an OOP design for easy experimentation, debugging, and parameter sweeps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f70f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports and GaussianMixtureUnlearning class definition\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.metrics import accuracy_score\n",
    "from scipy.stats import multivariate_normal\n",
    "from matplotlib.patches import Ellipse\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import os\n",
    "\n",
    "class GaussianMixtureUnlearning:\n",
    "    \"\"\"\n",
    "    A class for experimenting with Gaussian mixture models in the context of machine unlearning.\n",
    "    \n",
    "    The mixture consists of a forget set (p_f) and a retain set (p_r).\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, d=1, mean_sep=4.0, var_f=1.0, var_r=1.0, gamma=0.5, n=500, seed=613, l2_reg_coeff=1.0, T=1.0):\n",
    "        \"\"\"\n",
    "        Initialize the Gaussian mixture model parameters.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        d : int\n",
    "            Dimension of the data\n",
    "        mean_sep : float\n",
    "            Separation between means of the two components\n",
    "        var_f : float\n",
    "            Variance of the forget set (p_f)\n",
    "        var_r : float\n",
    "            Variance of the retain set (p_r)\n",
    "        gamma : float\n",
    "            Mixture weight for the forget set (p_f)\n",
    "        n : int\n",
    "            Number of samples to generate\n",
    "        seed : int\n",
    "            Random seed for reproducibility\n",
    "        l2_reg_coeff : float\n",
    "            L2 regularization coefficient\n",
    "        T: float\n",
    "            Base model temperature\n",
    "        \"\"\"\n",
    "        self.d = d\n",
    "        self.mean_sep = mean_sep\n",
    "        self.var_f = var_f\n",
    "        self.var_r = var_r\n",
    "        self.gamma = gamma\n",
    "        self.n = n\n",
    "        self.seed = seed\n",
    "        self.l2_reg_coeff = l2_reg_coeff\n",
    "        assert T >= 1.0, \"Temperature T must be >= 1.0\"\n",
    "        self.T = T\n",
    "        self.train_gamma = gamma  # mixture weight used for sampled training data\n",
    "        \n",
    "        # Set random seed\n",
    "        np.random.seed(self.seed)\n",
    "        \n",
    "        # Initialize means and covariances\n",
    "        self.mu_f = np.zeros(d)\n",
    "        self.mu_r = np.zeros(d)\n",
    "        self.mu_r[0] = mean_sep\n",
    "        self.cov_f = np.eye(d) * var_f\n",
    "        self.cov_r = np.eye(d) * var_r\n",
    "        \n",
    "        # Initialize data containers\n",
    "        self.X_f = None\n",
    "        self.X_r = None\n",
    "        self.X = None\n",
    "        self.y = None\n",
    "        self.n_f = None\n",
    "        self.n_r = None\n",
    "        \n",
    "        # Initialize model components\n",
    "        self.poly = PolynomialFeatures(degree=2, include_bias=True)\n",
    "        self.X_quad = None\n",
    "        self.clf = None\n",
    "        \n",
    "        # Initialize metrics\n",
    "        self.accuracy = None\n",
    "        self.kl_divergence = None\n",
    "        self.mae = None\n",
    "        self.pop_risk_learned = None\n",
    "        self.pop_risk_bayes = None\n",
    "        self.excess_risk = None\n",
    "        self.partition_function_estimate = None\n",
    "        \n",
    "    def sample_data(self, train_gamma=None, sample=True):\n",
    "        \"\"\"Sample `n` points using an optional training gamma (defaults to the true gamma).\"\"\"\n",
    "        effective_gamma = self.gamma if train_gamma is None else train_gamma\n",
    "        self.train_gamma = effective_gamma\n",
    "        if sample:\n",
    "            labels = np.random.binomial(1, 1-effective_gamma, size=self.n)\n",
    "        else:\n",
    "            num_f = np.ceil(self.n * effective_gamma).astype(int)\n",
    "            labels = np.array([0]*num_f + [1]*(self.n - num_f))\n",
    "        self.n_f = np.sum(labels == 0)\n",
    "        self.n_r = np.sum(labels == 1)\n",
    "        \n",
    "        self.X_f = np.random.multivariate_normal(self.mu_f, self.cov_f, self.n_f)\n",
    "        self.X_r = np.random.multivariate_normal(self.mu_r, self.cov_r, self.n_r)\n",
    "        \n",
    "        self.X = np.vstack([self.X_f, self.X_r])\n",
    "        self.y = np.array([0]*self.n_f + [1]*self.n_r)\n",
    "        \n",
    "        return self.X, self.y\n",
    "    \n",
    "    def visualize_mixture(self, show_plot=True):\n",
    "        \"\"\"\n",
    "        Visualize the mixture and samples (for d=1 or d=2).\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        show_plot : bool\n",
    "            Whether to display the plot\n",
    "        \"\"\"\n",
    "        if self.X is None:\n",
    "            raise ValueError(\"Must call sample_data() first\")\n",
    "        \n",
    "        if self.d == 2:\n",
    "            fig, ax = plt.subplots(figsize=(6,6))\n",
    "            ax.scatter(self.X_f[:,0], self.X_f[:,1], alpha=0.5, label='Forget set (p_f)', color='blue')\n",
    "            ax.scatter(self.X_r[:,0], self.X_r[:,1], alpha=0.5, label='Retain set (p_r)', color='red')\n",
    "            self._plot_gaussian_ellipse(self.mu_f, self.cov_f, ax, color='blue', label='Gaussian p_f')\n",
    "            self._plot_gaussian_ellipse(self.mu_r, self.cov_r, ax, color='red', label='Gaussian p_r')\n",
    "            ax.set_title('Samples from Mixture of Gaussians')\n",
    "            ax.legend()\n",
    "            if show_plot:\n",
    "                plt.show()\n",
    "            return fig, ax\n",
    "        elif self.d == 1:\n",
    "            fig, ax = plt.subplots(figsize=(7,3))\n",
    "            ax.hist(self.X_f[:,0], bins=30, alpha=0.5, label='Forget set (p_f)', color='blue', density=True)\n",
    "            ax.hist(self.X_r[:,0], bins=30, alpha=0.5, label='Retain set (p_r)', color='red', density=True)\n",
    "            xs = np.linspace(min(self.X[:,0])-2, max(self.X[:,0])+2, 300).reshape(-1, 1)\n",
    "            p_f = self.gamma * multivariate_normal.pdf(xs, mean=self.mu_f, cov=self.cov_f)\n",
    "            p_r = (1-self.gamma) * multivariate_normal.pdf(xs, mean=self.mu_r, cov=self.cov_r)\n",
    "            ax.plot(xs, p_f, color='blue', lw=2, label='True Gaussian p_f')\n",
    "            ax.plot(xs, p_r, color='red', lw=2, label='True Gaussian p_r')\n",
    "            ax.plot(xs, p_f+p_r, color='black', lw=2, label='Mixture')\n",
    "            ax.set_title('Samples from Mixture of Gaussians (1D)')\n",
    "            ax.legend()\n",
    "            if show_plot:\n",
    "                plt.show()\n",
    "            return fig, ax\n",
    "        else:\n",
    "            print(f\"Visualization not available for d={self.d}\")\n",
    "            return None, None\n",
    "    \n",
    "    def apply_feature_map(self):\n",
    "        \"\"\"Apply quadratic feature map to the data.\"\"\"\n",
    "        if self.X is None:\n",
    "            raise ValueError(\"Must call sample_data() first\")\n",
    "        \n",
    "        self.X_quad = self.poly.fit_transform(self.X)\n",
    "        return self.X_quad\n",
    "    \n",
    "    def train_classifier(self):\n",
    "        \"\"\"Train logistic regression classifier with L2 regularization.\"\"\"\n",
    "        if self.X_quad is None:\n",
    "            raise ValueError(\"Must call apply_feature_map() first\")\n",
    "        \n",
    "        self.clf = LogisticRegression(C=1/self.l2_reg_coeff, l1_ratio=0, solver='lbfgs', max_iter=1000)\n",
    "        self.clf.fit(self.X_quad, self.y)\n",
    "        \n",
    "        y_pred = self.clf.predict(self.X_quad)\n",
    "        self.accuracy = accuracy_score(self.y, y_pred)\n",
    "        \n",
    "        return self.clf, self.accuracy\n",
    "    \n",
    "    def visualize_classifier(self, show_plot=True):\n",
    "        \"\"\"\n",
    "        Visualize the learned classification rule (for d=1 or d=2).\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        show_plot : bool\n",
    "            Whether to display the plot\n",
    "        \"\"\"\n",
    "        if self.clf is None:\n",
    "            raise ValueError(\"Must call train_classifier() first\")\n",
    "        \n",
    "        if self.d == 2:\n",
    "            xx, yy = np.meshgrid(\n",
    "                np.linspace(self.X[:,0].min()-1, self.X[:,0].max()+1, 200),\n",
    "                np.linspace(self.X[:,1].min()-1, self.X[:,1].max()+1, 200)\n",
    "            )\n",
    "            grid = np.c_[xx.ravel(), yy.ravel()]\n",
    "            grid_quad = self.poly.transform(grid)\n",
    "            probs = self.clf.predict_proba(grid_quad)[:,1].reshape(xx.shape)\n",
    "            \n",
    "            fig, ax = plt.subplots(figsize=(6,6))\n",
    "            ax.contourf(xx, yy, probs, levels=20, cmap='RdBu', alpha=0.5)\n",
    "            ax.scatter(self.X_f[:,0], self.X_f[:,1], alpha=0.5, label='Forget set (p_f)', color='blue')\n",
    "            ax.scatter(self.X_r[:,0], self.X_r[:,1], alpha=0.5, label='Retain set (p_r)', color='red')\n",
    "            ax.contour(xx, yy, probs, levels=[0.5], colors='k', linewidths=2)\n",
    "            ax.set_title('Learned Decision Boundary')\n",
    "            ax.legend()\n",
    "            if show_plot:\n",
    "                plt.show()\n",
    "            return fig, ax\n",
    "        elif self.d == 1:\n",
    "            xs = np.linspace(self.X[:,0].min()-2, self.X[:,0].max()+2, 500).reshape(-1,1)\n",
    "            xs_quad = self.poly.transform(xs)\n",
    "            probs = self.clf.predict_proba(xs_quad)[:,1]\n",
    "            \n",
    "            fig, ax = plt.subplots(figsize=(7,3))\n",
    "            ax.hist(self.X_f[:,0], bins=30, alpha=0.5, label='Forget set (p_f)', color='blue', density=True)\n",
    "            ax.hist(self.X_r[:,0], bins=30, alpha=0.5, label='Retain set (p_r)', color='red', density=True)\n",
    "            ax.plot(xs, probs, color='black', lw=2, label='Logistic Regression $P(y=1|x)$')\n",
    "            ax.set_title('Learned Classification Rule (1D)')\n",
    "            ax.legend()\n",
    "            if show_plot:\n",
    "                plt.show()\n",
    "            return fig, ax\n",
    "        else:\n",
    "            print(f\"Visualization not available for d={self.d}\")\n",
    "            return None, None\n",
    "    \n",
    "    def estimate_retain_density(self, show_plot=True):\n",
    "        \"\"\"\n",
    "        Estimate retain set density using the classifier (for d=1 only).\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        show_plot : bool\n",
    "            Whether to display the plot\n",
    "        \"\"\"\n",
    "        if self.clf is None:\n",
    "            raise ValueError(\"Must call train_classifier() first\")\n",
    "        \n",
    "        if self.d != 1:\n",
    "            print(f\"Density estimation visualization only available for d=1, current d={self.d}\")\n",
    "            return None, None, None\n",
    "        \n",
    "        xs = np.linspace(self.X[:,0].min()-2, self.X[:,0].max()+2, 500).reshape(-1, 1)\n",
    "        p_f = multivariate_normal.pdf(xs, mean=self.mu_f, cov=self.cov_f)\n",
    "        p_r = multivariate_normal.pdf(xs, mean=self.mu_r, cov=self.cov_r)\n",
    "        mixture_density = self.gamma*p_f + (1-self.gamma)*p_r\n",
    "        \n",
    "        xs_quad = self.poly.transform(xs)\n",
    "        classifier_prob = self.clf.predict_proba(xs_quad)[:,1]\n",
    "\n",
    "        self.compute_partition_function_estimate()\n",
    "        estimated_p_r = mixture_density**(1/self.T) * classifier_prob / self.partition_function_estimate\n",
    "        \n",
    "        if show_plot:\n",
    "            fig, ax = plt.subplots(figsize=(7,3))\n",
    "            ax.plot(xs, p_r, color='red', lw=2, label='True Retain Set Density (p_r)')\n",
    "            ax.plot(xs, estimated_p_r, color='black', lw=2, label='Estimated Retain Set Density')\n",
    "            ax.plot(xs, mixture_density, color='gray', lw=1, linestyle='--', label='Mixture Density')\n",
    "            ax.hist(self.X_r[:,0], bins=30, alpha=0.3, color='red', density=True, label='Retain Set Samples (p_r)')\n",
    "            ax.set_title('True vs Estimated Retain Set Density (1D)')\n",
    "            ax.legend()\n",
    "            plt.show()\n",
    "            return xs, p_r, estimated_p_r\n",
    "        \n",
    "        return xs, p_r, estimated_p_r\n",
    "    \n",
    "    def compute_bayes_optimal_prob(self, X):\n",
    "        \"\"\"\n",
    "        Compute Bayes optimal probability P(y=1|x) for given inputs.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        X : np.ndarray\n",
    "            Input samples\n",
    "        \n",
    "        Returns:\n",
    "        --------\n",
    "        np.ndarray\n",
    "            Bayes optimal probabilities\n",
    "        \"\"\"\n",
    "        p_f = multivariate_normal.pdf(X, mean=self.mu_f, cov=self.cov_f)\n",
    "        p_r = multivariate_normal.pdf(X, mean=self.mu_r, cov=self.cov_r)\n",
    "        bayes_prob = (1 - self.gamma) * p_r / (self.gamma * p_f + (1 - self.gamma) * p_r)\n",
    "        return bayes_prob\n",
    "    \n",
    "    def compute_population_risk(self, n_mc=10000, eps=0.0):\n",
    "        \"\"\"\n",
    "        Compute population risk (binary cross-entropy) for both learned and Bayes classifiers.\n",
    "        \n",
    "        Uses Monte Carlo sampling from the mixture distribution.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        n_mc : int\n",
    "            Number of Monte Carlo samples\n",
    "        \n",
    "        Returns:\n",
    "        --------\n",
    "        pop_risk_learned : float\n",
    "            Population risk of the learned classifier\n",
    "        pop_risk_bayes : float\n",
    "            Population risk of the Bayes optimal classifier\n",
    "        excess_risk : float\n",
    "            Excess risk = pop_risk_learned - pop_risk_bayes\n",
    "        \"\"\"\n",
    "        if self.clf is None:\n",
    "            raise ValueError(\"Must call train_classifier() first\")\n",
    "        \n",
    "        # Sample from mixture\n",
    "        labels = np.random.binomial(1, 1-self.gamma, size=n_mc)\n",
    "        n_f_mc = np.sum(labels == 0)\n",
    "        n_r_mc = np.sum(labels == 1)\n",
    "        \n",
    "        X_mc = np.vstack([\n",
    "            np.random.multivariate_normal(self.mu_f, self.cov_f, n_f_mc),\n",
    "            np.random.multivariate_normal(self.mu_r, self.cov_r, n_r_mc)\n",
    "        ])\n",
    "        y_mc = np.array([0]*n_f_mc + [1]*n_r_mc)\n",
    "        \n",
    "        # Learned classifier probabilities\n",
    "        X_mc_quad = self.poly.transform(X_mc)\n",
    "        p_learned = self.clf.predict_proba(X_mc_quad)[:,1]\n",
    "        \n",
    "        # Bayes optimal probabilities\n",
    "        p_bayes = self.compute_bayes_optimal_prob(X_mc)\n",
    "        \n",
    "        # Binary cross-entropy loss\n",
    "        bce_learned = -y_mc * np.log(p_learned + eps) - (1 - y_mc) * np.log(1 - p_learned + eps)\n",
    "        bce_bayes = -y_mc * np.log(p_bayes + eps) - (1 - y_mc) * np.log(1 - p_bayes + eps)\n",
    "        \n",
    "        self.pop_risk_learned = np.mean(bce_learned)\n",
    "        self.pop_risk_bayes = np.mean(bce_bayes)\n",
    "        self.excess_risk = self.pop_risk_learned - self.pop_risk_bayes\n",
    "        \n",
    "        return self.pop_risk_learned, self.pop_risk_bayes, self.excess_risk\n",
    "\n",
    "    def compute_partition_function_estimate(self, num_samples=10000, overwrite=False):\n",
    "        if self.clf is None:\n",
    "            raise ValueError(\"Must call train_classifier() first\")\n",
    "\n",
    "        if self.partition_function_estimate is not None and not overwrite:\n",
    "            print(\"Warning: partition_function_estimate already computed. Set overwrite=True to recompute.\")\n",
    "\n",
    "        forget_weight_unnorm = self.gamma**(1/self.T)*(self.T**(self.d/2))*(2*np.pi*self.var_f)**(self.d*(1-1/self.T)/2)\n",
    "        retain_weight_unnorm = (1-self.gamma)**(1/self.T)*(self.T**(self.d/2))*(2*np.pi*self.var_r)**(self.d*(1-1/self.T)/2)\n",
    "        retain_weight = retain_weight_unnorm / (forget_weight_unnorm + retain_weight_unnorm)\n",
    "        forget_weight = 1 - retain_weight\n",
    "\n",
    "        labels = np.random.binomial(1, retain_weight, size=num_samples)\n",
    "        n_f_pf = np.sum(labels == 0)\n",
    "        n_r_pf = np.sum(labels == 1)\n",
    "        \n",
    "        X_pf = np.vstack([\n",
    "            np.random.multivariate_normal(self.mu_f, self.cov_f*self.T, n_f_pf),\n",
    "            np.random.multivariate_normal(self.mu_r, self.cov_r*self.T, n_r_pf)\n",
    "        ])\n",
    "        \n",
    "        # Get classifier probabilities at these points\n",
    "        X_pf_quad = self.poly.transform(X_pf)\n",
    "        classifier_prob_pf = self.clf.predict_proba(X_pf_quad)[:,1]\n",
    "        \n",
    "        # take mean of classifier probabilities at these points\n",
    "        self.partition_function_estimate = (forget_weight_unnorm + retain_weight_unnorm) * np.mean(classifier_prob_pf)\n",
    "\n",
    "    def evaluate_metrics(self, n_mc=10000, eps=1e-12):\n",
    "        \"\"\"\n",
    "        Evaluate KL divergence, MAE, and population risks using Monte Carlo sampling.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        n_mc : int\n",
    "            Number of Monte Carlo samples\n",
    "        eps : float\n",
    "            Small constant to avoid log(0)\n",
    "        \n",
    "        Returns:\n",
    "        --------\n",
    "        kl_divergence : float\n",
    "            KL(p_r || p_r_est)\n",
    "        mae : float\n",
    "            E_{x~p_f}[|p_r(x) - p_r_est(x)|]\n",
    "        pop_risk_learned : float\n",
    "            Population risk of learned classifier\n",
    "        pop_risk_bayes : float\n",
    "            Population risk of Bayes classifier\n",
    "        excess_risk : float\n",
    "            Excess risk of learned classifier\n",
    "        \"\"\"\n",
    "        if self.clf is None:\n",
    "            raise ValueError(\"Must call train_classifier() first\")\n",
    "        \n",
    "        # Estimate partition function (pf)\n",
    "        self.compute_partition_function_estimate(num_samples=n_mc, overwrite=True)\n",
    "\n",
    "        # KL divergence\n",
    "        x_mc_r = np.random.multivariate_normal(self.mu_r, self.cov_r, n_mc)\n",
    "        p_r_vals = multivariate_normal.pdf(x_mc_r, mean=self.mu_r, cov=self.cov_r)\n",
    "        p_f_vals = multivariate_normal.pdf(x_mc_r, mean=self.mu_f, cov=self.cov_f)\n",
    "        mixture_vals = self.gamma * p_f_vals + (1-self.gamma) * p_r_vals\n",
    "        xs_quad = self.poly.transform(x_mc_r)\n",
    "        classifier_prob = self.clf.predict_proba(xs_quad)[:,1]\n",
    "        p_r_est_vals = mixture_vals**(1/self.T) * classifier_prob / self.partition_function_estimate\n",
    "        \n",
    "        self.kl_divergence = np.mean(np.log(p_r_vals / np.maximum(p_r_est_vals,eps)))\n",
    "        \n",
    "        # MAE under p_f\n",
    "        x_mc_f = np.random.multivariate_normal(self.mu_f, self.cov_f, n_mc)\n",
    "        p_r_vals_f = multivariate_normal.pdf(x_mc_f, mean=self.mu_r, cov=self.cov_r)\n",
    "        p_f_vals_f = multivariate_normal.pdf(x_mc_f, mean=self.mu_f, cov=self.cov_f)\n",
    "        mixture_vals_f = self.gamma * p_f_vals_f + (1-self.gamma) * p_r_vals_f\n",
    "        xs_quad_f = self.poly.transform(x_mc_f)\n",
    "        classifier_prob_f = self.clf.predict_proba(xs_quad_f)[:,1]\n",
    "        p_r_est_vals_f = mixture_vals_f**(1/self.T) * classifier_prob_f / self.partition_function_estimate\n",
    "        \n",
    "        self.mae = np.mean(np.abs(p_r_vals_f - p_r_est_vals_f))\n",
    "        \n",
    "        # Population risks\n",
    "        self.compute_population_risk(n_mc=n_mc, eps=eps)\n",
    "        \n",
    "        return self.kl_divergence, self.mae, self.pop_risk_learned, self.pop_risk_bayes, self.excess_risk\n",
    "    \n",
    "    def run_full_pipeline(self, show_plots=True, n_mc=10000, eps=1e-12, train_gamma=None, sample_train_data=True):\n",
    "        \"\"\"\n",
    "        Run the complete pipeline; `train_gamma` controls the sampling weight used for classifier data.\n",
    "        \n",
    "        Parameters:\n",
    "        -----------\n",
    "        show_plots : bool\n",
    "            Whether to show plots (only for d=1 or d=2)\n",
    "        n_mc : int\n",
    "            Number of Monte Carlo samples for evaluation\n",
    "        train_gamma : float\n",
    "            Mixture weight for the forget set (p_f) used during classifier training\n",
    "        \n",
    "        Returns:\n",
    "        --------\n",
    "        dict\n",
    "            Dictionary containing all metrics\n",
    "        \"\"\"\n",
    "        self.sample_data(train_gamma=train_gamma, sample=sample_train_data)\n",
    "        if show_plots and self.d in [1, 2]:\n",
    "            self.visualize_mixture(show_plot=True)\n",
    "        \n",
    "        self.apply_feature_map()\n",
    "        self.train_classifier()\n",
    "        \n",
    "        if show_plots and self.d in [1, 2]:\n",
    "            self.visualize_classifier(show_plot=True)\n",
    "        \n",
    "        if show_plots and self.d == 1:\n",
    "            self.estimate_retain_density(show_plot=True)\n",
    "        \n",
    "        self.evaluate_metrics(n_mc=n_mc, eps=eps)\n",
    "        \n",
    "        results = {\n",
    "            'accuracy': self.accuracy,\n",
    "            'kl_divergence': self.kl_divergence,\n",
    "            'mae': self.mae,\n",
    "            'pop_risk_learned': self.pop_risk_learned,\n",
    "            'pop_risk_bayes': self.pop_risk_bayes,\n",
    "            'excess_risk': self.excess_risk\n",
    "        }\n",
    "        \n",
    "        return results\n",
    "    \n",
    "    def _plot_gaussian_ellipse(self, mean, cov, ax, color='black', label=None):\n",
    "        \"\"\"Helper function to plot Gaussian ellipse (for 2D).\"\"\"\n",
    "        vals, vecs = np.linalg.eigh(cov)\n",
    "        order = vals.argsort()[::-1]\n",
    "        vals, vecs = vals[order], vecs[:, order]\n",
    "        theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))\n",
    "        width, height = 2 * np.sqrt(vals)\n",
    "        ellip = Ellipse(xy=mean, width=width, height=height, angle=theta, \n",
    "                       edgecolor=color, fc='None', lw=2, label=label)\n",
    "        ax.add_patch(ellip)\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return (f\"GaussianMixtureUnlearning(d={self.d}, mean_sep={self.mean_sep}, \"\n",
    "                f\"var_f={self.var_f}, var_r={self.var_r}, gamma={self.gamma}, \"\n",
    "                f\"n={self.n}, seed={self.seed}, l2_reg_coeff={self.l2_reg_coeff})\")\n",
    "\n",
    "# Helper function to format var_f for LaTeX\n",
    "def format_var_f_label(var_f):\n",
    "    \"\"\"Format var_f value as LaTeX scientific notation.\"\"\"\n",
    "    exp = int(np.log10(var_f))\n",
    "    mantissa = int(var_f / (10**exp))\n",
    "    return f'$v_f = {mantissa} \\\\times 10^{{{exp}}}$'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87a10393",
   "metadata": {},
   "source": [
    "## Example Usage: Single Experiment\n",
    "\n",
    "Create an experiment instance and run the full pipeline with the same default parameters as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "289c5ca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create experiment instance\n",
    "experiment = GaussianMixtureUnlearning(\n",
    "    d=1,\n",
    "    mean_sep=4.0,\n",
    "    var_f=1.0,\n",
    "    var_r=1.0,\n",
    "    gamma=0.1,\n",
    "    n=500,\n",
    "    seed=613\n",
    ")\n",
    "\n",
    "# Run full pipeline\n",
    "results = experiment.run_full_pipeline(show_plots=True, n_mc=10000)\n",
    "\n",
    "# Display results\n",
    "print(f\"\\nResults:\")\n",
    "print(f\"Accuracy on training data: {results['accuracy']:.3f}\")\n",
    "print(f\"KL(p_r, p_r_est) [MC]: {results['kl_divergence']:.4f}\")\n",
    "print(f\"E_{{x~p_f}}[|p_r(x) - p_r_est(x)|] [MC]: {results['mae']:.4f}\")\n",
    "print(f\"\\nPopulation Risk Metrics:\")\n",
    "print(f\"Population Risk (Learned): {results['pop_risk_learned']:.4f}\")\n",
    "print(f\"Population Risk (Bayes Optimal): {results['pop_risk_bayes']:.4f}\")\n",
    "print(f\"Excess Risk: {results['excess_risk']:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9757a850",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_fs = [1e-6, 1e-3, 1e-0]\n",
    "ns = [50, 50, 50]\n",
    "\n",
    "fixed_params = {\n",
    "    'd': 1,\n",
    "    'mean_sep': 1.0,\n",
    "    'var_r': 1.0,\n",
    "    'gamma': 0.1,\n",
    "}\n",
    "\n",
    "lambda_range = [1e-8, 1e-4]\n",
    "\n",
    "# Temperature values to sweep\n",
    "T_vals = np.linspace(1.0, 3.0, 9)\n",
    "n_lambda_trials = 100\n",
    "n_trials_per_lambda = 10\n",
    "n_trials = 200\n",
    "\n",
    "line_color = \"#00274C\"\n",
    "\n",
    "# Results file path\n",
    "results_path = \"./synthetic-data/all_results_sweep_var_f.pkl\"\n",
    "\n",
    "# Check if saved data exists\n",
    "if os.path.exists(results_path):\n",
    "    print(f\"Loading saved results from {results_path}\")\n",
    "    with open(results_path, 'rb') as f:\n",
    "        all_T_results = pickle.load(f)\n",
    "    print(f\"Loaded results for var_f values: {list(all_T_results.keys())}\")\n",
    "else:\n",
    "    print(f\"No saved results found at {results_path}. Running experiments...\")\n",
    "    os.makedirs(os.path.dirname(results_path), exist_ok=True)\n",
    "\n",
    "    # Master dictionary to store results for all var_f values\n",
    "    all_T_results = {}  # all_T_results[var_f][T] = [result1, result2, ...]\n",
    "\n",
    "    for var_f_idx, (var_f, n) in enumerate(zip(var_fs, ns)):\n",
    "        print(f\"\\n{'='*60}\")\n",
    "        print(f\"Processing var_f = {var_f:.1e}\")\n",
    "        print(f\"{'='*60}\")\n",
    "        \n",
    "        # Sample lambda log-uniformly to find lambda for this var_f\n",
    "        lambdas_sampled = np.exp(np.random.uniform(np.log(lambda_range[0]), np.log(lambda_range[1]), n_lambda_trials))\n",
    "        \n",
    "        print(f\"Finding best lambda for var_f={var_f:.1e}...\")\n",
    "        lambda_results = []\n",
    "        for idx,lam in tqdm(enumerate(lambdas_sampled), desc=\"Lambda search\"):\n",
    "            lam_results = np.zeros(n_trials_per_lambda)\n",
    "            for inner_idx in range(n_trials_per_lambda):\n",
    "                # Unique seed for each (var_f, lambda, trial) combination\n",
    "                seed = var_f_idx * n_lambda_trials * n_trials_per_lambda + idx * n_trials_per_lambda + inner_idx\n",
    "                exp = GaussianMixtureUnlearning(\n",
    "                    n=n,\n",
    "                    var_f=var_f,\n",
    "                    l2_reg_coeff=lam,\n",
    "                    seed=seed,\n",
    "                    **fixed_params\n",
    "                )\n",
    "                result = exp.run_full_pipeline(show_plots=False, n_mc=100000, eps=1e-30, train_gamma=None, sample_train_data=False)\n",
    "                lam_results[inner_idx] = result[\"excess_risk\"]\n",
    "\n",
    "            lambda_results.append({\"lambda\":lam, \"excess_risk\": lam_results.mean()})\n",
    "        \n",
    "        # Find best lambda for minimizing excess risk\n",
    "        excess_risks = np.array([r['excess_risk'] for r in lambda_results])\n",
    "        lambda_best = lambda_results[np.argmin(excess_risks)]['lambda']\n",
    "        print(f\"Best lambda: {lambda_best:.4e}\")\n",
    "        \n",
    "        # Now sweep over T values with best lambda\n",
    "        T_results = {T: [] for T in T_vals}\n",
    "        # Offset temperature sweep seeds to avoid overlap with lambda search seeds\n",
    "        seed_offset = var_f_idx * n_lambda_trials * n_trials_per_lambda + n_lambda_trials * n_trials_per_lambda\n",
    "        for T_idx, T in enumerate(T_vals):\n",
    "            for trial in tqdm(range(n_trials), desc=f\"T={T:.2f}\", leave=False):\n",
    "                # Each temperature gets the same seed set\n",
    "                seed = seed_offset + trial\n",
    "                exp = GaussianMixtureUnlearning(\n",
    "                    n=n,\n",
    "                    var_f=var_f,\n",
    "                    l2_reg_coeff=lambda_best,\n",
    "                    T=T,\n",
    "                    seed=seed,\n",
    "                    **fixed_params\n",
    "                )\n",
    "                result = exp.run_full_pipeline(show_plots=False, n_mc=100000, eps=1e-30, train_gamma=None, sample_train_data=False)\n",
    "                result['T'] = T\n",
    "                T_results[T].append(result)\n",
    "        \n",
    "        # Store results for this var_f in master dictionary\n",
    "        all_T_results[var_f] = T_results\n",
    "\n",
    "    # Save results to file\n",
    "    print(f\"\\nSaving results to {results_path}\")\n",
    "    with open(results_path, 'wb') as f:\n",
    "        pickle.dump(all_T_results, f)\n",
    "    print(\"Results saved successfully!\")\n",
    "\n",
    "print(\"\\n\\nGenerating plots: Mean ± Standard Error\")\n",
    "\n",
    "fig2, axes2 = plt.subplots(len(var_fs), 2, figsize=(14, 5*len(var_fs)))\n",
    "\n",
    "for var_f_idx, (var_f, n) in enumerate(zip(var_fs, ns)):\n",
    "    T_results = all_T_results[var_f]\n",
    "    \n",
    "    # Collect mean and SE for each temperature\n",
    "    kl_means = []\n",
    "    kl_ses = []\n",
    "    \n",
    "    mae_means = []\n",
    "    mae_ses = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        kls = np.array([r['kl_divergence'] for r in T_results[T]])\n",
    "        maes = np.array([r['mae'] for r in T_results[T]])\n",
    "        \n",
    "        # KL stats\n",
    "        kl_means.append(np.mean(kls))\n",
    "        kl_ses.append(np.std(kls, ddof=1) / np.sqrt(len(kls)))\n",
    "        \n",
    "        # MAE stats\n",
    "        mae_means.append(np.mean(maes))\n",
    "        mae_ses.append(np.std(maes, ddof=1) / np.sqrt(len(maes)))\n",
    "    \n",
    "    # Convert to numpy arrays for easy plotting\n",
    "    kl_means = np.array(kl_means)\n",
    "    kl_ses = np.array(kl_ses)\n",
    "    \n",
    "    mae_means = np.array(mae_means)\n",
    "    mae_ses = np.array(mae_ses)\n",
    "\n",
    "    # Plot KL\n",
    "    ax_kl = axes2[var_f_idx, 0]\n",
    "    ax_kl.plot(T_vals, kl_means, color=line_color, linewidth=2, label='Mean')\n",
    "    \n",
    "    # Plot MAE\n",
    "    ax_mae = axes2[var_f_idx, 1]\n",
    "    ax_mae.plot(T_vals, mae_means, color=line_color, linewidth=2, label='Mean')\n",
    "\n",
    "    error_kw = {'capsize': 5, 'elinewidth': 2, 'ecolor': 'gray', 'alpha': 0.7}\n",
    "    ax_kl.errorbar(T_vals, kl_means, yerr=kl_ses, \n",
    "                   fmt='-o', color=line_color, **error_kw)\n",
    "    ax_mae.errorbar(T_vals, mae_means, yerr=mae_ses, \n",
    "                    fmt='-o', color=line_color, **error_kw)\n",
    "\n",
    "    # --- STYLING ---\n",
    "    ax_kl.set_xlabel('Base Model Temperature', fontsize=12, fontweight='bold')\n",
    "    ax_kl.set_ylabel('Retain Error', fontsize=12, fontweight='bold')\n",
    "    ax_kl.set_yscale('log')\n",
    "    ax_kl.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "    \n",
    "    ax_mae.set_xlabel('Base Model Temperature', fontsize=12, fontweight='bold')\n",
    "    ax_mae.set_ylabel('Forget Error', fontsize=12, fontweight='bold')\n",
    "    ax_mae.set_yscale('log')\n",
    "    ax_mae.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"\\n\\nGenerating combined plots for all var_f values\")\n",
    "\n",
    "# Create colormap for var_f values\n",
    "colors_var_f = plt.cm.viridis(np.linspace(0, 1, len(var_fs)))\n",
    "\n",
    "# Figure 1: Retain Error for all var_f\n",
    "fig_kl = plt.figure(figsize=(10, 6))\n",
    "ax_kl_combined = fig_kl.add_subplot(111)\n",
    "\n",
    "for var_f_idx, var_f in enumerate(var_fs):\n",
    "    T_results = all_T_results[var_f]\n",
    "    kl_means = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        kls = np.array([r['kl_divergence'] for r in T_results[T]])\n",
    "        kl_means.append(np.mean(kls))\n",
    "    \n",
    "    kl_means = np.array(kl_means)\n",
    "    ax_kl_combined.plot(T_vals, kl_means, color=colors_var_f[var_f_idx], linewidth=2.5, \n",
    "                        marker='o', markersize=6, label=format_var_f_label(var_f))\n",
    "\n",
    "ax_kl_combined.set_xlabel('Base Model Temperature', fontsize=16, fontweight='bold')\n",
    "ax_kl_combined.set_ylabel('Retain Error', fontsize=16, fontweight='bold')\n",
    "ax_kl_combined.tick_params(axis='both', labelsize=14)\n",
    "ax_kl_combined.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "# ax_kl_combined.legend(fontsize=13, loc='best')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Figure 2: Forget Error for all var_f\n",
    "fig_mae = plt.figure(figsize=(10, 6))\n",
    "ax_mae_combined = fig_mae.add_subplot(111)\n",
    "\n",
    "for var_f_idx, var_f in enumerate(var_fs):\n",
    "    T_results = all_T_results[var_f]\n",
    "    mae_means = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        maes = np.array([r['mae'] for r in T_results[T]])\n",
    "        mae_means.append(np.mean(maes))\n",
    "    \n",
    "    mae_means = np.array(mae_means)\n",
    "    ax_mae_combined.plot(T_vals, mae_means, color=colors_var_f[var_f_idx], linewidth=2.5, \n",
    "                         marker='o', markersize=6, label=format_var_f_label(var_f))\n",
    "ax_mae_combined.set_xlabel('Base Model Temperature', fontsize=16, fontweight='bold')\n",
    "ax_mae_combined.set_ylabel('Forget Error', fontsize=16, fontweight='bold')\n",
    "ax_mae_combined.tick_params(axis='both', labelsize=14)\n",
    "ax_mae_combined.set_yscale('log')\n",
    "ax_mae_combined.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "ax_mae_combined.legend(fontsize=15, loc='best')\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a7fe865",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_f = 1e-3\n",
    "ns = [25, 50, 100, 200, 400]\n",
    "\n",
    "fixed_params = {\n",
    "    'd': 1,\n",
    "    'mean_sep': 1.0,\n",
    "    'var_r': 1.0,\n",
    "    'gamma': 0.1,\n",
    "    'var_f': var_f,\n",
    "}\n",
    "\n",
    "lambda_range = [1e-8, 1e-4]\n",
    "\n",
    "# Temperature values to sweep\n",
    "T_vals = np.linspace(1.0, 3.0, 9)\n",
    "n_lambda_trials = 100\n",
    "n_trials_per_lambda = 10\n",
    "n_trials = 200\n",
    "\n",
    "line_color = \"#00274C\"\n",
    "\n",
    "# Results file path\n",
    "results_path = \"./synthetic-data/all_results_sweep_n.pkl\"\n",
    "\n",
    "# Check if saved data exists\n",
    "if os.path.exists(results_path):\n",
    "    print(f\"Loading saved results from {results_path}\")\n",
    "    with open(results_path, 'rb') as f:\n",
    "        all_results_sweep_n = pickle.load(f)\n",
    "    print(f\"Loaded results for n values: {list(all_results_sweep_n.keys())}\")\n",
    "else:\n",
    "    print(f\"No saved results found at {results_path}. Running experiments...\")\n",
    "    os.makedirs(os.path.dirname(results_path), exist_ok=True)\n",
    "\n",
    "    # Master dictionary to store results for all n values\n",
    "    all_results_sweep_n = {}  # all_results_sweep_T[n][T] = [result1, result2, ...]\n",
    "\n",
    "    for n_idx, n in enumerate(ns):\n",
    "        print(f\"\\n{'='*60}\")\n",
    "        print(f\"Processing n = {n}\")\n",
    "        print(f\"{'='*60}\")\n",
    "        \n",
    "        # Sample lambda log-uniformly to find lambda for this n\n",
    "        lambdas_sampled = np.exp(np.random.uniform(np.log(lambda_range[0]), np.log(lambda_range[1]), n_lambda_trials))\n",
    "        \n",
    "        print(f\"Finding best lambda for n={n}...\")\n",
    "        lambda_results = []\n",
    "        for idx, lam in tqdm(enumerate(lambdas_sampled), desc=\"Lambda search\"):\n",
    "            lam_results = np.zeros(n_trials_per_lambda)\n",
    "            for inner_idx in range(n_trials_per_lambda):\n",
    "                # Unique seed for each (n, lambda, trial) combination\n",
    "                seed = n_idx * n_lambda_trials * n_trials_per_lambda + idx * n_trials_per_lambda + inner_idx\n",
    "                exp = GaussianMixtureUnlearning(\n",
    "                    n=n,\n",
    "                    l2_reg_coeff=lam,\n",
    "                    seed=seed,\n",
    "                    **fixed_params\n",
    "                )\n",
    "                result = exp.run_full_pipeline(show_plots=False, n_mc=100000, eps=1e-30, train_gamma=None, sample_train_data=False)\n",
    "                lam_results[inner_idx] = result[\"excess_risk\"]\n",
    "\n",
    "            lambda_results.append({\"lambda\": lam, \"excess_risk\": lam_results.mean()})\n",
    "        \n",
    "        # Find best lambda for minimizing excess risk\n",
    "        excess_risks = np.array([r['excess_risk'] for r in lambda_results])\n",
    "        lambda_best = lambda_results[np.argmin(excess_risks)]['lambda']\n",
    "        print(f\"Best lambda: {lambda_best:.4e}\")\n",
    "        \n",
    "        # Now sweep over T values with best lambda\n",
    "        T_results = {T: [] for T in T_vals}\n",
    "\n",
    "        # Offset temperature sweep seeds to avoid overlap with lambda search seeds (0 to n_lambda_trials*n_trials_per_lambda-1)\n",
    "        seed_offset = n_idx * n_lambda_trials * n_trials_per_lambda + n_lambda_trials * n_trials_per_lambda\n",
    "\n",
    "        for T_idx, T in enumerate(T_vals):\n",
    "            for trial in tqdm(range(n_trials), desc=f\"T={T:.2f}\", leave=False):\n",
    "                # Each temperature gets the same seed set\n",
    "                seed = seed_offset + trial\n",
    "                exp = GaussianMixtureUnlearning(\n",
    "                    n=n,\n",
    "                    l2_reg_coeff=lambda_best,\n",
    "                    T=T,\n",
    "                    seed=seed,\n",
    "                    **fixed_params\n",
    "                )\n",
    "                result = exp.run_full_pipeline(show_plots=False, n_mc=100000, eps=1e-30, train_gamma=None, sample_train_data=False)\n",
    "                result['T'] = T\n",
    "                T_results[T].append(result)\n",
    "        \n",
    "        # Store results for this n in master dictionary\n",
    "        all_results_sweep_n[n] = T_results\n",
    "\n",
    "    # Save results to file\n",
    "    print(f\"\\nSaving results to {results_path}\")\n",
    "    with open(results_path, 'wb') as f:\n",
    "        pickle.dump(all_results_sweep_n, f)\n",
    "    print(\"Results saved successfully!\")\n",
    "\n",
    "print(\"\\n\\nGenerating plots: Mean ± Standard Error\")\n",
    "\n",
    "for n_idx, n in enumerate(ns):\n",
    "    print(f\"\\nn = {n}\")\n",
    "    \n",
    "    fig2, axes2 = plt.subplots(1, 2, figsize=(14, 5))\n",
    "    \n",
    "    T_results = all_results_sweep_n[n]\n",
    "    \n",
    "    # Collect mean and SE for each temperature\n",
    "    kl_means = []\n",
    "    kl_ses = []\n",
    "    \n",
    "    mae_means = []\n",
    "    mae_ses = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        kls = np.array([r['kl_divergence'] for r in T_results[T]])\n",
    "        maes = np.array([r['mae'] for r in T_results[T]])\n",
    "        \n",
    "        # KL stats\n",
    "        kl_means.append(np.mean(kls))\n",
    "        kl_ses.append(np.std(kls, ddof=1) / np.sqrt(len(kls)))\n",
    "        \n",
    "        # MAE stats\n",
    "        mae_means.append(np.mean(maes))\n",
    "        mae_ses.append(np.std(maes, ddof=1) / np.sqrt(len(maes)))\n",
    "    \n",
    "    # Convert to numpy arrays for easy plotting\n",
    "    kl_means = np.array(kl_means)\n",
    "    kl_ses = np.array(kl_ses)\n",
    "    \n",
    "    mae_means = np.array(mae_means)\n",
    "    mae_ses = np.array(mae_ses)\n",
    "\n",
    "    # Plot Retain Error (KL)\n",
    "    ax_kl = axes2[0]\n",
    "    ax_kl.plot(T_vals, kl_means, color=line_color, linewidth=2, label='Mean')\n",
    "    \n",
    "    # Plot Forget Error (MAE)\n",
    "    ax_mae = axes2[1]\n",
    "    ax_mae.plot(T_vals, mae_means, color=line_color, linewidth=2, label='Mean')\n",
    "\n",
    "    error_kw = {'capsize': 5, 'elinewidth': 2, 'ecolor': 'gray', 'alpha': 0.7}\n",
    "    ax_kl.errorbar(T_vals, kl_means, yerr=kl_ses, \n",
    "                   fmt='-o', color=line_color, **error_kw)\n",
    "    ax_mae.errorbar(T_vals, mae_means, yerr=mae_ses, \n",
    "                    fmt='-o', color=line_color, **error_kw)\n",
    "\n",
    "    # --- STYLING ---\n",
    "    ax_kl.set_xlabel('Base Model Temperature', fontsize=12, fontweight='bold')\n",
    "    ax_kl.set_ylabel('Retain Error', fontsize=12, fontweight='bold')\n",
    "    ax_kl.set_yscale('log')\n",
    "    ax_kl.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "    \n",
    "    ax_mae.set_xlabel('Base Model Temperature', fontsize=12, fontweight='bold')\n",
    "    ax_mae.set_ylabel('Forget Error', fontsize=12, fontweight='bold')\n",
    "    ax_mae.set_yscale('log')\n",
    "    ax_mae.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "print(\"\\n\\nGenerating combined plots for all n values\")\n",
    "\n",
    "# Create colormap for n values\n",
    "colors_n = plt.cm.viridis(np.linspace(0, 1, len(ns)))\n",
    "\n",
    "# Figure 1: Retain Error (KL) for all n\n",
    "fig_kl = plt.figure(figsize=(10, 6))\n",
    "ax_kl_combined = fig_kl.add_subplot(111)\n",
    "\n",
    "for n_idx, n in enumerate(ns):\n",
    "    T_results = all_results_sweep_n[n]\n",
    "    kl_means = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        kls = np.array([r['kl_divergence'] for r in T_results[T]])\n",
    "        kl_means.append(np.mean(kls))\n",
    "    \n",
    "    kl_means = np.array(kl_means)\n",
    "    ax_kl_combined.plot(T_vals, kl_means, color=colors_n[n_idx], linewidth=2.5, \n",
    "                        marker='o', markersize=6, label=f'n={n}')\n",
    "\n",
    "ax_kl_combined.set_xlabel('Base Model Temperature', fontsize=16, fontweight='bold')\n",
    "ax_kl_combined.set_ylabel('Retain Error', fontsize=16, fontweight='bold')\n",
    "ax_kl_combined.tick_params(axis='both', labelsize=14)\n",
    "# ax_kl_combined.set_yscale('log')\n",
    "ax_kl_combined.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "# ax_kl_combined.legend(fontsize=13, loc='best')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Figure 2: Forget Error (MAE) for all n\n",
    "fig_mae = plt.figure(figsize=(10, 6))\n",
    "ax_mae_combined = fig_mae.add_subplot(111)\n",
    "\n",
    "for n_idx, n in enumerate(ns):\n",
    "    T_results = all_results_sweep_n[n]\n",
    "    mae_means = []\n",
    "    \n",
    "    for T in T_vals:\n",
    "        maes = np.array([r['mae'] for r in T_results[T]])\n",
    "        mae_means.append(np.mean(maes))\n",
    "    \n",
    "    mae_means = np.array(mae_means)\n",
    "    ax_mae_combined.plot(T_vals, mae_means, color=colors_n[n_idx], linewidth=2.5, \n",
    "                         marker='o', markersize=6, label=f'n={n}')\n",
    "\n",
    "ax_mae_combined.set_xlabel('Base Model Temperature', fontsize=16, fontweight='bold')\n",
    "ax_mae_combined.set_ylabel('Forget Error', fontsize=16, fontweight='bold')\n",
    "ax_mae_combined.tick_params(axis='both', labelsize=14)\n",
    "# ax_mae_combined.set_yscale('log')\n",
    "ax_mae_combined.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "ticks = [0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18]\n",
    "labels = ['0.06', '', '0.10', '', '0.14', '', '0.18']\n",
    "ax_mae_combined.set_yticks(ticks)\n",
    "ax_mae_combined.set_yticklabels(labels, fontsize=14)\n",
    "ax_mae_combined.legend(fontsize=15, loc='best')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
