{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PwizGCKo0UId"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.nn.utils import parameters_to_vector, vector_to_parameters\n",
        "from torch.func import functional_call\n",
        "import numpy as np\n",
        "from sklearn.datasets import make_friedman1\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from sklearn.model_selection import train_test_split, KFold\n",
        "import matplotlib.pyplot as plt\n",
        "import time\n",
        "from scipy.sparse.linalg import LinearOperator\n",
        "from scipy.sparse.linalg import cg\n",
        "from copy import deepcopy\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define Model"
      ],
      "metadata": {
        "id": "T6UYsZO70iVy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class TwoLayerNet(nn.Module):\n",
        "    \"\"\"Two-layer neural network\"\"\"\n",
        "    def __init__(self, input_dim=5, hidden_dim=100, output_dim=1):\n",
        "        super(TwoLayerNet, self).__init__()\n",
        "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
        "        self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
        "        self.relu = nn.ReLU()\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.relu(self.fc1(x))\n",
        "        x = self.fc2(x)\n",
        "        return x\n"
      ],
      "metadata": {
        "id": "iBnNQsle0j6Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generate Data"
      ],
      "metadata": {
        "id": "wWYPXIpN0loE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def generate_friedman1_data(n_samples=300, n_features=5, noise=0.1):\n",
        "    \"\"\"Generate Friedman1 dataset\"\"\"\n",
        "    X, y = make_friedman1(n_samples=n_samples, n_features=n_features,\n",
        "                         noise=noise, random_state=42)\n",
        "    return X, y\n",
        "\n",
        "def add_random_features(X, n_additional=495):\n",
        "    \"\"\"Add random features\"\"\"\n",
        "    n_samples = X.shape[0]\n",
        "    random_features = np.random.randn(n_samples, n_additional)\n",
        "    return np.concatenate([X, random_features], axis=1)\n",
        "\n",
        "def l1_proximal_operator(vector, lambda_reg):\n",
        "    \"\"\"\n",
        "    Apply L1 proximal operator (soft thresholding) to a vector\n",
        "    prox_{λ||·||_1}(v) = sign(v) * max(0, |v| - λ)\n",
        "\n",
        "    Args:\n",
        "        vector: Input vector (torch.Tensor)\n",
        "        lambda_reg: Regularization parameter (float)\n",
        "\n",
        "    Returns:\n",
        "        Proximal operator result (torch.Tensor)\n",
        "    \"\"\"\n",
        "    return torch.sign(vector) * torch.clamp(torch.abs(vector) - lambda_reg, min=0.0)\n",
        "\n",
        "def l1_loss(model, lambda_reg=0.01):\n",
        "    \"\"\"Calculate L1 regularization loss\"\"\"\n",
        "    l1_penalty = 0\n",
        "    for param in model.parameters():\n",
        "        l1_penalty += torch.sum(torch.abs(param))\n",
        "    return lambda_reg * l1_penalty\n",
        "\n",
        "def l2_loss(model, lambda_reg=0.01):\n",
        "    \"\"\"Calculate L2 regularization loss\"\"\"\n",
        "    l2_penalty = 0\n",
        "    for param in model.parameters():\n",
        "        l2_penalty += torch.sum(param ** 2)\n",
        "    return lambda_reg * l2_penalty\n"
      ],
      "metadata": {
        "id": "sC1GTm1F0pw7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Interact with the Model"
      ],
      "metadata": {
        "id": "BTBaCQ_x0qXC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def set_flat_params(model, flat_params):\n",
        "    with torch.no_grad():\n",
        "        vector_to_parameters(flat_params, model.parameters())\n",
        "\n",
        "def get_flat_params(model):\n",
        "    return parameters_to_vector(list(model.parameters())).detach().clone()\n",
        "\n",
        "def set_flat_params(model, flat_params):\n",
        "    with torch.no_grad():\n",
        "        vector_to_parameters(flat_params, model.parameters())\n",
        "\n",
        "def _flat_to_param_dict(model, flat):\n",
        "    items  = list(model.named_parameters())\n",
        "    sizes  = [p.numel() for _, p in items]\n",
        "    splits = torch.split(flat, sizes)\n",
        "    return {name: t.view(p.shape) for (name, p), t in zip(items, splits)}\n"
      ],
      "metadata": {
        "id": "P68eOKFr0trj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Compute HVP and FVP"
      ],
      "metadata": {
        "id": "fdDu58pv0uFU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_hvp(model, loss_fn, data, targets, vector, damping=0.0):\n",
        "    model.zero_grad()\n",
        "    outputs = model(data)\n",
        "    loss = loss_fn(outputs, targets)\n",
        "    grads = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad],\n",
        "                                create_graph=True)\n",
        "    flat_g = torch.cat([g.reshape(-1) for g in grads])\n",
        "    gvp = torch.dot(flat_g, vector)\n",
        "    hvp = torch.autograd.grad(gvp, [p for p in model.parameters() if p.requires_grad],\n",
        "                              retain_graph=False)\n",
        "    flat_hvp = torch.cat([h.reshape(-1) for h in hvp])\n",
        "    if damping > 0.0:\n",
        "        flat_hvp = flat_hvp + damping * vector\n",
        "    return flat_hvp\n",
        "\n",
        "\n",
        "def compute_fisher_vector_product(model, data, targets, vector, loss_fn=None, damping=0.0):\n",
        "    if loss_fn is None:\n",
        "        loss_fn = nn.MSELoss(reduction='mean')\n",
        "    flat0 = get_flat_params(model)\n",
        "\n",
        "    def out_f(flat):\n",
        "        params = _flat_to_param_dict(model, flat)\n",
        "        # Stateless forward: out depends on `flat` through `params` (differentiable)\n",
        "        return functional_call(model, params, (data,))\n",
        "\n",
        "    # JVP: (y, J v)\n",
        "    y, jvp = torch.autograd.functional.jvp(out_f, flat0, v=vector)\n",
        "\n",
        "    # H_y * (J v): for MSE(mean), H_y = (1/n)*I\n",
        "    def loss_y(y_):\n",
        "        return loss_fn(y_, targets)\n",
        "\n",
        "    _, Hy_Jv = torch.autograd.functional.hvp(loss_y, y, v=jvp)\n",
        "\n",
        "    # VJP: J^T (H_y J v)\n",
        "    _, vjp = torch.autograd.functional.vjp(out_f, flat0, v=Hy_Jv)\n",
        "    gnhvp = vjp\n",
        "\n",
        "    if damping > 0.0:\n",
        "        gnhvp = gnhvp + damping * vector\n",
        "    return gnhvp\n",
        "\n",
        "def solve_linear_system_cg(hvp_func, b, max_iter=200):\n",
        "    b_np = b.detach().cpu().numpy()\n",
        "\n",
        "    def matvec(v):\n",
        "        v_t = torch.as_tensor(v, dtype=b.dtype, device=b.device)\n",
        "        r = hvp_func(v_t)\n",
        "        return r.detach().cpu().numpy()\n",
        "\n",
        "    A = LinearOperator((b_np.size, b_np.size), matvec=matvec)\n",
        "    x_np, info = cg(A, b_np, maxiter=max_iter)\n",
        "    if info > 0:\n",
        "        print(f\"Warning: CG hit iteration limit ({info}).\")\n",
        "    elif info < 0:\n",
        "        print(f\"Warning: CG failed with code {info}.\")\n",
        "    return torch.as_tensor(x_np, dtype=b.dtype, device=b.device)"
      ],
      "metadata": {
        "id": "RT5cAFb10yw8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Estimate Influence"
      ],
      "metadata": {
        "id": "isZp9BuS0zTM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def estimate_influence_hessian(model, data, targets, vector, lambda_reg=0.01, use_l1_prox=True, lambda_l2=1e-3, cg_max_iter=5000):\n",
        "    \"\"\"\n",
        "    Estimate influence using Hessian inverse approximation\n",
        "    Computes prox_{λ||·||_1}(θ + H^{-1}v) or θ + H^{-1}v\n",
        "\n",
        "    Args:\n",
        "        model: Neural network model\n",
        "        data: Input data\n",
        "        targets: Target values\n",
        "        vector: Influence vector\n",
        "        lambda_reg: L1 regularization parameter\n",
        "        use_l1_prox: Whether to apply L1 proximal operator\n",
        "        cg_max_iter: Maximum CG iterations\n",
        "\n",
        "    Returns:\n",
        "        Updated parameters\n",
        "    \"\"\"\n",
        "    criterion = nn.MSELoss(reduction='mean')\n",
        "\n",
        "    # Define HVP function\n",
        "    def hvp_func(v):\n",
        "        return compute_hvp(model, criterion, data, targets, v, damping=lambda_l2)\n",
        "\n",
        "    h_inv_v = solve_linear_system_cg(hvp_func, vector, max_iter=cg_max_iter)\n",
        "\n",
        "    # Get current parameters\n",
        "    current_params = get_flat_params(model)\n",
        "\n",
        "    # Compute update: θ + H^{-1}v\n",
        "    updated_params = current_params + h_inv_v\n",
        "\n",
        "    # Apply L1 proximal operator if requested\n",
        "    if use_l1_prox:\n",
        "        updated_params = l1_proximal_operator(updated_params, lambda_reg)\n",
        "\n",
        "    return updated_params\n",
        "\n",
        "def estimate_influence_fisher(model, data, targets, vector, lambda_reg=0.01, use_l1_prox=True, lambda_l2=1e-3, cg_max_iter=5000):\n",
        "    \"\"\"\n",
        "    Estimate influence using Fisher Information Matrix inverse approximation\n",
        "    Computes prox_{λ||·||_1}(θ + F^{-1}v) or θ + F^{-1}v\n",
        "\n",
        "    Args:\n",
        "        model: Neural network model\n",
        "        data: Input data\n",
        "        targets: Target values\n",
        "        vector: Influence vector\n",
        "        lambda_reg: L1 regularization parameter\n",
        "        use_l1_prox: Whether to apply L1 proximal operator\n",
        "        cg_max_iter: Maximum CG iterations\n",
        "\n",
        "    Returns:\n",
        "        Updated parameters\n",
        "    \"\"\"\n",
        "    # Define FIM-vector product function\n",
        "    def fim_func(v):\n",
        "        return compute_fisher_vector_product(model, data, targets, v, damping=lambda_l2)\n",
        "    f_inv_v = solve_linear_system_cg(fim_func, vector, max_iter=cg_max_iter)\n",
        "\n",
        "    # Get current parameters\n",
        "    current_params = get_flat_params(model)\n",
        "\n",
        "    # Compute update: θ + F^{-1}v\n",
        "    updated_params = current_params + f_inv_v\n",
        "\n",
        "    # Apply L1 proximal operator if requested\n",
        "    if use_l1_prox:\n",
        "        updated_params = l1_proximal_operator(updated_params, lambda_reg)\n",
        "\n",
        "    return updated_params"
      ],
      "metadata": {
        "id": "XdHfv-M701xi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluate Model"
      ],
      "metadata": {
        "id": "3eIOuu7L04Jh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def evaluate_model_with_params(model, params, X, y):\n",
        "    device = next(model.parameters()).device\n",
        "    orig = get_flat_params(model)\n",
        "    with torch.no_grad():\n",
        "        set_flat_params(model, params)\n",
        "        model.eval()\n",
        "        X_t = torch.as_tensor(X, dtype=torch.float32, device=device)\n",
        "        y_t = torch.as_tensor(y, dtype=torch.float32, device=device).unsqueeze(1)\n",
        "        loss = nn.MSELoss(reduction='mean')(model(X_t), y_t).item()\n",
        "        set_flat_params(model, orig)\n",
        "    return loss\n",
        "\n",
        "def estimate_cv_with_influence(\n",
        "    model,\n",
        "    X, y,\n",
        "    n_folds=5,\n",
        "    influence_method='hessian',\n",
        "    lambda_reg=0.01,\n",
        "    use_l1_prox=True,\n",
        "    lambda_l2=1e-3\n",
        "):\n",
        "    \"\"\"\n",
        "    Estimate cross-validation score using influence functions on a CLONED model,\n",
        "    leaving the original 'model' completely unchanged.\n",
        "\n",
        "    Args:\n",
        "        model: Trained torch.nn.Module\n",
        "        X (np.ndarray or torch.Tensor): inputs\n",
        "        y (np.ndarray or torch.Tensor): targets (shape [N] or [N,1])\n",
        "        n_folds (int): number of folds\n",
        "        influence_method (str): 'hessian' or 'fisher'\n",
        "        lambda_reg (float): L1 regularization parameter (for proximal step)\n",
        "        use_l1_prox (bool): whether to apply L1 proximal operator\n",
        "\n",
        "    Returns:\n",
        "        cv_estimates (list of float), mean_estimate (float), std_estimate (float)\n",
        "    \"\"\"\n",
        "    # Resolve device from the original model (kept untouched)\n",
        "    try:\n",
        "        device = next(model.parameters()).device\n",
        "    except StopIteration:\n",
        "        device = torch.device('cpu')\n",
        "\n",
        "    # Ensure numpy arrays for KFold indexing\n",
        "    X_np = X.detach().cpu().numpy() if torch.is_tensor(X) else np.asarray(X)\n",
        "    y_np = y.detach().cpu().numpy() if torch.is_tensor(y) else np.asarray(y)\n",
        "\n",
        "    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)\n",
        "    cv_estimates = []\n",
        "\n",
        "    print(f\"  Estimating {n_folds}-fold CV using {influence_method} influence (on cloned models)...\")\n",
        "    time_start = time.time()\n",
        "    for fold, (train_idx, val_idx) in enumerate(kf.split(X_np), start=1):\n",
        "        # ---- Clone the model for THIS fold ----\n",
        "        work_model = deepcopy(model).to(device)\n",
        "        work_model.train()\n",
        "\n",
        "        # Get fold data (torch tensors on correct device)\n",
        "        X_train_tensor = torch.as_tensor(X_np[train_idx], dtype=torch.float32, device=device)\n",
        "        y_train_tensor = torch.as_tensor(y_np[train_idx], dtype=torch.float32, device=device).reshape(-1, 1)\n",
        "\n",
        "        X_val_tensor   = torch.as_tensor(X_np[val_idx],   dtype=torch.float32, device=device)\n",
        "        y_val_tensor   = torch.as_tensor(y_np[val_idx],   dtype=torch.float32, device=device).reshape(-1, 1)\n",
        "\n",
        "        # Create influence vector (negative gradient w.r.t. validation loss)\n",
        "        work_model.zero_grad(set_to_none=True)\n",
        "\n",
        "        # grads of MEAN validation loss\n",
        "        criterion_val = nn.MSELoss(reduction='mean')\n",
        "        val_loss = criterion_val(work_model(X_val_tensor), y_val_tensor)\n",
        "        val_grads = torch.autograd.grad(val_loss, [p for p in work_model.parameters() if p.requires_grad])\n",
        "\n",
        "        g_val_mean = torch.cat([g.reshape(-1) for g in val_grads]).detach()\n",
        "\n",
        "        # scale to mean and to \"per-train-example\" factor\n",
        "        scale = 1/(len(train_idx))\n",
        "        influence_vector = scale * g_val_mean\n",
        "\n",
        "        try:\n",
        "\n",
        "            if influence_method == 'hessian':\n",
        "                updated_params = estimate_influence_hessian(\n",
        "                    work_model, X_train_tensor, y_train_tensor, influence_vector,\n",
        "                    lambda_reg=lambda_reg, use_l1_prox=use_l1_prox, lambda_l2=lambda_l2\n",
        "                )\n",
        "            elif influence_method == 'fisher':\n",
        "                updated_params = estimate_influence_fisher(\n",
        "                    work_model, X_train_tensor, y_train_tensor, influence_vector,\n",
        "                    lambda_reg=lambda_reg, use_l1_prox=use_l1_prox, lambda_l2=lambda_l2\n",
        "                )\n",
        "            else:\n",
        "                raise ValueError(f\"Unknown influence method: {influence_method}\")\n",
        "\n",
        "            val_loss_estimate = evaluate_model_with_params(\n",
        "                work_model, updated_params,  # <- evaluate on the clone\n",
        "                X_np[val_idx], y_np[val_idx]\n",
        "            )\n",
        "            cv_estimates.append(float(val_loss_estimate))\n",
        "            # print(f\"    Fold {fold} influence estimate: {val_loss_estimate:.4f}\")\n",
        "\n",
        "        except Exception as e:\n",
        "            print(f\"    Fold {fold} failed with error: {e}\")\n",
        "\n",
        "            current_params = get_flat_params(work_model)\n",
        "            current_val_loss = evaluate_model_with_params(\n",
        "                work_model, current_params,\n",
        "                X_np[val_idx], y_np[val_idx]\n",
        "            )\n",
        "            cv_estimates.append(float(current_val_loss))\n",
        "            print(f\"    Fold {fold} fallback estimate: {current_val_loss:.4f}\")\n",
        "\n",
        "        del work_model\n",
        "        torch.cuda.empty_cache() if device.type == 'cuda' else None\n",
        "    time_end = time.time()\n",
        "    total_time = time_end - time_start\n",
        "    mean_estimate = float(np.mean(cv_estimates))\n",
        "    std_estimate  = float(np.std(cv_estimates))\n",
        "    print(f\"  Influence CV estimate: {mean_estimate:.4f} ± {std_estimate:.4f}\")\n",
        "    print(f\"  Time per fold: {total_time/n_folds:.4f}\")\n",
        "    return cv_estimates, mean_estimate, std_estimate, total_time/n_folds\n",
        "\n",
        "def train_model_cv(model, X, y, epochs=500, lr=1e-4, lambda_reg=0.0, lambda_l2=1e-3, batch_size=32, reg_type='l1'):\n",
        "    \"\"\"Train model with optional L1 regularization for a single fold\"\"\"\n",
        "    # Convert to PyTorch tensors\n",
        "    X_tensor = torch.FloatTensor(X)\n",
        "    y_tensor = torch.FloatTensor(y).unsqueeze(1)\n",
        "\n",
        "    # Create data loader\n",
        "    dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)\n",
        "    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "    criterion = nn.MSELoss(reduction='mean')\n",
        "\n",
        "    model.train()\n",
        "    for epoch in range(epochs):\n",
        "        for batch_x, batch_y in loader:\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(batch_x)\n",
        "            mse_loss = criterion(outputs, batch_y)\n",
        "\n",
        "            # Add regularization if specified\n",
        "            total_loss = mse_loss\n",
        "            if lambda_reg > 0:\n",
        "                if reg_type == 'l1':\n",
        "                    total_loss += l1_loss(model, lambda_reg)\n",
        "                total_loss += l2_loss(model, lambda_l2)\n",
        "\n",
        "            # Backward pass\n",
        "            total_loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "def evaluate_model_cv(model, X, y):\n",
        "    \"\"\"Evaluate model on given data\"\"\"\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        X_tensor = torch.FloatTensor(X)\n",
        "        y_tensor = torch.FloatTensor(y).unsqueeze(1)\n",
        "        outputs = model(X_tensor)\n",
        "        criterion = nn.MSELoss(reduction='mean')\n",
        "        loss = criterion(outputs, y_tensor)\n",
        "        return loss.item()\n",
        "\n",
        "def run_cross_validation(X, y, build_model, n_folds=5, epochs=500, lr=1e-4,\n",
        "                         lambda_reg=0.0, reg_type='l1', lambda_l2=1e-3, batch_size=32):\n",
        "    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)\n",
        "    scores = []\n",
        "    time_start = time.time()\n",
        "    for train_idx, val_idx in kf.split(X):\n",
        "        m = build_model()\n",
        "        train_model_cv(m, X[train_idx], y[train_idx], epochs=epochs,\n",
        "                       lr=lr, lambda_reg=lambda_reg, lambda_l2=lambda_l2, batch_size=batch_size, reg_type=reg_type)\n",
        "        scores.append(evaluate_model_cv(m, X[val_idx], y[val_idx]))\n",
        "    time_end = time.time()\n",
        "    print(f\"  Actual CV estimate is given by: {np.mean(scores):.4f}\")\n",
        "    print(f\"  Time per fold: {(time_end - time_start)/n_folds:.4f}\")\n",
        "    return float(np.mean(scores)), float(np.std(scores)), (time_end - time_start)/n_folds\n",
        "\n",
        "def train_model_with_cv_tracking(model, train_loader, test_loader, X_full, y_full, epochs=1000, lr=1e-4, lambda_reg=0.0, reg_type='l1', cv_interval=100):\n",
        "    \"\"\"Train model with CV estimation tracking during training\"\"\"\n",
        "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "    criterion = nn.MSELoss(reduction='mean')\n",
        "\n",
        "    train_losses = []\n",
        "    test_losses = []\n",
        "    cv_estimates_hessian = []\n",
        "    cv_estimates_fisher = []\n",
        "    cv_epochs = []\n",
        "    cv_estimates_full = []\n",
        "    cv_estimates_hessian_times = []\n",
        "    cv_estimates_fisher_times = []\n",
        "    cv_estimates_full_times = []\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        total_train_loss = 0\n",
        "\n",
        "        for batch_x, batch_y in train_loader:\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(batch_x)\n",
        "            mse_loss = criterion(outputs, batch_y)\n",
        "\n",
        "            # Add regularization if specified\n",
        "            total_loss = mse_loss\n",
        "            if lambda_reg > 0:\n",
        "                if reg_type == 'l1':\n",
        "                    total_loss += l1_loss(model, lambda_reg)\n",
        "                total_loss += l2_loss(model, 1e-3)  # small L2 for stability\n",
        "\n",
        "            # Backward pass\n",
        "            total_loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            total_train_loss += total_loss.item()\n",
        "\n",
        "        # Evaluate on test set\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            test_loss = 0\n",
        "            for batch_x, batch_y in test_loader:\n",
        "                outputs = model(batch_x)\n",
        "                test_loss += criterion(outputs, batch_y).item()\n",
        "\n",
        "            train_losses.append(total_train_loss / len(train_loader))\n",
        "            test_losses.append(test_loss / len(test_loader))\n",
        "\n",
        "        # Estimate CV using influence functions periodically\n",
        "        if epoch % cv_interval == 0:\n",
        "            print(f\"\\nEpoch {epoch}: Computing influence-based CV estimates...\")\n",
        "            cv_epochs.append(epoch)\n",
        "\n",
        "            # Estimate CV using Hessian influence\n",
        "            try:\n",
        "                _, cv_mean_h, cv_std_h, total_time_hessian = estimate_cv_with_influence(\n",
        "                    model, X_full, y_full, n_folds=5, influence_method='hessian',\n",
        "                    lambda_reg=lambda_reg, use_l1_prox=(reg_type=='l1' and lambda_reg > 0), lambda_l2=1e-3\n",
        "                )\n",
        "                cv_estimates_hessian.append(cv_mean_h)\n",
        "                cv_estimates_hessian_times.append(total_time_hessian)\n",
        "            except Exception as e:\n",
        "                print(f\"Hessian influence estimation failed: {e}\")\n",
        "                cv_estimates_hessian.append(np.nan)\n",
        "\n",
        "            # Estimate CV using Fisher influence\n",
        "            try:\n",
        "                _, cv_mean_f, cv_std_f, total_time_fisher = estimate_cv_with_influence(\n",
        "                    model, X_full, y_full, n_folds=5, influence_method='fisher',\n",
        "                    lambda_reg=lambda_reg, use_l1_prox=(reg_type=='l1' and lambda_reg > 0), lambda_l2=1e-3\n",
        "                )\n",
        "                cv_estimates_fisher.append(cv_mean_f)\n",
        "                cv_estimates_fisher_times.append(total_time_fisher)\n",
        "            except Exception as e:\n",
        "                print(f\"Fisher influence estimation failed: {e}\")\n",
        "                cv_estimates_fisher.append(np.nan)\n",
        "\n",
        "            # Direct CV calculation\n",
        "            print(\"Running actual 5-fold cross validation for comparison...\")\n",
        "            cv_mean_direct, cv_std_direct, total_time_cv = run_cross_validation(\n",
        "                X_full, y_full, build_model=lambda: TwoLayerNet(input_dim=X_full.shape[1]),\n",
        "                n_folds=5, epochs=epoch, lr=lr, lambda_reg=lambda_reg, reg_type=reg_type, lambda_l2=1e-3\n",
        "            )\n",
        "            cv_estimates_full.append(cv_mean_direct)\n",
        "            cv_estimates_full_times.append(total_time_cv)\n",
        "\n",
        "        print('=======================================')\n",
        "        print(f'Epoch {epoch}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}')\n",
        "        print('=======================================')\n",
        "\n",
        "    return train_losses, test_losses, cv_epochs, cv_estimates_hessian, cv_estimates_fisher, cv_estimates_full, cv_estimates_hessian_times, cv_estimates_fisher_times, cv_estimates_full_times\n",
        "\n",
        "def train_model(model, train_loader, test_loader, epochs=1000, lr=1e-4, lambda_reg=0.0, reg_type='l1'):\n",
        "    \"\"\"Train model with optional regularization (original function for compatibility)\"\"\"\n",
        "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "    criterion = nn.MSELoss(reduction='mean')\n",
        "\n",
        "    train_losses = []\n",
        "    test_losses = []\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        total_train_loss = 0\n",
        "\n",
        "        for batch_x, batch_y in train_loader:\n",
        "            optimizer.zero_grad()\n",
        "\n",
        "            # Forward pass\n",
        "            outputs = model(batch_x)\n",
        "            mse_loss = criterion(outputs, batch_y)\n",
        "\n",
        "            # Add regularization if specified\n",
        "            total_loss = mse_loss\n",
        "            if lambda_reg > 0:\n",
        "                if reg_type == 'l1':\n",
        "                    total_loss += l1_loss(model, lambda_reg)\n",
        "                total_loss += l2_loss(model, 1e-3)\n",
        "\n",
        "            # Backward pass\n",
        "            total_loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "            total_train_loss += total_loss.item()\n",
        "\n",
        "        # Evaluate on test set\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            test_loss = 0\n",
        "            for batch_x, batch_y in test_loader:\n",
        "                outputs = model(batch_x)\n",
        "                test_loss += criterion(outputs, batch_y).item()\n",
        "\n",
        "            train_losses.append(total_train_loss / len(train_loader))\n",
        "            test_losses.append(test_loss / len(test_loader))\n",
        "\n",
        "        if epoch % 100 == 0:\n",
        "            print(f'Epoch {epoch}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}')\n",
        "\n",
        "    return train_losses, test_losses\n"
      ],
      "metadata": {
        "id": "QR5zzJ7v057l"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run Experiments"
      ],
      "metadata": {
        "id": "urFdreDA1FTz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def run_enhanced_experiment(reg_type='l1'):\n",
        "    \"\"\"Run enhanced experiment with influence estimation tracking\"\"\"\n",
        "    print(f\"\\n=== Start Running Experiment ===\")\n",
        "\n",
        "    # Generate data\n",
        "    X_base, y = generate_friedman1_data(n_samples=2000, n_features=100)\n",
        "    X = add_random_features(X_base, n_additional=300)\n",
        "\n",
        "    # Standardize features\n",
        "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n",
        "    scaler = StandardScaler().fit(X_train)\n",
        "    X_train = scaler.transform(X_train)\n",
        "    X_test  = scaler.transform(X_test)\n",
        "    X_full  = scaler.transform(X)\n",
        "\n",
        "    # Convert to PyTorch tensors\n",
        "    X_train_tensor = torch.FloatTensor(X_train)\n",
        "    X_test_tensor = torch.FloatTensor(X_test)\n",
        "    y_train_tensor = torch.FloatTensor(y_train).unsqueeze(1)\n",
        "    y_test_tensor = torch.FloatTensor(y_test).unsqueeze(1)\n",
        "\n",
        "    # Create data loaders\n",
        "    train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)\n",
        "    test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)\n",
        "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
        "    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
        "\n",
        "    # Train model with regularization and CV tracking\n",
        "    model_with_reg = TwoLayerNet(input_dim=X.shape[1])\n",
        "    train_losses_reg, test_losses_reg, cv_epochs, cv_estimates_h, cv_estimates_f, full_cv_estimate, \\\n",
        "        cv_estimates_hessian_times, cv_estimates_fisher_times, cv_estimates_full_times = train_model_with_cv_tracking(\n",
        "        model_with_reg, train_loader, test_loader, X_full, y,\n",
        "        epochs=10, lambda_reg=0.01, reg_type=reg_type, cv_interval=1\n",
        "    )\n",
        "\n",
        "    # Train model without regularization for comparison\n",
        "    model_no_reg = TwoLayerNet(input_dim=X.shape[1])\n",
        "    train_losses_no_reg, test_losses_no_reg, cv_epochs, \\\n",
        "        cv_estimates_h_no_reg, cv_estimates_f_no_reg, full_cv_estimate_no_reg, \\\n",
        "        cv_estimates_hessian_times_no_reg, cv_estimates_fisher_times_no_reg, cv_estimates_full_times_no_reg = train_model_with_cv_tracking(\n",
        "        model_no_reg, train_loader, test_loader, X_full, y,\n",
        "        epochs=10, lambda_reg=0.0, reg_type=reg_type, cv_interval=1\n",
        "    )\n",
        "\n",
        "    return {\n",
        "        'train_losses_no_reg': train_losses_no_reg,\n",
        "        'test_losses_no_reg': test_losses_no_reg,\n",
        "        'train_losses_reg': train_losses_reg,\n",
        "        'test_losses_reg': test_losses_reg,\n",
        "        'cv_epochs': cv_epochs,\n",
        "        'cv_estimates_hessian': cv_estimates_h,\n",
        "        'cv_estimates_fisher': cv_estimates_f,\n",
        "        'cv_estimates_full': full_cv_estimate,\n",
        "        'cv_estimates_hessian_no_reg': cv_estimates_h_no_reg,\n",
        "        'cv_estimates_fisher_no_reg': cv_estimates_f_no_reg,\n",
        "        'cv_estimates_full_no_reg': full_cv_estimate_no_reg,\n",
        "        'cv_estimates_hessian_times' : cv_estimates_hessian_times,\n",
        "        'cv_estimates_fisher_times': cv_estimates_fisher_times,\n",
        "        'cv_estimates_full_times': cv_estimates_full_times,\n",
        "        'models': {'no_reg': model_no_reg, 'with_reg': model_with_reg}\n",
        "    }\n"
      ],
      "metadata": {
        "id": "Z2thuSH61G5k"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run Experiments"
      ],
      "metadata": {
        "id": "zsuWDMdw1Hco"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Run enhanced experiment with L1 regularization\n",
        "results = run_enhanced_experiment(reg_type='l1')"
      ],
      "metadata": {
        "id": "EMZIJ-2W1J7s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Main Plotting"
      ],
      "metadata": {
        "id": "SiBRIikr1MBG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Two subfigures: (1) losses & CV estimates, (2) runtimes across epochs ---\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Unique styles for ALL series (no duplicates across panels)\n",
        "styles = {\n",
        "    # Panel 1 (losses)\n",
        "    'test_loss':        dict(color='#1f77b4', ls='-',  marker='o'),  # blue\n",
        "    'cv_hessian':       dict(color='#ff7f0e', ls='--', marker='^'),  # orange\n",
        "    'cv_fisher':        dict(color='#2ca02c', ls='-.', marker='s'),  # green\n",
        "    'cv_full':          dict(color='#d62728', ls=':',  marker='D'),  # red\n",
        "    # Panel 2 (times)\n",
        "    'time_hessian':     dict(color='#9467bd', ls=(0,(3,1,1,1)), marker='P'),  # purple\n",
        "    'time_fisher':      dict(color='#8c564b', ls=(0,(5,1)),     marker='X'),  # brown\n",
        "    'time_full':        dict(color='#e377c2', ls=(0,(1,1)),     marker='v'),  # pink\n",
        "}\n",
        "\n",
        "lw_main, ms_main = 6.0, 10\n",
        "\n",
        "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))\n",
        "\n",
        "# ===== Panel 1: Losses + CV estimates =====\n",
        "epochs_loss = np.arange(len(results['train_losses_reg'])) + 1\n",
        "\n",
        "ax1.plot(\n",
        "    epochs_loss, results['test_losses_reg'],\n",
        "    label='Real Test Loss', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['test_loss']\n",
        ")\n",
        "\n",
        "ax1.plot(\n",
        "    epochs_loss, results['cv_estimates_hessian'],\n",
        "    label='Approx. CV: Hessian', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['cv_hessian']\n",
        ")\n",
        "\n",
        "ax1.plot(\n",
        "    epochs_loss, results['cv_estimates_fisher'],\n",
        "    label='Approx. CV: Fisher (ours)', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['cv_fisher']\n",
        ")\n",
        "\n",
        "ax1.plot(\n",
        "    epochs_loss, results['cv_estimates_full'],\n",
        "    label='Full CV', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['cv_full']\n",
        ")\n",
        "\n",
        "ax1.set_yscale('log')\n",
        "ax1.set_xlabel('Epoch', fontsize=26)\n",
        "ax1.set_ylabel('Loss', fontsize=26)\n",
        "ax1.grid(True, which='both', alpha=0.3)\n",
        "ax1.tick_params(axis='both', which='major', labelsize=26, length=6, width=1.5)\n",
        "ax1.tick_params(axis='both', which='minor', labelsize=26, length=3, width=1.2)\n",
        "\n",
        "# ===== Panel 2: Runtimes per epoch (seconds) =====\n",
        "epochs_h = np.arange(len(results['cv_estimates_hessian_times'])) + 1\n",
        "epochs_f = np.arange(len(results['cv_estimates_fisher_times'])) + 1\n",
        "epochs_c = np.arange(len(results['cv_estimates_full_times']))   + 1\n",
        "\n",
        "ax2.plot(\n",
        "    epochs_h, results['cv_estimates_hessian_times'],\n",
        "    label='Time: Approx. CV (Hessian)', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['time_hessian']\n",
        ")\n",
        "\n",
        "ax2.plot(\n",
        "    epochs_f, results['cv_estimates_fisher_times'],\n",
        "    label='Time: Approx. CV via Fisher (ours)', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['time_fisher']\n",
        ")\n",
        "\n",
        "ax2.plot(\n",
        "    epochs_c, results['cv_estimates_full_times'],\n",
        "    label='Time: Full CV', linewidth=lw_main, markersize=ms_main,\n",
        "    **styles['time_full']\n",
        ")\n",
        "\n",
        "ax2.set_xlabel('Epoch', fontsize=26)\n",
        "ax2.set_ylabel('Time (s)', fontsize=26)\n",
        "ax2.grid(True, which='both', alpha=0.3)\n",
        "ax2.tick_params(axis='both', which='major', labelsize=26, length=6, width=1.5)\n",
        "ax2.tick_params(axis='both', which='minor', labelsize=26, length=3, width=1.2)\n",
        "\n",
        "# ===== Global title =====\n",
        "fig.suptitle('CV Estimates and Calculation Time', fontsize=28, y=0.97)\n",
        "\n",
        "# ===== One legend below both subplots (3 columns) =====\n",
        "handles1, labels1 = ax1.get_legend_handles_labels()\n",
        "handles2, labels2 = ax2.get_legend_handles_labels()\n",
        "fig.legend(\n",
        "    handles1 + handles2, labels1 + labels2,\n",
        "    loc='upper center', bbox_to_anchor=(0.5, 0.015),  # was -0.08 → closer\n",
        "    ncol=3, frameon=False, fontsize=22,\n",
        "    borderaxespad=0.2,    # tighter gap to axes\n",
        "    labelspacing=0.5,     # tighter line spacing within legend\n",
        "    columnspacing=0.8,    # tighter column spacing\n",
        "    handlelength=1.6\n",
        ")\n",
        "\n",
        "# tighten bottom space since legend is closer now\n",
        "plt.tight_layout(rect=[0, 0.11, 1, 0.9])\n",
        "plt.subplots_adjust(bottom=0.14, top=0.88)\n",
        "plt.show()\n",
        "fig.savefig(\"cv_estimates_and_time.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n"
      ],
      "metadata": {
        "id": "u_dy59XW5hUk"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}