{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46947d4c",
   "metadata": {},
   "source": [
    "# Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be16c24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from model_module import CGCLR\n",
    "from trainer_module import training\n",
    "import time\n",
    "from scipy import stats\n",
    "from mpl_toolkits import mplot3d\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b54226e2",
   "metadata": {},
   "source": [
    "# Synthentic Data Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a355e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "n = 50  # number of samples for each region\n",
    "noise_std = 0.1\n",
    "\n",
    "# covariates sampling\n",
    "def sample_in_region(n, condition, x_range=(-1, 1), y_range=(-2, 2)):\n",
    "    samples = []\n",
    "    while len(samples) < n:\n",
    "        x1 = np.random.uniform(x_range[0], x_range[1])\n",
    "        x2 = np.random.uniform(y_range[0], y_range[1])\n",
    "        if condition(x1, x2):\n",
    "            samples.append((x1, x2))\n",
    "    samples = np.array(samples)\n",
    "    return samples[:, 0], samples[:, 1]\n",
    "\n",
    "x1_north, x2_north = sample_in_region(n, lambda x1, x2: x2 > abs(x1))\n",
    "x1_west, x2_west = sample_in_region(n, lambda x1, x2: x1 < -abs(x2))\n",
    "x1_south, x2_south = sample_in_region(n, lambda x1, x2: x2 < -abs(x1))\n",
    "x1_east, x2_east = sample_in_region(n, lambda x1, x2: x1 > abs(x2))\n",
    "\n",
    "# customized clusterwise linear function\n",
    "def f1(x1, x2):\n",
    "    return x1+x2-10\n",
    "def f2(x1, x2):\n",
    "    return -x1 -x2 -3\n",
    "def f3(x1, x2):\n",
    "    return x1 + 2*x2 +2\n",
    "\n",
    "# generate responses\n",
    "y_north = f1(x1_north, x2_north) + np.random.normal(0, noise_std, n)\n",
    "y_west  = f2(x1_west, x2_west) + np.random.normal(0, noise_std, n)\n",
    "y_south = f1(x1_south, x2_south) + np.random.normal(0, noise_std, n)\n",
    "y_east  = f3(x1_east, x2_east) + np.random.normal(0, noise_std, n)\n",
    "\n",
    "# sampling results\n",
    "x1_all = np.concatenate([x1_north, x1_west, x1_south, x1_east])\n",
    "x2_all = np.concatenate([x2_north, x2_west, x2_south, x2_east])\n",
    "y_all = np.concatenate([y_north, y_west, y_south, y_east])\n",
    "dataset = np.column_stack((x1_all, x2_all, y_all))\n",
    "df = pd.DataFrame(dataset, columns=[\"x1\", \"x2\", \"y\"])\n",
    "X = np.array(df[['x1','x2']])\n",
    "Y = np.array(df[['y']])\n",
    "\n",
    "# Train Dataset Preprocessing\n",
    "train_X, train_Y = X, Y\n",
    "scaler_x = StandardScaler()\n",
    "train_X = scaler_x.fit_transform(train_X)\n",
    "scaler_y = StandardScaler()\n",
    "train_Y = scaler_y.fit_transform(train_Y)\n",
    "train_X = torch.tensor(np.array(train_X)).reshape(-1, train_X.shape[1]).float()\n",
    "train_Y = torch.tensor(np.array(train_Y)).reshape(-1, 1).float()        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "860f4a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X_np = scaler_x.inverse_transform(train_X.cpu().numpy())  # shape: (4*n, 2)\n",
    "train_Y_np = scaler_y.inverse_transform(train_Y.cpu().numpy()).flatten()  # shape: (4*n,)\n",
    "\n",
    "# evaluation covariates\n",
    "x1_min, x1_max = train_X_np[:, 0].min(), train_X_np[:, 0].max()\n",
    "x2_min, x2_max = train_X_np[:, 1].min(), train_X_np[:, 1].max()\n",
    "num_grid = 1000\n",
    "x1_range = np.linspace(x1_min, x1_max, num_grid)\n",
    "x2_range = np.linspace(x2_min, x2_max, num_grid)\n",
    "x1_mesh, x2_mesh = np.meshgrid(x1_range, x2_range)\n",
    "grid_points = np.c_[x1_mesh.ravel(), x2_mesh.ravel()]\n",
    "\n",
    "# denoised response\n",
    "def ground_truth_function(x1, x2):\n",
    "    cond_north = x2 > np.abs(x1)\n",
    "    cond_west  = x1 < -np.abs(x2)\n",
    "    cond_south = x2 < -np.abs(x1)\n",
    "    cond_east  = x1 > np.abs(x2)\n",
    "    return np.select([cond_north, cond_west, cond_south, cond_east],\n",
    "                     [f1(x1,x2), f2(x1,x2), f1(x1,x2), f3(x1,x2)],\n",
    "                     default=np.nan)\n",
    "\n",
    "def ground_truth_weight_function(x1, x2):\n",
    "    x1 = np.asarray(x1)\n",
    "    x2 = np.asarray(x2)\n",
    "    cond_north = x2 >  np.abs(x1)\n",
    "    cond_west  = x1 < -np.abs(x2)\n",
    "    cond_south = x2 < -np.abs(x1)\n",
    "    cond_east  = x1 >  np.abs(x2)\n",
    "    conds = [cond_north, cond_west, cond_south, cond_east]\n",
    "    b   = np.select(conds, [ 1, -1,  1,  1], default=np.nan)\n",
    "    w1  = np.select(conds, [ 1, -1,  1,  2], default=np.nan)\n",
    "    w2  = np.select(conds, [-10, -3,-10,  2], default=np.nan)\n",
    "    return np.stack([b, w1, w2], axis=-1)\n",
    "\n",
    "Y_true_ground = ground_truth_function(x1_mesh, x2_mesh)\n",
    "\n",
    "# 2D Contour Map with training samples\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, Y_true_ground, levels=100, cmap='viridis', alpha=0.8, vmin=-13, vmax=5)\n",
    "plt.scatter(X[:, 0], X[:, 1], color='black', marker='x', s=30, label='Train Data')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"Ground Truth Function with training samples\")\n",
    "plt.legend(loc='center left')\n",
    "plt.colorbar(contour, label=\"$y$\")\n",
    "plt.savefig('./img/True_Function.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "# 3D surface plot\n",
    "fig = plt.figure(figsize=(12, 10))\n",
    "ax = plt.axes(111,projection='3d')\n",
    "ax.view_init(45)\n",
    "surface = ax.contour3D(x1_mesh, x2_mesh, Y_true_ground,500,cmap='viridis')\n",
    "ax.set_xlabel(\"$x^1$\")\n",
    "ax.set_ylabel(\"$x^2$\")\n",
    "ax.set_zlabel(\"$y$\")\n",
    "fig.colorbar(surface, shrink=0.5, aspect=5, label=\"$y$\")\n",
    "plt.legend()\n",
    "plt.savefig('./img/True_Function_Surface.png', dpi=500)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3130b639",
   "metadata": {},
   "source": [
    "# CG-CLR (F-test with Prediction plot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6cf1c96",
   "metadata": {},
   "outputs": [],
   "source": [
    "# H0 Baseline: K=1\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "p_dim = train_X.shape[1]\n",
    "num_K = 1\n",
    "\n",
    "model = CGCLR(\n",
    "    input_dim=p_dim,\n",
    "    expert_num= num_K,\n",
    "    output_dim=1,\n",
    "    proxy_hidden_shape=[128,128,128],\n",
    "    dropout = 0.0,\n",
    "    device = device\n",
    ")\n",
    "\n",
    "pre_time = time.time()\n",
    "best_model = training(\n",
    "    model = model,\n",
    "    train_X = train_X,\n",
    "    train_Y = train_Y,\n",
    "    val_X = train_X,\n",
    "    val_Y = train_Y,\n",
    "    max_epochs = 10000,\n",
    "    patience = 1000,\n",
    "    lr = 1e-3,\n",
    "    batch_size = train_X.shape[0],\n",
    "    y_scaler = scaler_y,\n",
    "    device = device,\n",
    "    LAMBDA = 1,\n",
    "    verbose = False\n",
    ")\n",
    "print(\"Training Time(s): \",int(time.time() - pre_time))\n",
    "\n",
    "best_model.eval()\n",
    "with torch.no_grad():\n",
    "    _, _, _, y_tilde = best_model(train_X.to(device))\n",
    "if scaler_y:\n",
    "    TRAIN_SSE = np.sum((scaler_y.inverse_transform(y_tilde.detach().cpu().numpy()) - scaler_y.inverse_transform(np.array(train_Y)))**2)\n",
    "else:\n",
    "    TRAIN_SSE = torch.sum((y_tilde - train_Y)**2).detach().cpu().numpy()\n",
    "H0_SSE = TRAIN_SSE\n",
    "\n",
    "\n",
    "# Inference for plot\n",
    "grid_points_scaled = scaler_x.transform(grid_points)\n",
    "grid_points_scaled = torch.tensor(grid_points_scaled, dtype=torch.float32).to(device)\n",
    "with torch.no_grad():\n",
    "    w_hat, w_tilde, _, y_tilde_grid = best_model(grid_points_scaled)\n",
    "Y_pred = y_tilde_grid.cpu().numpy().reshape(x1_mesh.shape)\n",
    "Y_pred_orig = scaler_y.inverse_transform(Y_pred)\n",
    "Y_pred_proxy = (torch.concat([grid_points_scaled.cpu(),torch.ones(size=(len(grid_points_scaled),1))],axis=1) * w_hat.cpu()).sum(1)\n",
    "Y_pred_proxy = Y_pred_proxy.cpu().numpy().reshape(x1_mesh.shape)\n",
    "Y_pred_proxy_orig = scaler_y.inverse_transform(Y_pred_proxy)\n",
    "Y_true_ground = ground_truth_function(x1_mesh, x2_mesh)\n",
    "Y_pred_err_orig = np.abs(Y_pred_orig - Y_true_ground)\n",
    "Y_pred_proxy_err_orig = np.abs(Y_pred_proxy_orig - Y_true_ground)\n",
    "weight_true_ground = ground_truth_weight_function(x1_mesh, x2_mesh)\n",
    "sigma_y = scaler_y.scale_[0]\n",
    "mu_y    = scaler_y.mean_[0]\n",
    "sigma_x = scaler_x.scale_        \n",
    "mu_x    = scaler_x.mean_         \n",
    "w       = w_tilde[:, :-1].cpu().detach().numpy()    \n",
    "w0      = w_tilde[:,-1].cpu().detach().numpy()\n",
    "beta_tilde    = (sigma_y * w) / sigma_x       \n",
    "beta_0  = mu_y + sigma_y * w0 - np.dot(beta_tilde, mu_x)\n",
    "beta_0_tilde = beta_0.reshape(-1,1)\n",
    "weight_err = ((np.concatenate([beta_tilde, beta_0_tilde],axis=1).reshape(1000,1000,3) - weight_true_ground)**2).mean(2)**0.5\n",
    "w       = w_hat[:, :-1].cpu().detach().numpy()    \n",
    "w0      = w_hat[:,-1].cpu().detach().numpy()\n",
    "beta_hat    = (sigma_y * w) / sigma_x      \n",
    "beta_0  = mu_y + sigma_y * w0 - np.dot(beta_hat, mu_x)\n",
    "beta_0_hat = beta_0.reshape(-1,1)\n",
    "weight_proxy_err = ((np.concatenate([beta_hat, beta_0_hat],axis=1).reshape(1000,1000,3) - weight_true_ground)**2).mean(2)**0.5\n",
    "\n",
    "with torch.no_grad():\n",
    "    _, _, cluster_indices, y_tilde_train = best_model(train_X.to('cuda'))\n",
    "cluster_indices = cluster_indices.detach().cpu().numpy()\n",
    "\n",
    "# Codebook Prediction Contour Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_orig, levels=100, cmap='viridis', alpha=0.8, vmin=-13, vmax=5)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Codebook) Prediction Function\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$\\\\boldsymbol{x}_i^{\\\\top}\\\\tilde{w}_{z_i}$\")\n",
    "plt.savefig(f'./img/CARVE(Codebook)_Pred_K{num_K}.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(12, 10))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "surface = ax.plot_surface(x1_mesh, x2_mesh, Y_pred_orig, cmap='viridis', edgecolor='none', alpha=0.8)\n",
    "ax.scatter(train_X_np[:, 0], train_X_np[:, 1], train_Y_np, color='r', marker='x',\n",
    "           s=30, label='Training Data', depthshade=True)\n",
    "ax.set_xlabel(\"$x^1$\")\n",
    "ax.set_ylabel(\"$x^2$\")\n",
    "ax.set_zlabel(\"$y$\")\n",
    "ax.set_title(\"3D Surface Plot of CG-CLR(Codebook) Predictions\")\n",
    "fig.colorbar(surface, shrink=0.5, aspect=5, label=\"Predicted y\")\n",
    "plt.legend()\n",
    "plt.savefig(f'./img/CARVE(Codebook)_Pred_K{num_K}_3D.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "# Codebook Prediction Error Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_err_orig, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=2)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Codebook) Prediction Error\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$|\\\\boldsymbol{x}_i^{\\\\!\\\\top}\\\\tilde{w}_{z_i} - y_i|$\")\n",
    "plt.savefig(f'./img/CARVE(Codebook)_Error_K{num_K}.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "# Codebook Recovery Error Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, weight_err, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=4)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Codebook) Coefficient Recovery Error\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$||\\\\tilde{w}_{z_i} - w^*_i||$\")\n",
    "plt.savefig(f'./img/CARVE(Codebook)_Recovery_Error_K{num_K}.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "# Proxy Prediction Contour Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_proxy_orig, levels=100, cmap='viridis', alpha=0.8, vmin=-13, vmax=5)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Proxy) Prediction Function\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$\\\\boldsymbol{x}_i^{\\\\top}\\\\tilde{w}_{z_i}$\")\n",
    "plt.savefig(f'./img/CARVE(Proxy)_Pred_K{num_K}.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "fig = plt.figure(figsize=(12, 10))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "surface = ax.plot_surface(x1_mesh, x2_mesh, Y_pred_proxy_orig, cmap='viridis', edgecolor='none', alpha=0.8)\n",
    "ax.scatter(train_X_np[:, 0], train_X_np[:, 1], train_Y_np, color='r', marker='x',\n",
    "           s=30, label='Training Data', depthshade=True)\n",
    "ax.set_xlabel(\"$x^1$\")\n",
    "ax.set_ylabel(\"$x^2$\")\n",
    "ax.set_zlabel(\"$y$\")\n",
    "ax.set_title(\"3D Surface Plot of CG-CLR(Proxy) Predictions\")\n",
    "fig.colorbar(surface, shrink=0.5, aspect=5, label=\"Predicted y\")\n",
    "plt.legend()\n",
    "plt.savefig(f'./img/CARVE(Proxy)_Pred_K{num_K}_3D.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "# Proxy Prediction Error Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_proxy_err_orig, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=2)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Proxy) Prediction Error\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$|\\\\boldsymbol{x}_i^{\\\\!\\\\top}\\\\hat{w}_i - y_i|$\")\n",
    "plt.savefig(f'./img/CARVE(Proxy)_Error_K{num_K}.png', dpi = 500)\n",
    "plt.show()\n",
    "\n",
    "# Proxy Recovery Error Plot\n",
    "plt.figure(figsize=(10, 8))\n",
    "contour = plt.contourf(x1_mesh, x2_mesh, weight_proxy_err, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=4)\n",
    "color_list = ['g','b','r','brown','yellow','black']\n",
    "for i in range(num_K):\n",
    "    plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "plt.xlabel(\"$x^1$\")\n",
    "plt.ylabel(\"$x^2$\")\n",
    "plt.title(\"CG-CLR(Proxy) Coefficient Recovery Error\")\n",
    "plt.legend()\n",
    "plt.colorbar(contour, label=\"$||\\\\hat{w}_i - w^*_i||$\")\n",
    "plt.savefig(f'./img/CARVE(Proxy)_Recovery_Error_K{num_K}.png', dpi = 500)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d248316",
   "metadata": {},
   "outputs": [],
   "source": [
    "for num_K in range(2, 50):\n",
    "    # H1 modeling\n",
    "    model = CGCLR(\n",
    "        input_dim=p_dim,\n",
    "        expert_num= num_K,\n",
    "        output_dim=1,\n",
    "        proxy_hidden_shape=[128,128,128],\n",
    "        dropout = 0.0,\n",
    "        device = device\n",
    "    )\n",
    "    pre_time = time.time()\n",
    "    best_model = training(\n",
    "        model = model,\n",
    "        train_X = train_X,\n",
    "        train_Y = train_Y,\n",
    "        val_X = train_X,\n",
    "        val_Y = train_Y,\n",
    "        max_epochs = 10000,\n",
    "        patience = 1000,\n",
    "        lr = 1e-3,\n",
    "        batch_size = train_X.shape[0],\n",
    "        y_scaler = scaler_y,\n",
    "        device = device,\n",
    "        LAMBDA = 1,\n",
    "        verbose = False\n",
    "    )\n",
    "    print(\"Training Time(s): \",int(time.time() - pre_time))\n",
    "    \n",
    "    best_model.eval()\n",
    "    with torch.no_grad():\n",
    "        _, _, _, y_tilde = best_model(train_X.to(device))\n",
    "    if scaler_y:\n",
    "        TRAIN_SSE = np.sum((scaler_y.inverse_transform(y_tilde.detach().cpu().numpy()) - scaler_y.inverse_transform(np.array(train_Y)))**2)\n",
    "    else:\n",
    "        TRAIN_SSE = torch.sum((y_tilde - train_Y)**2)\n",
    "\n",
    "    H1_SSE = TRAIN_SSE\n",
    "    f_stat = (train_X.shape[0] - (num_K)*(p_dim+1))/(p_dim+1) * (H0_SSE / H1_SSE - 1)\n",
    "    p_val = 2 * stats.f.sf(f_stat, p_dim+1, (len(train_X) - (num_K)*(p_dim+1)))\n",
    "    print(f\"F={f_stat}, P_val={p_val}, H0: K={num_K-1}\")\n",
    "    H0_SSE = H1_SSE\n",
    "    \n",
    "    \n",
    "    # Inference for plot\n",
    "    grid_points_scaled = scaler_x.transform(grid_points)\n",
    "    grid_points_scaled = torch.tensor(grid_points_scaled, dtype=torch.float32).to(device)\n",
    "    with torch.no_grad():\n",
    "        w_hat, w_tilde, _, y_tilde_grid = best_model(grid_points_scaled)\n",
    "    Y_pred = y_tilde_grid.cpu().numpy().reshape(x1_mesh.shape)\n",
    "    Y_pred_orig = scaler_y.inverse_transform(Y_pred)\n",
    "    Y_pred_proxy = (torch.concat([grid_points_scaled.cpu(),torch.ones(size=(len(grid_points_scaled),1))],axis=1) * w_hat.cpu()).sum(1)\n",
    "    Y_pred_proxy = Y_pred_proxy.cpu().numpy().reshape(x1_mesh.shape)\n",
    "    Y_pred_proxy_orig = scaler_y.inverse_transform(Y_pred_proxy)\n",
    "    Y_true_ground = ground_truth_function(x1_mesh, x2_mesh)\n",
    "    Y_pred_err_orig = np.abs(Y_pred_orig - Y_true_ground)\n",
    "    Y_pred_proxy_err_orig = np.abs(Y_pred_proxy_orig - Y_true_ground)\n",
    "    weight_true_ground = ground_truth_weight_function(x1_mesh, x2_mesh)\n",
    "    sigma_y = scaler_y.scale_[0]\n",
    "    mu_y    = scaler_y.mean_[0]\n",
    "    sigma_x = scaler_x.scale_        \n",
    "    mu_x    = scaler_x.mean_         \n",
    "    w       = w_tilde[:, :-1].cpu().detach().numpy()    \n",
    "    w0      = w_tilde[:,-1].cpu().detach().numpy()\n",
    "    beta_tilde    = (sigma_y * w) / sigma_x       \n",
    "    beta_0  = mu_y + sigma_y * w0 - np.dot(beta_tilde, mu_x)\n",
    "    beta_0_tilde = beta_0.reshape(-1,1)\n",
    "    weight_err = ((np.concatenate([beta_tilde, beta_0_tilde],axis=1).reshape(1000,1000,3) - weight_true_ground)**2).mean(2)**0.5\n",
    "    w       = w_hat[:, :-1].cpu().detach().numpy()    \n",
    "    w0      = w_hat[:,-1].cpu().detach().numpy()\n",
    "    beta_hat    = (sigma_y * w) / sigma_x      \n",
    "    beta_0  = mu_y + sigma_y * w0 - np.dot(beta_hat, mu_x)\n",
    "    beta_0_hat = beta_0.reshape(-1,1)\n",
    "    weight_proxy_err = ((np.concatenate([beta_hat, beta_0_hat],axis=1).reshape(1000,1000,3) - weight_true_ground)**2).mean(2)**0.5\n",
    "\n",
    "    with torch.no_grad():\n",
    "        _, _, cluster_indices, y_tilde_train = best_model(train_X.to('cuda'))\n",
    "    cluster_indices = cluster_indices.detach().cpu().numpy()\n",
    "\n",
    "    # Codebook Prediction Contour Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_orig, levels=100, cmap='viridis', alpha=0.8, vmin=-13, vmax=5)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Codebook) Prediction Function\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$\\\\boldsymbol{x}_i^{\\\\top}\\\\tilde{w}_{z_i}$\")\n",
    "    plt.savefig(f'./img/CARVE(Codebook)_Pred_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    fig = plt.figure(figsize=(12, 10))\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    surface = ax.plot_surface(x1_mesh, x2_mesh, Y_pred_orig, cmap='viridis', edgecolor='none', alpha=0.8)\n",
    "    ax.scatter(train_X_np[:, 0], train_X_np[:, 1], train_Y_np, color='r', marker='x',\n",
    "            s=30, label='Training Data', depthshade=True)\n",
    "    ax.set_xlabel(\"$x^1$\")\n",
    "    ax.set_ylabel(\"$x^2$\")\n",
    "    ax.set_zlabel(\"$y$\")\n",
    "    ax.set_title(\"3D Surface Plot of CG-CLR(Codebook) Predictions\")\n",
    "    fig.colorbar(surface, shrink=0.5, aspect=5, label=\"Predicted y\")\n",
    "    plt.legend()\n",
    "    plt.savefig(f'./img/CARVE(Codebook)_Pred_K{num_K}_3D.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    # Codebook Prediction Error Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_err_orig, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=2)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Codebook) Prediction Error\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$|\\\\boldsymbol{x}_i^{\\\\!\\\\top}\\\\tilde{w}_{z_i} - y_i|$\")\n",
    "    plt.savefig(f'./img/CARVE(Codebook)_Error_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    # Codebook Recovery Error Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, weight_err, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=4)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Codebook) Coefficient Recovery Error\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$||\\\\tilde{w}_{z_i} - w^*_i||$\")\n",
    "    plt.savefig(f'./img/CARVE(Codebook)_Recovery_Error_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    # Proxy Prediction Contour Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_proxy_orig, levels=100, cmap='viridis', alpha=0.8, vmin=-13, vmax=5)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Proxy) Prediction Function\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$\\\\boldsymbol{x}_i^{\\\\top}\\\\tilde{w}_{z_i}$\")\n",
    "    plt.savefig(f'./img/CARVE(Proxy)_Pred_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    fig = plt.figure(figsize=(12, 10))\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    surface = ax.plot_surface(x1_mesh, x2_mesh, Y_pred_proxy_orig, cmap='viridis', edgecolor='none', alpha=0.8)\n",
    "    ax.scatter(train_X_np[:, 0], train_X_np[:, 1], train_Y_np, color='r', marker='x',\n",
    "            s=30, label='Training Data', depthshade=True)\n",
    "    ax.set_xlabel(\"$x^1$\")\n",
    "    ax.set_ylabel(\"$x^2$\")\n",
    "    ax.set_zlabel(\"$y$\")\n",
    "    ax.set_title(\"3D Surface Plot of CG-CLR(Proxy) Predictions\")\n",
    "    fig.colorbar(surface, shrink=0.5, aspect=5, label=\"Predicted y\")\n",
    "    plt.legend()\n",
    "    plt.savefig(f'./img/CARVE(Proxy)_Pred_K{num_K}_3D.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    # Proxy Prediction Error Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, Y_pred_proxy_err_orig, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=2)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Proxy) Prediction Error\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$|\\\\boldsymbol{x}_i^{\\\\!\\\\top}\\\\hat{w}_i - y_i|$\")\n",
    "    plt.savefig(f'./img/CARVE(Proxy)_Error_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "\n",
    "    # Proxy Prediction Error Plot\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    contour = plt.contourf(x1_mesh, x2_mesh, weight_proxy_err, levels=100, cmap='viridis', alpha=0.8, vmin=0, vmax=4)\n",
    "    color_list = ['g','b','r','brown','yellow','black']\n",
    "    for i in range(num_K):\n",
    "        plt.scatter(train_X_np[(cluster_indices==i).flatten()][:, 0], train_X_np[(cluster_indices==i).flatten()][:, 1], color=color_list[i], marker='x', s=30, label=f'cluster-{i}')\n",
    "    plt.xlabel(\"$x^1$\")\n",
    "    plt.ylabel(\"$x^2$\")\n",
    "    plt.title(\"CG-CLR(Proxy) Coefficient Recovery Error\")\n",
    "    plt.legend()\n",
    "    plt.colorbar(contour, label=\"$||\\\\hat{w}_i - w^*_i||$\")\n",
    "    plt.savefig(f'./img/CARVE(Proxy)_Recovery_Error_K{num_K}.png', dpi = 500)\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "    \n",
    "    if p_val > 0.01:\n",
    "        print(f\"Cannot Reject H0, Thus K={num_K-1}\")\n",
    "        break\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a1e3c0",
   "metadata": {},
   "source": [
    "##"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CGCLR",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
