{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!/usr/bin/env python3\n",
    "\"\"\"\n",
    "Log-sum-exp scaling-law fits with optional log-interaction terms.\n",
    "Requires: numpy, pandas, scipy, matplotlib (optional for plots).\n",
    "\"\"\"\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.optimize import minimize\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "# import ace_tools as tools\n",
    "\n",
    "# 1. Load the latest dense2jp.txt\n",
    "file_path = Path(\"../data/base2code.txt\")\n",
    "text = file_path.read_text().strip().splitlines()\n",
    "\n",
    "# parse rows\n",
    "rows = []\n",
    "for line in text:\n",
    "    parts = line.strip().split()\n",
    "    if len(parts) == 5:\n",
    "        _, _, D1, D2, y_true = parts\n",
    "    else:\n",
    "        # if fewer columns, assume last three are D1 D2 Loss\n",
    "        D1, D2, y_true = parts[-3:]\n",
    "    rows.append([float(D1), float(D2), float(y_true.replace('E+', 'e+'))])\n",
    "\n",
    "df = pd.DataFrame(rows, columns=['D1', 'D2', 'Loss'])\n",
    "D1, D2, y_true = df['D1'].values * 1024 * 512, df['D2'].values* 1024 * 512, df['Loss'].values\n",
    "inp_upc = np.column_stack((D1, D2, y_true))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### multiplicative with interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, b, c, d1, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1) - b * np.log(x2) + c * np.log(x1)* np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "# import numpy as np\n",
    "# from sklearn.metrics import mean_squared_error\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "\n",
    "def leave_one_out_cross_validation_random_grid_early_stopping(inp, param_grid, max_iterations=100):\n",
    "    \"\"\"\n",
    "    Perform Leave-One-Out Cross-Validation (LOOCV) with random sampling from the grid (without replacement)\n",
    "    and early stopping after a fixed number of iterations.\n",
    "\n",
    "    Parameters:\n",
    "    - inp: The dataset, a numpy array of shape (n_samples, 3), where columns represent x1, x2, and y.\n",
    "    - param_grid: Dictionary defining the grid of parameter values.\n",
    "    - max_iterations: Maximum number of iterations for grid search per training set.\n",
    "\n",
    "    Returns:\n",
    "    - avg_rmse: Average RMSE across all leave-one-out splits.\n",
    "    \"\"\"\n",
    "    n_samples = inp.shape[0]\n",
    "    loo_rmse = []\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'],\n",
    "        param_grid['b'],\n",
    "        param_grid['c'],\n",
    "        param_grid['d1'],\n",
    "        param_grid['d3']\n",
    "    ))\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Leave out the i-th data point\n",
    "        test_point = inp[i:i+1]\n",
    "        train_data = np.delete(inp, i, axis=0)\n",
    "\n",
    "        min_loss = float('inf')\n",
    "        best_params = None\n",
    "        no_improve_rounds = 0\n",
    "        iteration_count = 0\n",
    "\n",
    "        # Shuffle the combinations for random sampling\n",
    "        random.shuffle(param_combinations)\n",
    "\n",
    "        for params in param_combinations:\n",
    "            # Increment iteration count\n",
    "            iteration_count += 1\n",
    "\n",
    "            # Run optimization with the current parameter combination as initialization\n",
    "            objective_grad = grad(objective_fn)\n",
    "            result = minimize(\n",
    "                objective_fn,\n",
    "                params,\n",
    "                args=(train_data,),\n",
    "                method='BFGS',\n",
    "                jac=objective_grad\n",
    "            )\n",
    "\n",
    "            # Check if the loss improves\n",
    "            if result.fun < min_loss:\n",
    "                min_loss = result.fun\n",
    "                best_params = result.x\n",
    "                no_improve_rounds = 0  # Reset early stopping counter\n",
    "            else:\n",
    "                no_improve_rounds += 1\n",
    "\n",
    "            # Apply early stopping if max_iterations is reached\n",
    "            if iteration_count >= max_iterations:\n",
    "                print(f\"Early stopping applied after {max_iterations} iterations for this split.\")\n",
    "                break\n",
    "\n",
    "        # Use the best parameters to predict on the test point\n",
    "        test_loss = loss_scipy_upcycle(best_params, test_point, loss_type=\"mse\")\n",
    "        test_rmse = np.sqrt(test_loss / len(test_point))\n",
    "\n",
    "        # Store the RMSE for this test sample\n",
    "        loo_rmse.append(test_rmse)\n",
    "\n",
    "    # Return the average RMSE across all leave-one-out splits\n",
    "    avg_rmse = np.mean(loo_rmse), np.std(loo_rmse)\n",
    "    return avg_rmse\n",
    "\n",
    "def grid_search_optimize(inp, param_grid, early_stopping_rounds=10):\n",
    "    min_loss = float('inf')\n",
    "    best_params = None\n",
    "    no_improve_rounds = 0\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'], \n",
    "        param_grid['b'], \n",
    "        param_grid['c'], \n",
    "        param_grid['d1'], \n",
    "        param_grid['d3'], \n",
    "    ))\n",
    "    random.shuffle(param_combinations)\n",
    "\n",
    "    for params in param_combinations:\n",
    "        a, b, c, d1, E = params\n",
    "        init_params = (a, b, c, d1, E)  # For upcycle model\n",
    "\n",
    "        objective_grad = grad(objective_fn)\n",
    "\n",
    "        # Run optimization with the Jacobian\n",
    "        result = minimize(\n",
    "            objective_fn, \n",
    "            init_params, \n",
    "            args=inp,\n",
    "            method='BFGS', \n",
    "            jac=objective_grad\n",
    "        )\n",
    "\n",
    "        l = result.fun\n",
    "        optimized_params = result.x\n",
    "\n",
    "        # Update best params if loss is lower\n",
    "        if l < min_loss:\n",
    "            min_loss = l\n",
    "            best_params = optimized_params\n",
    "            no_improve_rounds = 0  # Reset early stopping counter\n",
    "        else:\n",
    "            no_improve_rounds += 1\n",
    "\n",
    "        # Check for early stopping\n",
    "        if no_improve_rounds >= early_stopping_rounds:\n",
    "            print(f\"Early stopping after {early_stopping_rounds} rounds with no improvement.\")\n",
    "            break\n",
    "\n",
    "    return min_loss, best_params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Average LOOCV RMSE (mean, std): (np.float64(0.0015732781524633755), np.float64(0.0009297851770543415))\n"
     ]
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'a': np.linspace(0, 0.5, 5),\n",
    "    'b': np.linspace(0, 0.5, 5),\n",
    "    'c': np.linspace(0, 0.5, 5),\n",
    "    'd1': np.linspace(0, 10, 5),\n",
    "    'd3': np.linspace(0, 10, 5)\n",
    "}\n",
    "\n",
    "avg_rmse = leave_one_out_cross_validation_random_grid_early_stopping(inp_upc, param_grid)\n",
    "print(f\"Average LOOCV RMSE (mean, std): {avg_rmse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### multi without interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, b, d1, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1) - b * np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "# import numpy as np\n",
    "# from sklearn.metrics import mean_squared_error\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "\n",
    "def leave_one_out_cross_validation_random_grid_early_stopping(inp, param_grid, max_iterations=100):\n",
    "    \"\"\"\n",
    "    Perform Leave-One-Out Cross-Validation (LOOCV) with random sampling from the grid (without replacement)\n",
    "    and early stopping after a fixed number of iterations.\n",
    "\n",
    "    Parameters:\n",
    "    - inp: The dataset, a numpy array of shape (n_samples, 3), where columns represent x1, x2, and y.\n",
    "    - param_grid: Dictionary defining the grid of parameter values.\n",
    "    - max_iterations: Maximum number of iterations for grid search per training set.\n",
    "\n",
    "    Returns:\n",
    "    - avg_rmse: Average RMSE across all leave-one-out splits.\n",
    "    \"\"\"\n",
    "    n_samples = inp.shape[0]\n",
    "    loo_rmse = []\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'],\n",
    "        param_grid['b'],\n",
    "        param_grid['d1'],\n",
    "        param_grid['d3']\n",
    "    ))\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Leave out the i-th data point\n",
    "        test_point = inp[i:i+1]\n",
    "        train_data = np.delete(inp, i, axis=0)\n",
    "\n",
    "        min_loss = float('inf')\n",
    "        best_params = None\n",
    "        no_improve_rounds = 0\n",
    "        iteration_count = 0\n",
    "\n",
    "        # Shuffle the combinations for random sampling\n",
    "        random.shuffle(param_combinations)\n",
    "\n",
    "        for params in param_combinations:\n",
    "            # Increment iteration count\n",
    "            iteration_count += 1\n",
    "\n",
    "            # Run optimization with the current parameter combination as initialization\n",
    "            objective_grad = grad(objective_fn)\n",
    "            result = minimize(\n",
    "                objective_fn,\n",
    "                params,\n",
    "                args=(train_data,),\n",
    "                method='BFGS',\n",
    "                jac=objective_grad\n",
    "            )\n",
    "\n",
    "            # Check if the loss improves\n",
    "            if result.fun < min_loss:\n",
    "                min_loss = result.fun\n",
    "                best_params = result.x\n",
    "                no_improve_rounds = 0  # Reset early stopping counter\n",
    "            else:\n",
    "                no_improve_rounds += 1\n",
    "\n",
    "            # Apply early stopping if max_iterations is reached\n",
    "            if iteration_count >= max_iterations:\n",
    "                print(f\"Early stopping applied after {max_iterations} iterations for this split.\")\n",
    "                break\n",
    "\n",
    "        # Use the best parameters to predict on the test point\n",
    "        test_loss = loss_scipy_upcycle(best_params, test_point, loss_type=\"mse\")\n",
    "        test_rmse = np.sqrt(test_loss / len(test_point))\n",
    "\n",
    "        # Store the RMSE for this test sample\n",
    "        loo_rmse.append(test_rmse)\n",
    "\n",
    "    # Return the average RMSE across all leave-one-out splits\n",
    "    avg_rmse = np.mean(loo_rmse)\n",
    "    return avg_rmse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Average LOOCV RMSE: 0.0022454509880876915\n"
     ]
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'a': np.linspace(0, 1, 5),\n",
    "    'b': np.linspace(0, 1, 5),\n",
    "    'd1': np.linspace(0, 10, 5),\n",
    "    'd3': np.linspace(0, 10, 5)\n",
    "}\n",
    "\n",
    "avg_rmse = leave_one_out_cross_validation_random_grid_early_stopping(inp_upc, param_grid)\n",
    "print(f\"Average LOOCV RMSE: {avg_rmse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### additive "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, b,  d1, d2, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1)  \n",
    "    term2 = d2 - b*np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1, term2,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "# import numpy as np\n",
    "# from sklearn.metrics import mean_squared_error\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "\n",
    "def leave_one_out_cross_validation_random_grid_early_stopping(inp, param_grid, max_iterations=100):\n",
    "    \"\"\"\n",
    "    Perform Leave-One-Out Cross-Validation (LOOCV) with random sampling from the grid (without replacement)\n",
    "    and early stopping after a fixed number of iterations.\n",
    "\n",
    "    Parameters:\n",
    "    - inp: The dataset, a numpy array of shape (n_samples, 3), where columns represent x1, x2, and y.\n",
    "    - param_grid: Dictionary defining the grid of parameter values.\n",
    "    - max_iterations: Maximum number of iterations for grid search per training set.\n",
    "\n",
    "    Returns:\n",
    "    - avg_rmse: Average RMSE across all leave-one-out splits.\n",
    "    \"\"\"\n",
    "    n_samples = inp.shape[0]\n",
    "    loo_rmse = []\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'],\n",
    "        param_grid['b'],\n",
    "        param_grid['d1'],\n",
    "        param_grid['d2'],\n",
    "        param_grid['d3']\n",
    "    ))\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Leave out the i-th data point\n",
    "        test_point = inp[i:i+1]\n",
    "        train_data = np.delete(inp, i, axis=0)\n",
    "\n",
    "        min_loss = float('inf')\n",
    "        best_params = None\n",
    "        no_improve_rounds = 0\n",
    "        iteration_count = 0\n",
    "\n",
    "        # Shuffle the combinations for random sampling\n",
    "        random.shuffle(param_combinations)\n",
    "\n",
    "        for params in param_combinations:\n",
    "            # Increment iteration count\n",
    "            iteration_count += 1\n",
    "\n",
    "            # Run optimization with the current parameter combination as initialization\n",
    "            objective_grad = grad(objective_fn)\n",
    "            result = minimize(\n",
    "                objective_fn,\n",
    "                params,\n",
    "                args=(train_data,),\n",
    "                method='BFGS',\n",
    "                jac=objective_grad\n",
    "            )\n",
    "\n",
    "            # Check if the loss improves\n",
    "            if result.fun < min_loss:\n",
    "                min_loss = result.fun\n",
    "                best_params = result.x\n",
    "                no_improve_rounds = 0  # Reset early stopping counter\n",
    "            else:\n",
    "                no_improve_rounds += 1\n",
    "\n",
    "            # Apply early stopping if max_iterations is reached\n",
    "            if iteration_count >= max_iterations:\n",
    "                print(f\"Early stopping applied after {max_iterations} iterations for this split.\")\n",
    "                break\n",
    "\n",
    "        # Use the best parameters to predict on the test point\n",
    "        test_loss = loss_scipy_upcycle(best_params, test_point, loss_type=\"mse\")\n",
    "        test_rmse = np.sqrt(test_loss / len(test_point))\n",
    "\n",
    "        # Store the RMSE for this test sample\n",
    "        loo_rmse.append(test_rmse)\n",
    "\n",
    "    # Return the average RMSE across all leave-one-out splits\n",
    "    avg_rmse = np.mean(loo_rmse)\n",
    "    return avg_rmse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Average LOOCV RMSE: 0.0039130386681377515\n"
     ]
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'a': np.linspace(0, 1, 5),\n",
    "    'b': np.linspace(0, 1, 5),\n",
    "    'd1': np.linspace(0, 10, 5),\n",
    "    'd2': np.linspace(0, 10, 5),\n",
    "    'd3': np.linspace(0, 10, 5)\n",
    "}\n",
    "\n",
    "avg_rmse = leave_one_out_cross_validation_random_grid_early_stopping(inp_upc, param_grid)\n",
    "print(f\"Average LOOCV RMSE: {avg_rmse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### hybrid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, b,  d1, d2, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1)  - b*np.log(x2)\n",
    "    term2 = d2 - b*np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1, term2,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "# import numpy as np\n",
    "# from sklearn.metrics import mean_squared_error\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "\n",
    "def leave_one_out_cross_validation_random_grid_early_stopping(inp, param_grid, max_iterations=100):\n",
    "    \"\"\"\n",
    "    Perform Leave-One-Out Cross-Validation (LOOCV) with random sampling from the grid (without replacement)\n",
    "    and early stopping after a fixed number of iterations.\n",
    "\n",
    "    Parameters:\n",
    "    - inp: The dataset, a numpy array of shape (n_samples, 3), where columns represent x1, x2, and y.\n",
    "    - param_grid: Dictionary defining the grid of parameter values.\n",
    "    - max_iterations: Maximum number of iterations for grid search per training set.\n",
    "\n",
    "    Returns:\n",
    "    - avg_rmse: Average RMSE across all leave-one-out splits.\n",
    "    \"\"\"\n",
    "    n_samples = inp.shape[0]\n",
    "    loo_rmse = []\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'],\n",
    "        param_grid['b'],\n",
    "        param_grid['d1'],\n",
    "        param_grid['d2'],\n",
    "        param_grid['d3']\n",
    "    ))\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Leave out the i-th data point\n",
    "        test_point = inp[i:i+1]\n",
    "        train_data = np.delete(inp, i, axis=0)\n",
    "\n",
    "        min_loss = float('inf')\n",
    "        best_params = None\n",
    "        no_improve_rounds = 0\n",
    "        iteration_count = 0\n",
    "\n",
    "        # Shuffle the combinations for random sampling\n",
    "        random.shuffle(param_combinations)\n",
    "\n",
    "        for params in param_combinations:\n",
    "            # Increment iteration count\n",
    "            iteration_count += 1\n",
    "\n",
    "            # Run optimization with the current parameter combination as initialization\n",
    "            objective_grad = grad(objective_fn)\n",
    "            result = minimize(\n",
    "                objective_fn,\n",
    "                params,\n",
    "                args=(train_data,),\n",
    "                method='BFGS',\n",
    "                jac=objective_grad\n",
    "            )\n",
    "\n",
    "            # Check if the loss improves\n",
    "            if result.fun < min_loss:\n",
    "                min_loss = result.fun\n",
    "                best_params = result.x\n",
    "                no_improve_rounds = 0  # Reset early stopping counter\n",
    "            else:\n",
    "                no_improve_rounds += 1\n",
    "\n",
    "            # Apply early stopping if max_iterations is reached\n",
    "            if iteration_count >= max_iterations:\n",
    "                print(f\"Early stopping applied after {max_iterations} iterations for this split.\")\n",
    "                break\n",
    "\n",
    "        # Use the best parameters to predict on the test point\n",
    "        test_loss = loss_scipy_upcycle(best_params, test_point, loss_type=\"mse\")\n",
    "        test_rmse = np.sqrt(test_loss / len(test_point))\n",
    "\n",
    "        # Store the RMSE for this test sample\n",
    "        loo_rmse.append(test_rmse)\n",
    "\n",
    "    # Return the average RMSE across all leave-one-out splits\n",
    "    avg_rmse = np.mean(loo_rmse)\n",
    "    return avg_rmse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Average LOOCV RMSE: 0.0022454321099805875\n"
     ]
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'a': np.linspace(0, 1, 5),\n",
    "    'b': np.linspace(0, 1, 5),\n",
    "    'd1': np.linspace(0, 10, 5),\n",
    "    'd2': np.linspace(0, 10, 5),\n",
    "    'd3': np.linspace(0, 10, 5)\n",
    "}\n",
    "\n",
    "avg_rmse = leave_one_out_cross_validation_random_grid_early_stopping(inp_upc, param_grid)\n",
    "print(f\"Average LOOCV RMSE: {avg_rmse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### fit D1+D2 scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, d1, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1+x2)  \n",
    "    # term2 = d2 - b*np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "# import numpy as np\n",
    "# from sklearn.metrics import mean_squared_error\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "\n",
    "def leave_one_out_cross_validation_random_grid_early_stopping(inp, param_grid, max_iterations=100):\n",
    "    \"\"\"\n",
    "    Perform Leave-One-Out Cross-Validation (LOOCV) with random sampling from the grid (without replacement)\n",
    "    and early stopping after a fixed number of iterations.\n",
    "\n",
    "    Parameters:\n",
    "    - inp: The dataset, a numpy array of shape (n_samples, 3), where columns represent x1, x2, and y.\n",
    "    - param_grid: Dictionary defining the grid of parameter values.\n",
    "    - max_iterations: Maximum number of iterations for grid search per training set.\n",
    "\n",
    "    Returns:\n",
    "    - avg_rmse: Average RMSE across all leave-one-out splits.\n",
    "    \"\"\"\n",
    "    n_samples = inp.shape[0]\n",
    "    loo_rmse = []\n",
    "\n",
    "    # Generate all combinations of parameter values\n",
    "    param_combinations = list(product(\n",
    "        param_grid['a'],\n",
    "        param_grid['d1'],\n",
    "        param_grid['d3']\n",
    "    ))\n",
    "\n",
    "    for i in range(n_samples):\n",
    "        # Leave out the i-th data point\n",
    "        test_point = inp[i:i+1]\n",
    "        train_data = np.delete(inp, i, axis=0)\n",
    "\n",
    "        min_loss = float('inf')\n",
    "        best_params = None\n",
    "        no_improve_rounds = 0\n",
    "        iteration_count = 0\n",
    "\n",
    "        # Shuffle the combinations for random sampling\n",
    "        random.shuffle(param_combinations)\n",
    "\n",
    "        for params in param_combinations:\n",
    "            # Increment iteration count\n",
    "            iteration_count += 1\n",
    "\n",
    "            # Run optimization with the current parameter combination as initialization\n",
    "            objective_grad = grad(objective_fn)\n",
    "            result = minimize(\n",
    "                objective_fn,\n",
    "                params,\n",
    "                args=(train_data,),\n",
    "                method='BFGS',\n",
    "                jac=objective_grad\n",
    "            )\n",
    "\n",
    "            # Check if the loss improves\n",
    "            if result.fun < min_loss:\n",
    "                min_loss = result.fun\n",
    "                best_params = result.x\n",
    "                no_improve_rounds = 0  # Reset early stopping counter\n",
    "            else:\n",
    "                no_improve_rounds += 1\n",
    "\n",
    "            # Apply early stopping if max_iterations is reached\n",
    "            if iteration_count >= max_iterations:\n",
    "                print(f\"Early stopping applied after {max_iterations} iterations for this split.\")\n",
    "                break\n",
    "\n",
    "        # Use the best parameters to predict on the test point\n",
    "        test_loss = loss_scipy_upcycle(best_params, test_point, loss_type=\"mse\")\n",
    "        test_rmse = np.sqrt(test_loss / len(test_point))\n",
    "\n",
    "        # Store the RMSE for this test sample\n",
    "        loo_rmse.append(test_rmse)\n",
    "\n",
    "    # Return the average RMSE across all leave-one-out splits\n",
    "    avg_rmse = np.mean(loo_rmse)\n",
    "    return avg_rmse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Early stopping applied after 100 iterations for this split.\n",
      "Average LOOCV RMSE: 0.021284841952871805\n"
     ]
    }
   ],
   "source": [
    "param_grid = {\n",
    "    'a': np.linspace(0, 1, 5),\n",
    "    'd1': np.linspace(0, 10, 5),\n",
    "    'd3': np.linspace(0, 10, 5)\n",
    "}\n",
    "\n",
    "avg_rmse = leave_one_out_cross_validation_random_grid_early_stopping(inp_upc, param_grid)\n",
    "print(f\"Average LOOCV RMSE: {avg_rmse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fit all to get coeff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import autograd.numpy as np\n",
    "from autograd import grad\n",
    "from scipy.optimize import minimize\n",
    "from tqdm import tqdm\n",
    "\n",
    "def custom_logsumexp(terms):\n",
    "    max_term = np.max(terms, axis=0)\n",
    "    sum_exp = np.sum(np.exp(terms - max_term), axis=0)\n",
    "    return max_term + np.log(sum_exp)\n",
    "\n",
    "\n",
    "def loss_scipy_upcycle(params, inp, loss_type=\"huber\", delta = 1e-3):\n",
    "    # a, b, d1 = params\n",
    "    a, b, c, d1, d3 = params\n",
    "    x1, x2, y = inp[:, 0], inp[:, 1], inp[:, 2]\n",
    "\n",
    "    term1 = d1  - a * np.log( x1) - b * np.log(x2) + c * np.log(x1)* np.log(x2)\n",
    "    term3 = d3 * np.ones_like(x1) # use fit from scratch models\n",
    "\n",
    "    # Use custom_logsumexp for numerical stability\n",
    "    terms = np.stack([term1,term3], axis=0)\n",
    "    post_lse = custom_logsumexp(terms)\n",
    "    \n",
    "    if loss_type == \"huber\":\n",
    "        residual = post_lse - np.log(y)\n",
    "        loss = np.where(np.abs(residual) <= delta,\n",
    "                                0.5 * residual**2,\n",
    "                                delta * (np.abs(residual) - 0.5 * delta)).sum()\n",
    "    elif loss_type == \"msle\":\n",
    "        loss = ((post_lse - np.log(y))**2).sum()\n",
    "    elif loss_type == \"mse\":\n",
    "        loss = ((np.exp(post_lse) - y)**2).sum()\n",
    "    else:\n",
    "        raise NotImplementedError(f\"loss {loss_type} not implemented!\")\n",
    "    \n",
    "    return loss \n",
    "\n",
    "def objective_fn(params, inp):\n",
    "    loss3 = loss_scipy_upcycle(params, inp)\n",
    "    total_loss = loss3\n",
    "    return total_loss\n",
    "\n",
    "def random_search_optimize( inp, n_iter=100, early_stopping_rounds=10):\n",
    "    min_loss = float('inf')\n",
    "    best_params = None\n",
    "    no_improve_rounds = 0\n",
    "\n",
    "    for _ in range(n_iter):\n",
    "        # Sample random parameters from the specified ranges\n",
    "        a = np.random.uniform(0, 1)\n",
    "        b = np.random.uniform(0, 1)\n",
    "        c = np.random.uniform(0, 1)\n",
    "        d1 = np.random.uniform(0, 10)\n",
    "        d2 = np.random.uniform(0, 10)\n",
    "        # d3 = np.random.uniform(0, 10)\n",
    "        init_params = (a, b,c, d1, d2) # For upcycle model\n",
    "        \n",
    "        objective_grad = grad(objective_fn)\n",
    "        \n",
    "        # Run optimization with the Jacobian\n",
    "        result = minimize(\n",
    "            objective_fn, \n",
    "            init_params, \n",
    "            args=inp,\n",
    "            method='BFGS', \n",
    "            # method='L-BFGS-B', \n",
    "            jac=objective_grad\n",
    "        )\n",
    "        \n",
    "        l = result.fun\n",
    "        params = result.x\n",
    "        \n",
    "        # Update best params if loss is lower\n",
    "        if l < min_loss:\n",
    "            min_loss = l\n",
    "            best_params = params\n",
    "            no_improve_rounds = 0  # Reset early stopping counter\n",
    "        else:\n",
    "            no_improve_rounds += 1\n",
    "\n",
    "        # Check for early stopping\n",
    "        if no_improve_rounds >= early_stopping_rounds:\n",
    "            print(f\"Early stopping after {early_stopping_rounds} rounds with no improvement.\")\n",
    "            break\n",
    "\n",
    "    return min_loss, best_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.534215629898169e-05\n",
      "[ 1.06052382e-01  1.46479714e-01  4.40965531e-03  3.53211018e+00\n",
      " -3.26973671e+01]\n"
     ]
    }
   ],
   "source": [
    "loss, param = random_search_optimize(inp_upc, n_iter=100, early_stopping_rounds=100)\n",
    "param_grid = {\n",
    "    'a': np.linspace(0, 1, 5),  # 5 equally spaced values between 0 and 1\n",
    "    'b': np.linspace(0, 1, 5),\n",
    "    'c': np.linspace(0, 1, 5),\n",
    "    'd1': np.linspace(-5, 5, 5),\n",
    "    'd2': np.linspace(-5, 5, 5),\n",
    "}\n",
    "# loss, param = grid_search_optimize(inp_upc, param_grid, early_stopping_rounds=1000)\n",
    "\n",
    "print(loss)\n",
    "print(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
