{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "76ce232f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Imports\n",
    "# ==========================================\n",
    "\n",
    "import numpy as np\n",
    "from anova_module import ModelAnalysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ae4ceeb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==========================================\n",
    "# Context : Function f and support of X\n",
    "# ==========================================\n",
    "\n",
    "# Support of X\n",
    "def generate_joint_support(N: int) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Generates the tabular support of size r x 5 for the categorical random vector X.\n",
    "    \n",
    "    The vector X = (X1, X2, X3, X4, X5) satisfies the following structural equations:\n",
    "        - X1, X2, X4: Independent variables with support {0, ..., N-1}.\n",
    "        - X3: Deterministic variable where X3 = X2 almost surely.\n",
    "        - X5: Deterministic constant where X5 = 0 almost surely.\n",
    "    \n",
    "    The resulting support size is r = N^3.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    N : int\n",
    "        The cardinality of the sample space for the independent variables.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        A matrix of shape (N^3, 5) containing all possible realizations of the vector X.\n",
    "    \"\"\"\n",
    "    \n",
    "    # 1. Generate the grid for the independent variables X1, X2, X4.\n",
    "    # We generate indices for a 3D grid of shape (N, N, N).\n",
    "    # Reshape and transpose to obtain the Cartesian product of size (N^3, 3).\n",
    "    grid = np.indices((N, N, N)).reshape(3, -1).T\n",
    "    \n",
    "    # Extract independent components from the grid\n",
    "    x1 = grid[:, 0]\n",
    "    x2 = grid[:, 1]\n",
    "    x4 = grid[:, 2]\n",
    "    \n",
    "    # 2. Construct the dependent/deterministic variables\n",
    "    x3 = x2                     # Constraint: X3 copies X2 (perfect correlation)\n",
    "    x5 = np.zeros(N**3, dtype=int)  # Constraint: X5 is constant at 0\n",
    "    \n",
    "    # 3. Stack all components to form the joint support matrix\n",
    "    support = np.column_stack((x1, x2, x3, x4, x5))\n",
    "    \n",
    "    return support\n",
    "\n",
    "# Function f\n",
    "def compute_linear_threshold(X: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Computes the sign of a fixed linear combination of the first three variables.\n",
    "    \n",
    "    Given an input matrix X of shape (n, 5), this function calculates:\n",
    "        y = sign(a * X1 + b * X2 + c * X3)\n",
    "        \n",
    "    where a, b, and c are fixed hyperparameters defined within the function.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray\n",
    "        Input data matrix of shape (n, 5).\n",
    "        Columns must correspond to [X1, X2, X3, X4, X5].\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        Output vector of shape (n,) containing values {-1, 0, 1}.\n",
    "    \"\"\"\n",
    "    # ---------------------------------------------------------\n",
    "    # Fixed Hyperparameters (Hardcoded as requested)\n",
    "    # ---------------------------------------------------------\n",
    "    ALPHA = 1   # Coef for X1\n",
    "    BETA  = -1  # Coef for X2\n",
    "    GAMMA = 0.5   # Coef for X3\n",
    "    \n",
    "    # ---------------------------------------------------------\n",
    "    # Vectorized Computation\n",
    "    # ---------------------------------------------------------\n",
    "    # We use column slicing X[:, i] to perform operations on the entire \n",
    "    # dataset at once. This leverages BLAS optimization under the hood.\n",
    "    \n",
    "    linear_combination = (\n",
    "        ALPHA * X[:, 0] + \n",
    "        BETA  * X[:, 1] + \n",
    "        GAMMA * X[:, 2]\n",
    "    )\n",
    "    \n",
    "    # np.sign returns -1 if x < 0, 0 if x == 0, 1 if x > 0\n",
    "    return np.sign(linear_combination)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9dade6c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix:   0%|\u001b[32m          \u001b[0m| 0/27 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Constructing Basis Matrix: 100%|\u001b[32m██████████\u001b[0m| 27/27 [00:00<00:00, 4777.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computations complete. Results ready.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# ==========================================\n",
    "# Parameters and application of our framework\n",
    "# ==========================================\n",
    "\n",
    "d = 5 # number of variables\n",
    "N = 3 # number of modalities for each rv\n",
    "X = generate_joint_support(N) # support\n",
    "f_model = compute_linear_threshold # function\n",
    "A = ModelAnalysis(X , f_model , 100 , 1e-3 , 1e-10) # class\n",
    "S , Matrix = A.functional_anova() # sets and f_A\n",
    "P = A.get_P() # probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "73b6acfd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[], [1], [2], [4], [1, 2], [1, 4], [2, 4], [1, 2, 4]]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Sets S\n",
    "\n",
    "S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ea16f17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.11111111e-01, 5.18518518e-01, 7.40740741e-02, 9.39334830e-37,\n",
       "       7.40740741e-02, 8.34964294e-37, 8.34964294e-37, 0.00000000e+00])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Norms of f_S(X_S) (in L^2 equipped with its scalar product)\n",
    "\n",
    "np.sum( (Matrix**2).T * P , axis=1 )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hfd_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
