{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03e91dd5-48e3-474a-8c0c-c687f9ef01eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pyro\n",
    "import torch\n",
    "import random\n",
    "import inspect \n",
    "import logging\n",
    "import itertools\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from copy import deepcopy\n",
    "from scipy.stats import norm\n",
    "from tedvae_gpu import TEDVAE\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.nn.functional import mse_loss\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import inspect\n",
    "import textwrap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f918dd82-d016-48cf-b06a-caa0d74054ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger(\"pyro\").setLevel(logging.DEBUG)\n",
    "logging.getLogger(\"pyro\").handlers[0].setLevel(logging.DEBUG)\n",
    "pyro.enable_validation(True)\n",
    "torch.set_default_tensor_type('torch.FloatTensor')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5e292b4-7f87-4c11-8a37-9cfdc1cdf7ba",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d42e45d8-e6be-4705-9f4f-34ba636ed7ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DoubleEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, latent_dim, use_auxiliary_latents=False):\n",
    "        super().__init__()\n",
    "        self.use_aux = use_auxiliary_latents\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "\n",
    "        self.mean_u = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.logvar_u = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.mean_eps = nn.Linear(hidden_dim, input_dim)\n",
    "        self.logvar_eps = nn.Linear(hidden_dim, input_dim)\n",
    "\n",
    "        if self.use_aux:\n",
    "            self.mean_ua0 = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.logvar_ua0 = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.mean_ua1 = nn.Linear(hidden_dim, latent_dim)\n",
    "            self.logvar_ua1 = nn.Linear(hidden_dim, latent_dim)\n",
    "\n",
    "    def forward(self, z):\n",
    "        h = F.relu(self.fc1(z))\n",
    "\n",
    "        mu_u = self.mean_u(h)\n",
    "        logvar_u = self.logvar_u(h)\n",
    "        std_u = torch.exp(0.5 * logvar_u)\n",
    "\n",
    "        mu_eps = self.mean_eps(h)\n",
    "        logvar_eps = self.logvar_eps(h)\n",
    "        std_eps = torch.exp(0.5 * logvar_eps)\n",
    "\n",
    "        if self.use_aux:\n",
    "            mu_ua0 = self.mean_ua0(h)\n",
    "            logvar_ua0 = self.logvar_ua0(h)\n",
    "            std_ua0 = torch.exp(0.5 * logvar_ua0)\n",
    "\n",
    "            mu_ua1 = self.mean_ua1(h)\n",
    "            logvar_ua1 = self.logvar_ua1(h)\n",
    "            std_ua1 = torch.exp(0.5 * logvar_ua1)\n",
    "\n",
    "            return (mu_u, std_u, logvar_u, mu_eps, std_eps, logvar_eps,\n",
    "                    mu_ua0, std_ua0, logvar_ua0, mu_ua1, std_ua1, logvar_ua1)\n",
    "\n",
    "        return mu_u, std_u, logvar_u, mu_eps, std_eps, logvar_eps\n",
    "\n",
    "# ===== Additive Decoder =====\n",
    "class AdditiveDecoder(nn.Module):\n",
    "    def __init__(self, latent_dim, hidden_dim, output_dim, use_auxiliary_latents=False):\n",
    "        super().__init__()\n",
    "        self.use_aux = use_auxiliary_latents\n",
    "        factor = 1 + (2 if use_auxiliary_latents else 0)\n",
    "        self.fc1 = nn.Linear(latent_dim * factor, hidden_dim)\n",
    "        self.out = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, u, eps, ua0=None, ua1=None):\n",
    "        if self.use_aux:\n",
    "            x = torch.cat([u, ua0, ua1], dim=1)\n",
    "        else:\n",
    "            x = u\n",
    "        h = F.relu(self.fc1(x))\n",
    "        return self.out(h) + eps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d7eea1-3e13-42f8-bec8-2a553eb55bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== GPVAE with optional auxiliary latents =====\n",
    "class GPVAEwithNoise(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, latent_dim, gp_lengthscale=1.0, gp_variance=1.0, gp_noise=1e-2,\n",
    "                 use_auxiliary_latents=False):\n",
    "        super().__init__()\n",
    "        self.use_aux = use_auxiliary_latents\n",
    "        self.encoder = DoubleEncoder(input_dim, hidden_dim, latent_dim, use_auxiliary_latents)\n",
    "        self.decoder = AdditiveDecoder(latent_dim, hidden_dim, input_dim, use_auxiliary_latents)\n",
    "        self.gp_lengthscale = gp_lengthscale\n",
    "        self.gp_variance = gp_variance\n",
    "        self.gp_noise = gp_noise\n",
    "\n",
    "    def reparameterize(self, mu, std):\n",
    "        return mu + torch.randn_like(std) * std\n",
    "\n",
    "    def rbf_kernel(self, x1, x2=None):\n",
    "        if x2 is None:\n",
    "            x2 = x1\n",
    "        dists = torch.cdist(x1, x2).pow(2)\n",
    "        return self.gp_variance * torch.exp(-0.5 * dists / self.gp_lengthscale**2)\n",
    "\n",
    "    def gp_prior_kl(self, z_train, mu_u, std_u):\n",
    "        n = z_train.shape[0]\n",
    "        y_lower = (mu_u - std_u).squeeze()\n",
    "        y_upper = (mu_u + std_u).squeeze()\n",
    "        K = self.rbf_kernel(z_train) + self.gp_noise * torch.eye(n, device=z_train.device)\n",
    "\n",
    "        jitter = 1e-4\n",
    "        for _ in range(5):\n",
    "            try:\n",
    "                L = torch.linalg.cholesky(K + jitter * torch.eye(n, device=z_train.device))\n",
    "                break\n",
    "            except RuntimeError:\n",
    "                jitter *= 10\n",
    "        else:\n",
    "            return torch.tensor(0.0, device=z_train.device)\n",
    "\n",
    "        log_prob = 0.0\n",
    "        for i in range(n):\n",
    "            k_star = K[i, i] + jitter\n",
    "            std_i = torch.sqrt(k_star)\n",
    "            dist = torch.distributions.Normal(loc=0.0, scale=std_i)\n",
    "            prob = dist.cdf(y_upper[i]) - dist.cdf(y_lower[i]) + 1e-6\n",
    "            log_prob += torch.log(prob)\n",
    "        return -log_prob\n",
    "\n",
    "    def get_latent_stats(self, z, var_name='u'):\n",
    "        outputs = self.encoder(z)\n",
    "    \n",
    "        if var_name == 'u':\n",
    "            return outputs[0], outputs[1], outputs[2]  # mu_u, std_u, logvar_u\n",
    "        elif var_name == 'eps':\n",
    "            return outputs[3], outputs[4], outputs[5]  # mu_eps, std_eps, logvar_eps\n",
    "        elif var_name == 'ua0':\n",
    "            if not self.use_aux:\n",
    "                return None, None, None\n",
    "            return outputs[6], outputs[7], outputs[8]  # mu_ua0, std_ua0, logvar_ua0\n",
    "        elif var_name == 'ua1':\n",
    "            if not self.use_aux:\n",
    "                return None, None, None\n",
    "            return outputs[9], outputs[10], outputs[11]  # mu_ua1, std_ua1, logvar_ua1\n",
    "        else:\n",
    "            raise ValueError(f\"Invalid var_name '{var_name}'. Must be one of 'u', 'eps', 'ua0', or 'ua1'.\")\n",
    "\n",
    "    #DF02-07-25: a new membership function to retrieve desired variable (samples) easily from var_name\n",
    "    def sample_latent(self, z, var_name='u'):\n",
    "        outputs = self.encoder(z)\n",
    "        \n",
    "        if var_name == 'u':\n",
    "            mu, std = outputs[0], outputs[1]\n",
    "        elif var_name == 'eps':\n",
    "            mu, std = outputs[3], outputs[4]\n",
    "        elif var_name == 'ua0':\n",
    "            if not self.use_aux:\n",
    "                return None\n",
    "            mu, std = outputs[6], outputs[7]\n",
    "        elif var_name == 'ua1':\n",
    "            if not self.use_aux:\n",
    "                return None\n",
    "            mu, std = outputs[9], outputs[10]\n",
    "        else:\n",
    "            raise ValueError(f\"Invalid var_name '{var_name}'.\")\n",
    "        \n",
    "        return self.reparameterize(mu, std)\n",
    "\n",
    "\n",
    "    def forward(self, z):\n",
    "        if self.use_aux:\n",
    "            (mu_u, std_u, logvar_u, mu_eps, std_eps, logvar_eps,\n",
    "             mu_ua0, std_ua0, logvar_ua0, mu_ua1, std_ua1, logvar_ua1) = self.encoder(z)\n",
    "            ua0 = self.reparameterize(mu_ua0, std_ua0)\n",
    "            ua1 = self.reparameterize(mu_ua1, std_ua1)\n",
    "            z_recon = self.decoder(self.reparameterize(mu_u, std_u),\n",
    "                                   self.reparameterize(mu_eps, std_eps),\n",
    "                                   ua0=ua0, ua1=ua1)\n",
    "        else:\n",
    "            mu_u, std_u, logvar_u, mu_eps, std_eps, logvar_eps = self.encoder(z)\n",
    "            ua0 = ua1 = None\n",
    "            z_recon = self.decoder(self.reparameterize(mu_u, std_u),\n",
    "                                   self.reparameterize(mu_eps, std_eps))\n",
    "\n",
    "        recon_loss = F.mse_loss(z_recon, z, reduction='sum')\n",
    "        kl_u = -0.5 * torch.sum(1 + logvar_u - mu_u.pow(2) - logvar_u.exp())\n",
    "        kl_eps = -0.5 * torch.sum(1 + logvar_eps - mu_eps.pow(2) - logvar_eps.exp())\n",
    "        gp_kl = self.gp_prior_kl(z, mu_u, std_u)\n",
    "        std_penalty = torch.sum(std_u ** 2)\n",
    "\n",
    "        total_loss = recon_loss + kl_u + kl_eps + gp_kl + 10 * std_penalty\n",
    "\n",
    "        if self.use_aux:\n",
    "            kl_ua0 = -0.5 * torch.sum(1 + logvar_ua0 - mu_ua0.pow(2) - logvar_ua0.exp())\n",
    "            kl_ua1 = -0.5 * torch.sum(1 + logvar_ua1 - mu_ua1.pow(2) - logvar_ua1.exp())\n",
    "            total_loss += kl_ua0 + kl_ua1\n",
    "\n",
    "        return total_loss, {\n",
    "            'recon_loss': recon_loss.item(),\n",
    "            'kl_u': kl_u.item(),\n",
    "            'kl_eps': kl_eps.item(),\n",
    "            'gp_kl': gp_kl.item(),\n",
    "            'u': self.reparameterize(mu_u, std_u),\n",
    "            'eps': self.reparameterize(mu_eps, std_eps),\n",
    "            'ua0': ua0,\n",
    "            'ua1': ua1,\n",
    "            'mu_u': mu_u,\n",
    "            'std_u': std_u, \n",
    "            'z_recon': z_recon   #DF02--7-25: return the recon so that can directly use vae() as as function\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8073f2fe-d5fe-428a-9e08-ad149cd00383",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== Flexible Causal Head =====\n",
    "class FlexibleCausalHead(nn.Module):\n",
    "    def __init__(self, latent_dim_u, z_y_dim=None, use_auxiliary_latents=False):\n",
    "        super().__init__()\n",
    "        self.use_aux = use_auxiliary_latents\n",
    "        self.z_y_dim = z_y_dim\n",
    "\n",
    "        input_dim = latent_dim_u + 1  # for u and t\n",
    "        if z_y_dim is not None:\n",
    "            input_dim += z_y_dim\n",
    "        if use_auxiliary_latents:\n",
    "            input_dim += latent_dim_u  # assuming ua1 has same dim as u\n",
    "\n",
    "        self.fc = nn.Linear(input_dim, 64)\n",
    "        self.out = nn.Linear(64, 1)\n",
    "\n",
    "    def forward(self, u, t, z_y=None, ua1=None):\n",
    "        t = t.view(-1, 1)\n",
    "        parts = [u, t]\n",
    "\n",
    "        if self.z_y_dim is not None:\n",
    "            if z_y is None:\n",
    "                raise ValueError(\"Expected z_y input but got None\")\n",
    "            parts.append(z_y)\n",
    "\n",
    "        if self.use_aux:\n",
    "            if ua1 is None:\n",
    "                raise ValueError(\"Expected ua1 input but got None\")\n",
    "            parts.append(ua1)\n",
    "\n",
    "        x = torch.cat(parts, dim=1)\n",
    "        h = F.relu(self.fc(x))\n",
    "        return self.out(h)\n",
    "\n",
    "class CausalGPVAEwithNoise(nn.Module):\n",
    "    def __init__(self, \n",
    "                 input_dim, \n",
    "                 hidden_dim, \n",
    "                 latent_dim, \n",
    "                 z_y_dim=None, \n",
    "                 use_auxiliary_latents=False,\n",
    "                 gp_lengthscale=1.0, \n",
    "                 gp_variance=1.0, \n",
    "                 gp_noise=1e-2):\n",
    "        super().__init__()\n",
    "        self.z_y_dim = z_y_dim\n",
    "        self.use_aux = use_auxiliary_latents\n",
    "\n",
    "        # GP-VAE backbone\n",
    "        self.vae = GPVAEwithNoise(input_dim, hidden_dim, latent_dim,\n",
    "                                  gp_lengthscale=gp_lengthscale,\n",
    "                                  gp_variance=gp_variance,\n",
    "                                  gp_noise=gp_noise,\n",
    "                                  use_auxiliary_latents=use_auxiliary_latents\n",
    "                                 )\n",
    "\n",
    "        # Flexible causal head\n",
    "        self.causal_head = FlexibleCausalHead(\n",
    "            latent_dim_u=latent_dim,\n",
    "            z_y_dim=z_y_dim,\n",
    "            use_auxiliary_latents=use_auxiliary_latents\n",
    "        )\n",
    "\n",
    "    #DF02-07-05: compute z_y from inside not outside!!\n",
    "    def forward(self, z, t, y):\n",
    "\n",
    "        loss_vae, vae_info = self.vae(z)\n",
    "        u = vae_info['u']\n",
    "        ua1 = vae_info['ua1'] if self.use_aux else None\n",
    "    \n",
    "        # Internally extract z_y if needed\n",
    "        if self.z_y_dim is not None:\n",
    "            if z.shape[1] < self.z_y_dim:\n",
    "                raise ValueError(f\"z has {z.shape[1]} dims, but z_y_dim={self.z_y_dim}\")\n",
    "            z_y = z[:, -self.z_y_dim:]  # Take last `z_y_dim` dims\n",
    "            y_pred = self.causal_head(u, t, z_y=z_y, ua1=ua1)\n",
    "        else:\n",
    "            y_pred = self.causal_head(u, t, ua1=ua1)\n",
    "    \n",
    "        causal_loss = F.mse_loss(y_pred, y.unsqueeze(-1), reduction='sum')\n",
    "        total_loss = loss_vae + causal_loss\n",
    "    \n",
    "        return total_loss, {\n",
    "            **vae_info,\n",
    "            'y_pred': y_pred,\n",
    "            'causal_loss': causal_loss.item()\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d056f0df-7fd1-45d8-9598-b1c392a3bfba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_synthetic_data_with_aux_uas(n=1000, noise_std=0.1, seed=42, num_proxies=6, proxy_funcs=None, outcome_func=None, treatment_func=None):\n",
    "    \"\"\"\n",
    "    Generate synthetic data where:\n",
    "        u, ua0, ua1 → z\n",
    "        u, ua0 → t\n",
    "        u, t, ua1 → y\n",
    "    \"\"\"\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    # Latent variables\n",
    "    u = np.random.normal(0, 1, size=(n, 1))\n",
    "    ua0 = np.random.normal(0, 1, size=(n, 1))\n",
    "    ua1 = np.random.normal(0, 1, size=(n, 1))\n",
    "\n",
    "    u_tensor = torch.tensor(u, dtype=torch.float32)\n",
    "    ua0_tensor = torch.tensor(ua0, dtype=torch.float32)\n",
    "    ua1_tensor = torch.tensor(ua1, dtype=torch.float32)\n",
    "\n",
    "    if proxy_funcs is None:\n",
    "        proxy_funcs = [\n",
    "            lambda u, ua0, ua1: u,                                     # linear in u (identity)\n",
    "            lambda u, ua0, ua1: u ** 2,                                # quadratic in u\n",
    "            lambda u, ua0, ua1: torch.sin(u),                          # nonlinear periodic\n",
    "            lambda u, ua0, ua1: torch.tanh(u),                         # saturating nonlinearity\n",
    "            lambda u, ua0, ua1: torch.exp(-u ** 2),                    # Gaussian shape\n",
    "            lambda u, ua0, ua1: torch.sin(2 * u) + 0.1 * ua1,          # mostly u, weak ua1\n",
    "            lambda u, ua0, ua1: torch.log(torch.abs(u) + 1e-3),        # log-scale info\n",
    "            lambda u, ua0, ua1: torch.cos(u + 0.5 * ua0),              # slight entanglement\n",
    "            lambda u, ua0, ua1: torch.relu(u),                         # piecewise linear\n",
    "            #lambda u, ua0, ua1: (u + 0.1 * ua0) * (1 + 0.1 * ua1),     # dominant u\n",
    "        ]\n",
    "\n",
    "    proxy_funcs = proxy_funcs[:num_proxies]\n",
    "    clean_z = torch.cat([g(u_tensor, ua0_tensor, ua1_tensor) for g in proxy_funcs], dim=1)\n",
    "    epsilon = torch.randn(n, num_proxies) * noise_std\n",
    "    z_tensor = clean_z + epsilon\n",
    "    z_np = z_tensor.numpy()\n",
    "\n",
    "    # Treatment assignment: u, ua0 → t\n",
    "    if treatment_func is None:\n",
    "        def treatment_func(u_np, ua0_np=ua0):\n",
    "            logits = 1.5 * u_np + 0.8 * ua0_np\n",
    "            probs = 1 / (1 + np.exp(-logits))\n",
    "            return np.random.binomial(1, probs).astype(np.float32)\n",
    "\n",
    "    t = treatment_func(u, ua0)\n",
    "    t_tensor = torch.tensor(t).squeeze()\n",
    "\n",
    "    # Outcome: u, t, ua1 → y\n",
    "    if outcome_func is None:\n",
    "        def outcome_func(u_np, t_np, z_np, ua1_np=ua1):\n",
    "            ua_contrib = 0.4 * np.sin(ua1_np)\n",
    "            return (\n",
    "                np.sin(u_np) +\n",
    "                0.3 * t_np.reshape(-1, 1) +\n",
    "                0.6 * u_np * t_np.reshape(-1, 1) +\n",
    "                ua_contrib +\n",
    "                np.random.normal(0, noise_std, size=u_np.shape)\n",
    "            )\n",
    "\n",
    "    y = outcome_func(u, t, ua0, ua1).astype(np.float32) #DF03-07-25: add ua0 and ua1\n",
    "    y_tensor = torch.tensor(y).squeeze()\n",
    "\n",
    "    return {'z': z_tensor,\n",
    "            't': t_tensor,\n",
    "            'y': y_tensor,\n",
    "            'u': u_tensor.squeeze(),\n",
    "            'ua0': ua0_tensor.squeeze(),\n",
    "            'ua1': ua1_tensor.squeeze(),\n",
    "            'proxy_funcs': proxy_funcs,\n",
    "            'outcome_func': outcome_func,\n",
    "            'treatment_func': treatment_func}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fe4f82-e1fe-4175-95db-9b85489337ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def call_outcome_func(outcome_func, u_np, t_np, z_np=None, ua0=None, ua1=None):\n",
    "    \"\"\"\n",
    "    Calls the outcome function with appropriate arguments.\n",
    "    Supports both:\n",
    "        - outcome_func(u_np, t_np)\n",
    "        - outcome_func(u_np, t_np, z_np)\n",
    "    \"\"\"\n",
    "    sig = inspect.signature(outcome_func)\n",
    "    params = sig.parameters\n",
    "\n",
    "    if len(params) == 2:\n",
    "        return outcome_func(u_np, t_np)\n",
    "    elif len(params) == 3:\n",
    "        if z_np is None:\n",
    "            raise ValueError(\"Outcome function expects z_np but none was provided.\")\n",
    "        return outcome_func(u_np, t_np, z_np)\n",
    "    #DF03-07-25: add this branch \n",
    "    elif len(params) == 4:      \n",
    "        return outcome_func(u_np, t_np, ua0, ua1)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported outcome function signature.\")\n",
    "\n",
    "\n",
    "def simple_outcome_func_with_aux_u(u_np, t_np, ua0, ua1):\n",
    "    \"\"\"\n",
    "    Outcome model where y depends on u, t, and optionally on global ua1.\n",
    "    This keeps the interface: outcome_func(u, t, z) → y\n",
    "\n",
    "    Global:\n",
    "        ua1_true: optional (n, 1) torch.Tensor used for auxiliary latent effect\n",
    "    \"\"\"\n",
    "    base = u_np + t_np.reshape(-1, 1) * (0.5 + 0.3 * u_np)\n",
    "    base += 0.4 * ua1 \n",
    "\n",
    "    noise = np.random.normal(0, 0.1, size=base.shape)\n",
    "    return base + noise\n",
    "\n",
    "def select_train_from_z_test_from_external(z, u_true, t, y, data_test, n_train=10, seed=42):\n",
    "\n",
    "    # -- Randomly select training indices from original data\n",
    "    rng = np.random.default_rng(seed)\n",
    "    all_indices = np.arange(len(z))\n",
    "    train_indices = rng.choice(all_indices, size=n_train, replace=False)\n",
    "\n",
    "    z_train_gp = z[train_indices]\n",
    "    true_u_train = u_true[train_indices]\n",
    "\n",
    "    # -- Extract test data from external dataset\n",
    "    z_test = data_test['z']\n",
    "    true_u_test = data_test['u']\n",
    "    t_test = data_test['t']\n",
    "    y_test = data_test['y']\n",
    "\n",
    "    #DF03-07-25:  allow aux u\n",
    "    if ua0_true is not None:\n",
    "        true_ua0_train = ua0_true[train_indices]\n",
    "        true_ua0_test  = data_test['ua0']\n",
    "    else:\n",
    "        true_ua0_train = None\n",
    "        true_ua0_test = None\n",
    "        \n",
    "    if ua1_true is not None:\n",
    "        true_ua1_train = ua1_true[train_indices]\n",
    "        true_ua1_test  = data_test['ua1']\n",
    "    else:\n",
    "        true_ua1_train = None\n",
    "        true_ua1_test  = None\n",
    "    \n",
    "    return z_train_gp, z_test, true_u_train, true_u_test,t_test, y_test, train_indices,\\\n",
    "    true_ua0_train,true_ua1_train,true_ua0_test,true_ua1_test\n",
    "\n",
    "\n",
    "#DF09-06-25: Interval visualization Estimated u vs True u\n",
    "# ===== Sample to get Latent u =====\n",
    "def estimate_latent_u_posterior(model, z, n_samples=100):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        mu_u, std_u = model.vae.encoder(z)[0:2]\n",
    "        u_samples = []\n",
    "        for _ in range(n_samples):\n",
    "            u_sample = model.vae.reparameterize(mu_u, std_u)\n",
    "            u_samples.append(u_sample.squeeze())\n",
    "        u_stack = torch.stack(u_samples, dim=0)  # (n_samples, n_points)\n",
    "        return u_stack.mean(0), u_stack.std(0)\n",
    "\n",
    "def estimate_counterfactual(model, z, t_cf, n_samples=100):\n",
    "    \"\"\"\n",
    "    成对重采样（paired）：每个 MC 轮次 s 同时采样 u^{(s)} ~ q(u|z) 和 ua1^{(s)} ~ q(ua1|z)，\n",
    "    并用同一对 (u^{(s)}, ua1^{(s)}) 计算给定 t_cf 下的 y(t_cf)。\n",
    "    返回：y(t_cf) 的均值和标准差（按样本维度 N 聚合）。\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        # ---- 编码器统计量 ----\n",
    "        mu_u, std_u, _ = model.vae.get_latent_stats(z, var_name='u')   # (N, d_u)\n",
    "        N, d_u = mu_u.shape\n",
    "        device, dtype = mu_u.device, mu_u.dtype\n",
    "\n",
    "        mu_ua1, std_ua1, _ = model.vae.get_latent_stats(z, var_name='ua1')  # (N, d_ua) or (None,...)\n",
    "        has_ua1 = (mu_ua1 is not None) and (std_ua1 is not None)\n",
    "\n",
    "        # ---- z_y 选择逻辑（与你原实现一致）----\n",
    "        if model.causal_head.z_y_dim == z.shape[1]:\n",
    "            z_y = z\n",
    "        elif model.causal_head.z_y_dim:\n",
    "            z_y = z[:, z.shape[1] // 2:]\n",
    "        else:\n",
    "            z_y = None\n",
    "\n",
    "        # ---- 采样温度 ----\n",
    "        temp_u  = getattr(model.vae, \"sample_temp\", 2.0)\n",
    "        temp_ua = getattr(model.vae, \"sample_temp_ua\", temp_u)\n",
    "\n",
    "        # ---- 成对重采：按相同的第1维 s 同时生成 u 与 ua1 ----\n",
    "        eps_u = torch.randn(n_samples, N, d_u, device=device)\n",
    "        u_samples = mu_u.unsqueeze(0) + eps_u * (std_u.unsqueeze(0) * temp_u)   # (S,N,d_u)\n",
    "\n",
    "        if has_ua1:\n",
    "            d_ua = mu_ua1.shape[1]\n",
    "            eps_ua1 = torch.randn(n_samples, N, d_ua, device=device)\n",
    "            ua1_samples = mu_ua1.unsqueeze(0) + eps_ua1 * (std_ua1.unsqueeze(0) * temp_ua)  # (S,N,d_ua)\n",
    "            UA1 = ua1_samples.reshape(-1, d_ua)  # (S*N, d_ua)\n",
    "        else:\n",
    "            UA1 = None\n",
    "\n",
    "        # ---- 展平批量 ----\n",
    "        U = u_samples.reshape(-1, d_u)  # (S*N, d_u)\n",
    "        if z_y is not None:\n",
    "            ZY = z_y.unsqueeze(0).expand(n_samples, -1, -1).reshape(-1, z_y.shape[1])  # (S*N, z_y_dim)\n",
    "        else:\n",
    "            ZY = None\n",
    "\n",
    "        # ---- 广播/检查 t_cf：可为标量或长度 N 的向量 ----\n",
    "        if not torch.is_tensor(t_cf):\n",
    "            t_cf = torch.tensor(t_cf, device=device, dtype=dtype)\n",
    "        else:\n",
    "            t_cf = t_cf.to(device=device, dtype=dtype)\n",
    "\n",
    "        if t_cf.ndim == 0:\n",
    "            # 标量 -> (S,N)\n",
    "            t_flat = t_cf.expand(n_samples, N).reshape(-1)\n",
    "        elif t_cf.ndim == 1 and t_cf.numel() == N:\n",
    "            # (N,) -> (S,N)\n",
    "            t_flat = t_cf.unsqueeze(0).expand(n_samples, -1).reshape(-1)\n",
    "        else:\n",
    "            raise ValueError(f\"t_cf 形状 {tuple(t_cf.shape)} 与 batch 大小 N={N} 不兼容（需标量或长度为 N 的向量）\")\n",
    "\n",
    "        # ---- 单臂前向：只计算 y(t_cf) ----\n",
    "        y_all = model.causal_head(U, t_flat, z_y=ZY, ua1=UA1).squeeze(-1)  # (S*N,)\n",
    "        y_stack = y_all.view(n_samples, N)  # (S, N)\n",
    "\n",
    "        return y_stack.mean(0), y_stack.std(0)\n",
    "\n",
    "\n",
    "        \n",
    "# def estimate_counterfactual(model, z, t_cf, n_samples=100):\n",
    "#     model.eval()\n",
    "#     with torch.no_grad():\n",
    "        \n",
    "#         mu_u, std_u,_ = model.vae.get_latent_stats(z, var_name='u')\n",
    "#         ua1 = model.vae.sample_latent(z, var_name='ua1')\n",
    "\n",
    "#         z_y = z if model.causal_head.z_y_dim == z.shape[1] else (\n",
    "#             z[:, z.shape[1] // 2:] if model.causal_head.z_y_dim else None\n",
    "#         )\n",
    "#         outcomes = []\n",
    "#         for _ in range(n_samples):\n",
    "#             u_sample = model.vae.reparameterize(mu_u, std_u)\n",
    "#             y_cf = model.causal_head(u_sample, t_cf, z_y=z_y, ua1=ua1)\n",
    "#             outcomes.append(y_cf)\n",
    "#         y_stack = torch.stack(outcomes, dim=0)\n",
    "#         return y_stack.mean(0), y_stack.std(0)\n",
    "\n",
    "#DF02-07-25: use the new membership function to retrieve desired variable (u and ua1) easily from var_name\n",
    "# 2. True u + Causal Head\n",
    "def estimate_true_causal_head(model, true_u, z, t_cf): #DF03-07-25: made it similar to def estimate_counterfactual(\n",
    "    with torch.no_grad():\n",
    "\n",
    "        #compute z_y\n",
    "        z_y = z if model.causal_head.z_y_dim == z.shape[1] else (\n",
    "            z[:, z.shape[1] // 2:] if model.causal_head.z_y_dim else None\n",
    "        )\n",
    "        # ===== Predict y using causal head =====\n",
    "        ua1 = model.vae.sample_latent(z, var_name='ua1') #DF02-07-25: added\n",
    "        y = model.causal_head(true_u.unsqueeze(1), t_cf, z_y=z_y, ua1=ua1)\n",
    "        return y.squeeze(), torch.zeros_like(y.squeeze())\n",
    "\n",
    "# 3. True f(u, t)\n",
    "def estimate_true_f_outcome(model, true_u, t_cf, outcome_func):\n",
    "    u_np = true_u.unsqueeze(1).cpu().numpy() #DF26-06-25 gpu (add .cpu())\n",
    "    t_np = t_cf.reshape(-1, 1).cpu().numpy() #DF26-06-25 gpu (add .cpu())\n",
    "    y = torch.from_numpy(call_outcome_func(outcome_func,u_np, t_np, z_test.cpu().numpy(),ua0_test_np, ua1_test_np)).float() #DF26-06-25 gpu (add .cpu())\n",
    "    return y.squeeze(), torch.zeros_like(y.squeeze())\n",
    "\n",
    "# ===== RBF Kernel =====\n",
    "def rbf_kernel(x1, x2, lengthscale, variance):\n",
    "    if x1.ndim == 2:\n",
    "        x1 = x1.mean(dim=1, keepdim=True)  # Average over proxy features\n",
    "    if x2.ndim == 2:\n",
    "        x2 = x2.mean(dim=1, keepdim=True)\n",
    "    dists = torch.cdist(x1, x2).pow(2)  # Output is (N, M)\n",
    "    return variance * torch.exp(-0.5 * dists / lengthscale**2)\n",
    "\n",
    "\n",
    "# ===== GP Posterior for u* from z with Interval Data =====\n",
    "def gp_posterior_u(model, z_train, z_test, gp_lengthscale, gp_variance, gp_noise):\n",
    "    with torch.no_grad():\n",
    "        mu_u, std_u = model.vae.encoder(z_train)[0:2]\n",
    "    \n",
    "    u_lower = (mu_u - std_u).squeeze()\n",
    "    u_upper = (mu_u + std_u).squeeze()\n",
    "\n",
    "    K = rbf_kernel(z_train, z_train, gp_lengthscale, gp_variance)\n",
    "    K += gp_noise * torch.eye(len(z_train), device=z_train.device)\n",
    "\n",
    "    K_s = rbf_kernel(z_train, z_test, gp_lengthscale, gp_variance)\n",
    "    K_ss = rbf_kernel(z_test, z_test, gp_lengthscale, gp_variance) + 1e-6 * torch.eye(len(z_test), device=z_test.device)\n",
    "\n",
    "    jitter = 1e-4\n",
    "    for _ in range(5):\n",
    "        try:\n",
    "            L = torch.linalg.cholesky(K + jitter * torch.eye(len(z_train), device=K.device))\n",
    "            break\n",
    "        except RuntimeError:\n",
    "            jitter *= 10\n",
    "    else:\n",
    "        raise RuntimeError(\"Cholesky failed in gp_posterior\")\n",
    "\n",
    "    # Posterior mean using u_lower and u_upper\n",
    "    alpha_lower = torch.cholesky_solve(u_lower.unsqueeze(1), L)\n",
    "    alpha_upper = torch.cholesky_solve(u_upper.unsqueeze(1), L)\n",
    "\n",
    "    mu_lower = K_s.t() @ alpha_lower\n",
    "    mu_upper = K_s.t() @ alpha_upper\n",
    "\n",
    "    cov_gp = K_ss - K_s.t() @ torch.cholesky_solve(K_s, L)\n",
    "    return mu_lower.squeeze(), mu_upper.squeeze(), cov_gp\n",
    "\n",
    "# ====== Predict u* and ITE using GP Posterior ======\n",
    "def predict_u_gp_conditioned(model, z_train, z_test):\n",
    "    mu_lower, mu_upper, cov_gp = gp_posterior_u(\n",
    "        model, z_train, z_test,\n",
    "        gp_lengthscale=model.vae.gp_lengthscale,\n",
    "        gp_variance=model.vae.gp_variance,\n",
    "        gp_noise=model.vae.gp_noise\n",
    "    )\n",
    "    return mu_lower, mu_upper, cov_gp\n",
    "\n",
    "def predict_y_bounds_from_encoder_samples(model, z_test, n_samples=100, ci=0.90):\n",
    "    \"\"\"\n",
    "    Encoder-driven ITE interval（成对重采样版）：\n",
    "      1) 采样同一批 u ~ q(u|z)（S 次）\n",
    "      2) 每次 MC 同步重采 ua1 ~ q(ua1|z)（S 次），并与对应的 u 成对使用\n",
    "      3) 用同一对 (u^{(s)}, ua1^{(s)}) 同时计算 y(0), y(1)\n",
    "      4) 对 y0, y1, ITE=y1-y0 的样本分布按置信度取分位数得到上下界\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        # ===== z_y 选择逻辑（与原实现一致）=====\n",
    "        if model.causal_head.z_y_dim == z_test.shape[1]:\n",
    "            z_y_test = z_test\n",
    "        elif model.causal_head.z_y_dim:\n",
    "            z_y_test = z_test[:, z_test.shape[1] // 2:]\n",
    "        else:\n",
    "            z_y_test = None\n",
    "\n",
    "        # ===== 编码器后验参数 =====\n",
    "        mu_u, std_u, *_ = model.vae.encoder(z_test)     # (N, d_u)\n",
    "        N, d_u = mu_u.shape\n",
    "        device = mu_u.device\n",
    "\n",
    "        # （可选）给 ua1 单独的温度；没有就回退到 u 的温度\n",
    "        temp_u  = getattr(model.vae, \"sample_temp\", 2.0)\n",
    "        temp_ua = getattr(model.vae, \"sample_temp_ua\", temp_u)\n",
    "\n",
    "        # ===== 采样同一批 u：形状 (S, N, d_u) =====\n",
    "        eps_u     = torch.randn(n_samples, N, d_u, device=device)\n",
    "        u_samples = mu_u.unsqueeze(0) + eps_u * (std_u.unsqueeze(0) * temp_u)\n",
    "\n",
    "        # ===== 采样同一批 ua1：形状 (S, N, d_ua)；若未启用 aux 则为 None =====\n",
    "        mu_ua1, std_ua1, _ = model.vae.get_latent_stats(z_test, var_name='ua1')\n",
    "        if mu_ua1 is not None:\n",
    "            d_ua = mu_ua1.shape[1]\n",
    "            eps_ua1     = torch.randn(n_samples, N, d_ua, device=device)\n",
    "            ua1_samples = mu_ua1.unsqueeze(0) + eps_ua1 * (std_ua1.unsqueeze(0) * temp_ua)\n",
    "            UA1 = ua1_samples.reshape(-1, d_ua)  # (S*N, d_ua)\n",
    "        else:\n",
    "            UA1 = None\n",
    "\n",
    "        # ===== 广播 z_y（如使用）到 (S*N, z_y_dim) =====\n",
    "        if z_y_test is not None:\n",
    "            ZY = z_y_test.unsqueeze(0).expand(n_samples, -1, -1).reshape(-1, z_y_test.shape[1])\n",
    "        else:\n",
    "            ZY = None\n",
    "\n",
    "        # ===== 展平并分别计算 y(0), y(1)（成对：同一 (u^{(s)}, ua1^{(s)})）=====\n",
    "        U  = u_samples.reshape(-1, d_u)     # (S*N, d_u)\n",
    "        t0 = torch.zeros(U.size(0), device=device)\n",
    "        t1 = torch.ones (U.size(0), device=device)\n",
    "\n",
    "        y0_all = model.causal_head(U, t0, z_y=ZY, ua1=UA1).squeeze(-1)  # (S*N,)\n",
    "        y1_all = model.causal_head(U, t1, z_y=ZY, ua1=UA1).squeeze(-1)  # (S*N,)\n",
    "\n",
    "        # ===== 还原形状、构造 ITE 样本 =====\n",
    "        y0_samples = y0_all.view(n_samples, N)   # (S, N)\n",
    "        y1_samples = y1_all.view(n_samples, N)   # (S, N)\n",
    "        ite_samples = y1_samples - y0_samples    # (S, N)\n",
    "\n",
    "        # ===== 分位数区间 =====\n",
    "        lower_q = (1 - ci) / 2\n",
    "        upper_q = 1 - lower_q\n",
    "\n",
    "        y0_lower = y0_samples.quantile(lower_q, dim=0)\n",
    "        y0_upper = y0_samples.quantile(upper_q, dim=0)\n",
    "        y1_lower = y1_samples.quantile(lower_q, dim=0)\n",
    "        y1_upper = y1_samples.quantile(upper_q, dim=0)\n",
    "        ite_lower = ite_samples.quantile(lower_q, dim=0)\n",
    "        ite_upper = ite_samples.quantile(upper_q, dim=0)\n",
    "\n",
    "        return y0_lower, y0_upper, y1_lower, y1_upper, ite_lower, ite_upper\n",
    "\n",
    "\n",
    "# def predict_y_bounds_from_encoder_samples(model, z_test, n_samples=100, ci=0.90):\n",
    "#     model.eval()\n",
    "#     with torch.no_grad():\n",
    "#         # === Infer z_y ===\n",
    "#         if model.causal_head.z_y_dim == z_test.shape[1]:\n",
    "#             z_y_test = z_test\n",
    "#         elif model.causal_head.z_y_dim:\n",
    "#             z_y_test = z_test[:, z_test.shape[1] // 2:]\n",
    "#         else:\n",
    "#             z_y_test = None\n",
    "\n",
    "#         ua1 = model.vae.sample_latent(z_test, var_name='ua1')\n",
    "\n",
    "#         # === Get encoder outputs for u ===\n",
    "#         mu_u, std_u, *_ = model.vae.encoder(z_test)  # (N, latent_dim)\n",
    "#         N, latent_dim = mu_u.shape\n",
    "#         device = mu_u.device\n",
    "\n",
    "#         # === Sample u ~ N(mu_u, std_u^2), shape: (n_samples, N, latent_dim)\n",
    "#         eps = torch.randn(n_samples, N, latent_dim, device=device)\n",
    "#         u_samples = mu_u.unsqueeze(0) + eps * std_u.unsqueeze(0)\n",
    "\n",
    "#         # === Repeat t and any auxiliary inputs ===\n",
    "#         u_all = u_samples.view(-1, latent_dim)  # (n_samples * N, latent_dim)\n",
    "#         t0_all = torch.zeros(n_samples * N, device=device)\n",
    "#         t1_all = torch.ones(n_samples * N, device=device)\n",
    "\n",
    "#         # === Broadcast z_y and ua1 ===\n",
    "#         z_y_all = None\n",
    "#         if z_y_test is not None:\n",
    "#             z_y_all = z_y_test.unsqueeze(0).expand(n_samples, -1, -1).reshape(-1, z_y_test.shape[1])\n",
    "#         ua1_all = None\n",
    "#         if ua1 is not None:\n",
    "#             ua1_all = ua1.unsqueeze(0).expand(n_samples, -1, -1).reshape(-1, ua1.shape[1])\n",
    "\n",
    "#         # === Predict y(0) and y(1) ===\n",
    "#         y0_all = model.causal_head(u_all, t0_all, z_y=z_y_all, ua1=ua1_all).squeeze(-1)\n",
    "#         y1_all = model.causal_head(u_all, t1_all, z_y=z_y_all, ua1=ua1_all).squeeze(-1)\n",
    "\n",
    "#         # === Reshape back to (n_samples, N)\n",
    "#         y0_samples = y0_all.view(n_samples, N)\n",
    "#         y1_samples = y1_all.view(n_samples, N)\n",
    "#         ite_samples = y1_samples - y0_samples\n",
    "\n",
    "#         # === Compute quantiles\n",
    "#         lower_q = (1 - ci) / 2\n",
    "#         upper_q = 1 - lower_q\n",
    "\n",
    "#         y0_lower = y0_samples.quantile(lower_q, dim=0)\n",
    "#         y0_upper = y0_samples.quantile(upper_q, dim=0)\n",
    "#         y1_lower = y1_samples.quantile(lower_q, dim=0)\n",
    "#         y1_upper = y1_samples.quantile(upper_q, dim=0)\n",
    "#         ite_lower = ite_samples.quantile(lower_q, dim=0)\n",
    "#         ite_upper = ite_samples.quantile(upper_q, dim=0)\n",
    "\n",
    "#         return y0_lower, y0_upper, y1_lower, y1_upper, ite_lower, ite_upper\n",
    "\n",
    "def gp_posterior_ite(x_train, y_lower, y_upper, x_test, noise_std, lengthscale, variance):\n",
    "    K = rbf_kernel(x_train, x_train, lengthscale, variance) + noise_std**2 * torch.eye(len(x_train), device=x_train.device)\n",
    "    K_s = rbf_kernel(x_train, x_test, lengthscale, variance)\n",
    "    K_ss = rbf_kernel(x_test, x_test, lengthscale, variance) + 1e-6 * torch.eye(len(x_test), device=x_test.device)\n",
    "\n",
    "    jitter = 1e-4\n",
    "    for _ in range(5):\n",
    "        try:\n",
    "            L = torch.linalg.cholesky(K + jitter * torch.eye(len(x_train), device=K.device))\n",
    "            break\n",
    "        except RuntimeError:\n",
    "            jitter *= 10\n",
    "    else:\n",
    "        raise RuntimeError(\"Cholesky failed in gp_posterior.\")\n",
    "\n",
    "    # Posterior mean using interval bounds\n",
    "    alpha_lower = torch.cholesky_solve(y_lower.unsqueeze(1), L)\n",
    "    alpha_upper = torch.cholesky_solve(y_upper.unsqueeze(1), L)\n",
    "\n",
    "    mu_lower = K_s.t() @ alpha_lower\n",
    "    mu_upper = K_s.t() @ alpha_upper\n",
    "\n",
    "    cov_post = K_ss - K_s.t() @ torch.cholesky_solve(K_s, L)\n",
    "\n",
    "    return mu_lower.squeeze(), mu_upper.squeeze(), cov_post"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a47b322e-06bf-492f-8197-d1ebf9faf188",
   "metadata": {},
   "outputs": [],
   "source": [
    "proxy_func_sets = [\n",
    "    [  # ID-1：\n",
    "        lambda u, ua0, ua1: u,\n",
    "        lambda u, ua0, ua1: torch.sin(u),\n",
    "        lambda u, ua0, ua1: u ** 2,\n",
    "    ],\n",
    "    [  # ID-2：\n",
    "        lambda u, ua0, ua1: torch.tanh(u),\n",
    "        lambda u, ua0, ua1: torch.sin(2 * u),\n",
    "        lambda u, ua0, ua1: torch.log(torch.abs(u) + 1e-3),\n",
    "    ],\n",
    "    [  # ID-3：\n",
    "        lambda u, ua0, ua1: u ** 2 + ua0,\n",
    "        lambda u, ua0, ua1: torch.log1p(torch.abs(u)) + ua1,\n",
    "        lambda u, ua0, ua1: u ** 3 + 0.1 * ua0 * ua1,\n",
    "    ],\n",
    "    [  # ID-4：\n",
    "        lambda u, ua0, ua1: torch.tanh(u) + ua0,\n",
    "        lambda u, ua0, ua1: torch.atan(u) + 0.1 * ua1,\n",
    "        lambda u, ua0, ua1: torch.sin(u) + torch.exp(-torch.abs(ua0)) + u,\n",
    "    ],\n",
    "    [  # ID-5：\n",
    "        lambda u, ua0, ua1: u / (torch.abs(ua0) + 0.1),\n",
    "        lambda u, ua0, ua1: torch.sin(u) / (1 + u ** 2) + 0.05 * ua1,\n",
    "        lambda u, ua0, ua1: torch.log1p(u ** 2) + 0.2 * ua0,\n",
    "    ],\n",
    "    [  # ID-6：\n",
    "        lambda u, ua0, ua1: torch.log1p(torch.exp(u)) + 0.1 * ua0,      \n",
    "        lambda u, ua0, ua1: u ** 3 + 0.1 * ua1,              \n",
    "        lambda u, ua0, ua1: torch.sigmoid(u) + 0.05 * ua0 * ua1,    \n",
    "    ],\n",
    "]\n",
    "\n",
    "# 定义处理函数组合\n",
    "def treat_func_linear_1(u_np, ua0_np):\n",
    "    logits = 1.5 * u_np + 0.8 * ua0_np\n",
    "    probs = 1 / (1 + np.exp(-logits))\n",
    "    return np.random.binomial(1, probs).astype(np.float32)\n",
    "\n",
    "def treat_func_linear_2(u_np, ua0_np):\n",
    "    logits = 0.5 * u_np + 1.2 * ua0_np\n",
    "    probs = 1 / (1 + np.exp(-logits))\n",
    "    return np.random.binomial(1, probs).astype(np.float32)\n",
    "\n",
    "def treat_func_nonlinear(u_np, ua0_np):\n",
    "    logits = np.power(u_np, 2) + 0.5 * ua0_np\n",
    "    probs = 1 / (1 + np.exp(-logits))\n",
    "    return np.random.binomial(1, probs).astype(np.float32)\n",
    "    \n",
    "treatment_func_list = [treat_func_linear_1, treat_func_linear_2]\n",
    "\n",
    "# 定义结果函数组合\n",
    "def outcome_func_simple(u_np, t_np, ua0=None, ua1=None):\n",
    "    base = u_np + t_np.reshape(-1, 1) * (0.5 + 0.3 * u_np)\n",
    "    if ua1 is not None:\n",
    "        base += 0.4 * ua1\n",
    "    noise = np.random.normal(0, 0.1, size=base.shape)\n",
    "    return base + noise\n",
    "\n",
    "def outcome_func_1(u_np, t_np, ua0=None, ua1=None):\n",
    "    base = np.sin(u_np) + t_np.reshape(-1, 1) * u_np\n",
    "    if ua0 is not None:\n",
    "        base += 0.3 * ua0\n",
    "    if ua1 is not None:\n",
    "        base += 0.3 * np.cos(ua1)\n",
    "    noise = np.random.normal(0, 0.1, size=base.shape)\n",
    "    return base + noise\n",
    "\n",
    "def outcome_func_2(u_np, t_np, ua0=None, ua1=None):\n",
    "    base = np.sin(u_np) + t_np.reshape(-1, 1) + 0.5 * u_np * t_np.reshape(-1, 1)\n",
    "    if ua1 is not None:\n",
    "        base += 0.5 * np.cos(ua1)\n",
    "    noise = np.random.normal(0, 0.1, size=base.shape)\n",
    "    return base + noise\n",
    "\n",
    "outcome_func_list = [outcome_func_1, outcome_func_2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8619778-84e1-4182-821e-1665a9ba488a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "import textwrap\n",
    "\n",
    "def print_function_source(title, funcs):\n",
    "    print(f\"{title}:\")\n",
    "    if isinstance(funcs, list):\n",
    "        for i, f in enumerate(funcs):\n",
    "            print(f\"  Z_{i+1}(u):\")\n",
    "            try:\n",
    "                src = inspect.getsource(f).strip()\n",
    "                print(textwrap.indent(src, \"    \"))\n",
    "            except:\n",
    "                print(\"    [source not available]\")\n",
    "    else:\n",
    "        try:\n",
    "            src = inspect.getsource(funcs).strip()\n",
    "            print(textwrap.indent(src, \"  \"))\n",
    "        except:\n",
    "            print(\"  [source not available]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae614525-0259-4069-9f99-0109374b0da4",
   "metadata": {},
   "source": [
    "# Run main code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ac2a11-b403-4d41-9842-bba072c65e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "iteration_id = 0\n",
    "tedvae_ates = []\n",
    "true_ite_ates = []\n",
    "intervalgpvae_ates = []\n",
    "\n",
    "coverage_all = []\n",
    "ate_error_all = []\n",
    "mse_y_infer_u_all = []\n",
    "interval_length_all = []\n",
    "\n",
    "intervalgpvae_pehes = []\n",
    "tedvae_pehes_tests = []\n",
    "tedvae_pehes_trains = []\n",
    "chosen_version = \"u_aux\"\n",
    "for proxy_funcs, treat_func, outcome_func in itertools.product(proxy_func_sets, treatment_func_list, outcome_func_list):\n",
    "    iteration_id += 1\n",
    "    print_function_source(\"Proxy functions\", proxy_funcs)\n",
    "    print_function_source(\"Treatment function\", treat_func)\n",
    "    print_function_source(\"Outcome function\", outcome_func)\n",
    "\n",
    "    \n",
    "\n",
    "    data_train = generate_synthetic_data_with_aux_uas(n=1000,\n",
    "                                                      num_proxies=len(proxy_funcs),\n",
    "                                                      proxy_funcs=deepcopy(proxy_funcs),\n",
    "                                                      treatment_func=deepcopy(treat_func),\n",
    "                                                      outcome_func=deepcopy(outcome_func),\n",
    "                                                      seed=42)\n",
    "\n",
    "    data_test = generate_synthetic_data_with_aux_uas(n=50, \n",
    "                                                     num_proxies=len(proxy_funcs),\n",
    "                                                     proxy_funcs=deepcopy(proxy_funcs),\n",
    "                                                     treatment_func=deepcopy(treat_func),\n",
    "                                                     outcome_func=deepcopy(outcome_func),\n",
    "                                                     seed=100 + iteration_id)\n",
    "    z, t, y, u_true = data_train['z'], data_train['t'], data_train['y'], data_train['u']\n",
    "    proxy_funcs = data_train['proxy_funcs']\n",
    "    outcome_func = data_train['outcome_func']\n",
    "    treatment_func = data_train['treatment_func']\n",
    "    ua0_true = data_train.get('ua0', None)\n",
    "    ua1_true = data_train.get('ua1', None)\n",
    "    z_train_gp, z_test, true_u_train, true_u_test, t_test, y_test, train_indices,true_ua0_train,true_ua1_train,true_ua0_test,true_ua1_test=\\\n",
    "    select_train_from_z_test_from_external(z, u_true, t, y, data_test, n_train=5) #DF: n_train=10, 5, 20\n",
    "    \n",
    "    if true_ua0_train is not None:\n",
    "        ua0_train_np = true_ua0_train.cpu().numpy()\n",
    "        ua0_train_np = ua0_train_np.reshape(-1,1)\n",
    "    else:\n",
    "        ua0_train_np = None\n",
    "    if true_ua1_train is not None:\n",
    "        ua1_train_np = true_ua1_train.cpu().numpy()\n",
    "        ua1_train_np = ua1_train_np.reshape(-1,1)\n",
    "    else:\n",
    "        ua1_train_np = None\n",
    "        \n",
    "    if true_ua0_test is not None:\n",
    "        ua0_test_np = true_ua0_test.cpu().numpy()\n",
    "        ua0_test_np = ua0_test_np.reshape(-1,1)\n",
    "    else:\n",
    "        ua0_test_np = None\n",
    "    if true_ua1_test is not None:\n",
    "        ua1_test_np = true_ua1_test.cpu().numpy()\n",
    "        ua1_test_np = ua1_test_np.reshape(-1,1)\n",
    "    else:\n",
    "        ua1_test_np = None\n",
    "        \n",
    "    device = \"cpu\"\n",
    "    z, t, y, u_true = z.to(device), t.to(device), y.to(device), u_true.to(device)\n",
    "    z_test, true_u_test = z_test.to(device), true_u_test.to(device)\n",
    "    t_test, y_test = t_test.to(device), y_test.to(device)\n",
    "    z_train_gp = z_train_gp.to(device)\n",
    "\n",
    "    \n",
    "    use_z_y = chosen_version in [\"z_to_y\", \"split_z_to_t_and_y\"]\n",
    "    use_aux = chosen_version == \"u_aux\"\n",
    "    \n",
    "    if chosen_version == \"z_to_y\":\n",
    "        z_y_dim = z.shape[1]\n",
    "    elif chosen_version == \"split_z_to_t_and_y\":\n",
    "        z_y_dim = z.shape[1] // 2 \n",
    "    else:\n",
    "        z_y_dim = None\n",
    "    \n",
    "    loader = DataLoader(TensorDataset(z, t, y, u_true), batch_size=128, shuffle=True)\n",
    "    model = CausalGPVAEwithNoise(input_dim=z.shape[1],\n",
    "                                 hidden_dim=64,\n",
    "                                 latent_dim=1,\n",
    "                                 z_y_dim=z_y_dim,\n",
    "                                 use_auxiliary_latents=use_aux,\n",
    "                                 gp_lengthscale=1.0,\n",
    "                                 gp_variance=1.0,\n",
    "                                 gp_noise=1e-4).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    \n",
    "    for epoch in range(200):\n",
    "        recon_losses, causal_losses = [], []\n",
    "        model.train()\n",
    "        for batch in loader:\n",
    "            z_batch, t_batch, y_batch, _ = [b.to(device) for b in batch]\n",
    "            loss, info = model(z_batch, t_batch, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            recon_losses.append(info['recon_loss'])\n",
    "            causal_losses.append(info['causal_loss'])\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            model.eval()\n",
    "            mu_u_test = model.vae.encoder(z_test.to(device))[0]\n",
    "            u_test_loss = F.mse_loss(mu_u_test, true_u_test.to(device).view(-1, 1)).item()\n",
    "\n",
    "    \n",
    "    for param in model.vae.parameters():\n",
    "        param.requires_grad = False\n",
    "    \n",
    "    causal_optimizer = torch.optim.Adam(model.causal_head.parameters(), lr=1e-3)\n",
    "    \n",
    "    for epoch in range(100):\n",
    "        causal_losses = []\n",
    "        model.train()\n",
    "        for z_batch, t_batch, y_batch, _ in loader:\n",
    "            z_batch, t_batch, y_batch = z_batch.to(device), t_batch.to(device), y_batch.to(device)\n",
    "            mu_u = model.vae.encoder(z_batch)[0].detach()\n",
    "            ua1 = None \n",
    "            if model.use_aux: \n",
    "                mu_ua1 = model.vae.encoder(z_batch)[9]  \n",
    "                std_ua1 = model.vae.encoder(z_batch)[10]\n",
    "                ua1 = model.vae.reparameterize(mu_ua1, std_ua1).detach()\n",
    "            \n",
    "            if model.z_y_dim is not None: \n",
    "                if model.z_y_dim == z_batch.shape[1]: \n",
    "                    z_y_batch = z_batch\n",
    "                else: \n",
    "                    z_y_batch = z_batch[:, -model.z_y_dim:]\n",
    "            else:\n",
    "                z_y_batch = None\n",
    "                \n",
    "            y_pred = model.causal_head(mu_u, t_batch, z_y=z_y_batch, ua1=ua1)\n",
    "            causal_loss = F.mse_loss(y_pred, y_batch.unsqueeze(-1))\n",
    "            causal_optimizer.zero_grad()\n",
    "            causal_loss.backward()\n",
    "            causal_optimizer.step()\n",
    "            causal_losses.append(causal_loss.item())\n",
    "\n",
    "    \n",
    "    for param in model.causal_head.parameters():\n",
    "        param.requires_grad = False\n",
    "    \n",
    "    for param in model.vae.parameters():\n",
    "        param.requires_grad = True\n",
    "    \n",
    "    vae_optimizer = torch.optim.Adam(model.vae.parameters(), lr=1e-4)\n",
    "    for epoch in range(50):\n",
    "        recon_losses, causal_losses, total_losses = [], [], []\n",
    "        model.train()\n",
    "        for batch in loader:\n",
    "            z_batch, t_batch, y_batch, _ = [b.to(device) for b in batch]\n",
    "            if chosen_version == \"z_to_y\":\n",
    "                z_y_batch = z_batch\n",
    "            elif chosen_version == \"split_z_to_t_and_y\":\n",
    "                z_y_batch = z_batch[:, z.shape[1] // 2:]\n",
    "            else:\n",
    "                z_y_batch = None\n",
    "            \n",
    "            ua1 = model.vae.sample_latent(z_batch, var_name='ua1')\n",
    "            loss_vae, vae_info = model.vae(z_batch)\n",
    "            with torch.no_grad():\n",
    "                y_pred = model.causal_head(vae_info['u'], t_batch, z_y=z_y_batch, ua1=ua1)\n",
    "            \n",
    "            causal_loss = F.mse_loss(y_pred, y_batch.unsqueeze(-1), reduction='sum')\n",
    "            total_loss = loss_vae + 1.0 * causal_loss  # weight can be tuned\n",
    "            vae_optimizer.zero_grad()\n",
    "            total_loss.backward()\n",
    "            vae_optimizer.step()\n",
    "            recon_losses.append(vae_info['recon_loss'])\n",
    "            causal_losses.append(causal_loss.item())\n",
    "            total_losses.append(total_loss.item())\n",
    "            \n",
    "    t0_np = np.zeros_like(true_u_test.unsqueeze(1).cpu().numpy()) \n",
    "    t1_np = np.ones_like(true_u_test.unsqueeze(1).cpu().numpy())\n",
    "    y0_true = call_outcome_func(outcome_func, true_u_test.unsqueeze(1).cpu().numpy(), t0_np, z_test.cpu().numpy(), ua0_test_np, ua1_test_np) \n",
    "    y1_true = call_outcome_func(outcome_func, true_u_test.unsqueeze(1).cpu().numpy(), t1_np, z_test.cpu().numpy(), ua0_test_np, ua1_test_np)\n",
    "    true_ite = (y1_true - y0_true).squeeze()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        z_recon = model.vae(z_test)[1]['z_recon']\n",
    "    \n",
    "    proxy_dim = z_test.shape[1]\n",
    "    lims_list = []\n",
    "    true_g_u = torch.cat([\n",
    "        g(true_u_test.unsqueeze(1), true_ua0_test.unsqueeze(1), true_ua1_test.unsqueeze(1)).detach()\n",
    "        if g.__code__.co_argcount >= 3 else\n",
    "        g(true_u_test.unsqueeze(1)).detach()\n",
    "        for g in proxy_funcs\n",
    "    ], dim=1)\n",
    "    \n",
    "    true_eps = z_test - true_g_u\n",
    "    eps_est = model.vae.encoder(z_test)[3].detach() \n",
    "    eps_dim = true_eps.shape[1]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        mu_u_test, _, _ = model.vae.get_latent_stats(z_test, var_name='u')\n",
    "        mu_eps_test, std_eps_test, _ = model.vae.get_latent_stats(z_test, var_name='eps')\n",
    "        ua1_test = model.vae.sample_latent(z_test, var_name='ua1')\n",
    "        \n",
    "        if chosen_version == \"z_to_y\":\n",
    "            z_y_test = z_test\n",
    "        elif chosen_version == \"split_z_to_t_and_y\":\n",
    "            z_y_test = z_test[:, z_test.shape[1] // 2:]\n",
    "        else:\n",
    "            z_y_test = None\n",
    "        \n",
    "        y_pred_inferred_u = model.causal_head(mu_u_test, t_test.to(device), z_y=z_y_test.to(device) if z_y_test is not None else None, ua1=ua1_test).squeeze()\n",
    "        \n",
    "    u_mean_post, u_std_post = estimate_latent_u_posterior(model, z_test)\n",
    "    true_u_np = true_u_test.detach().cpu().numpy() if torch.is_tensor(true_u_test) else true_u_test\n",
    "    x = np.arange(len(z_test))\n",
    "    \n",
    "    t0 = torch.zeros(len(z_test), device=device)\n",
    "    t1 = torch.ones(len(z_test), device=device) \n",
    "    \n",
    "    y0_m1, y0_s1 = estimate_counterfactual(model, z_test, t0)\n",
    "    y1_m1, y1_s1 = estimate_counterfactual(model, z_test, t1)\n",
    "    \n",
    "    ite_m1 = y1_m1 - y0_m1\n",
    "    ite_s1 = torch.sqrt(y1_s1**2 + y0_s1**2)\n",
    "    true_ite_np = true_ite if isinstance(true_ite, np.ndarray) else true_ite.detach().cpu().numpy()\n",
    "\n",
    "    ite_m1_tensor = torch.tensor(ite_m1, dtype=torch.float32).squeeze().cpu()\n",
    "    true_ite_tensor = torch.tensor(true_ite_np, dtype=torch.float32).squeeze().cpu()\n",
    "    IntervalGP_VAE_pehe = torch.sqrt(F.mse_loss(ite_m1_tensor, true_ite_tensor)).item()\n",
    "    \n",
    "    \n",
    "    model.vae.gp_lengthscale = 0.5\n",
    "    model.vae.gp_variance = 2.0\n",
    "    mu_lower, mu_upper, cov_gp = predict_u_gp_conditioned(model, z_train_gp, z_test)\n",
    "    mu_lower_np = mu_lower.detach().cpu().numpy()\n",
    "    mu_upper_np = mu_upper.detach().cpu().numpy()\n",
    "    std_gp_np = torch.sqrt(torch.diag(cov_gp)).detach().cpu().numpy()\n",
    "    true_u_test_np = true_u_test.cpu().numpy()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        mu_u_train, std_u_train = model.vae.encoder(z_train_gp)[0:2]\n",
    "        u_train_lower_np = (mu_u_train - std_u_train).squeeze().cpu().numpy()\n",
    "        u_train_upper_np = (mu_u_train + std_u_train).squeeze().cpu().numpy()\n",
    "        true_u_train_np = true_u_train.cpu().numpy()\n",
    "    \n",
    "    x_all = np.concatenate([true_u_train_np, true_u_test_np])\n",
    "    x_all = np.argsort(np.argsort(x_all))  \n",
    "    type_all = np.array(['train'] * len(true_u_train_np) + ['test'] * len(true_u_test_np))\n",
    "    \n",
    "    u_lower_all = np.concatenate([u_train_lower_np, mu_lower_np])\n",
    "    u_upper_all = np.concatenate([u_train_upper_np, mu_upper_np])\n",
    "    \n",
    "    std_train_np = np.full_like(u_train_lower_np, np.sqrt(model.vae.gp_noise))\n",
    "    std_all = np.concatenate([std_train_np, std_gp_np])\n",
    "    true_u_all = np.concatenate([true_u_train_np, true_u_test_np])\n",
    "    \n",
    "    z_test = z_test.to(next(model.parameters()).device)\n",
    "    y0_lower, y0_upper, y1_lower, y1_upper, ite_lower, ite_upper = predict_y_bounds_from_encoder_samples(model=model,\n",
    "                                                                                                         z_test=z_test,\n",
    "                                                                                                         n_samples=100, \n",
    "                                                                                                         ci=0.90)\n",
    "    true_u_test_np = true_u_test.cpu().numpy()\n",
    "    rank_x = np.argsort(np.argsort(true_u_test_np))\n",
    "    ite_midpoint_np = ((y1_lower + y1_upper) / 2 - (y0_lower + y0_upper) / 2).detach().cpu().numpy()\n",
    "    ite_lower_np = ite_lower.detach().cpu().numpy()\n",
    "    ite_upper_np = ite_upper.detach().cpu().numpy()\n",
    "    \n",
    "    if isinstance(true_ite, torch.Tensor):\n",
    "        true_ite_np = true_ite.detach().cpu().numpy()\n",
    "    else:\n",
    "        true_ite_np = true_ite\n",
    "        \n",
    "    mu_u_train, std_u_train = model.vae.encoder(z_train_gp)[:2]\n",
    "    mu_u_train = mu_u_train.view(-1, 1)\n",
    "    std_u_train = std_u_train.view(-1, 1)\n",
    "    \n",
    "    t0 = torch.zeros(mu_u_train.shape[0], device=mu_u_train.device)  \n",
    "    t1 = torch.ones(mu_u_train.shape[0], device=mu_u_train.device)   \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        if model.causal_head.z_y_dim == z_train_gp.shape[1]:\n",
    "            z_y_train = z_train_gp\n",
    "        elif model.causal_head.z_y_dim:\n",
    "            z_y_train = z_train_gp[:, z_train_gp.shape[1] // 2:]\n",
    "        else:\n",
    "            z_y_train = None\n",
    "        \n",
    "        ua1_train = model.vae.sample_latent(z_train_gp, var_name='ua1')\n",
    "        y0_mean_train = model.causal_head(mu_u_train, t0, z_y=z_y_train, ua1=ua1_train).squeeze(-1)\n",
    "        y1_mean_train = model.causal_head(mu_u_train, t1, z_y=z_y_train, ua1=ua1_train).squeeze(-1)\n",
    "        ite_mean_train = y1_mean_train - y0_mean_train\n",
    "    \n",
    "    y0_std_train = std_u_train.squeeze(-1)\n",
    "    y1_std_train = std_u_train.squeeze(-1)\n",
    "    ite_std_train = torch.sqrt(y0_std_train**2 + y1_std_train**2)\n",
    "    \n",
    "    ite_lower_train = ite_mean_train - ite_std_train\n",
    "    ite_upper_train = ite_mean_train + ite_std_train\n",
    "    \n",
    "    outcome_func = data_train['outcome_func']  \n",
    "    true_u_train_np = true_u_train.detach().cpu().numpy().reshape(-1, 1)\n",
    "    t0_np = np.zeros_like(true_u_train_np)\n",
    "    t1_np = np.ones_like(true_u_train_np)\n",
    "    \n",
    "    y0_train_np = call_outcome_func(outcome_func, true_u_train_np, t0_np, z_train_gp.cpu().numpy(),ua0_train_np, ua1_train_np) \n",
    "    y1_train_np = call_outcome_func(outcome_func, true_u_train_np, t1_np, z_train_gp.cpu().numpy(),ua0_train_np, ua1_train_np) \n",
    "    \n",
    "    true_ite_train_np = (y1_train_np - y0_train_np).reshape(-1)\n",
    "    \n",
    "    ite_lower_test_np = ite_lower.detach().cpu().numpy().reshape(-1)\n",
    "    ite_upper_test_np = ite_upper.detach().cpu().numpy().reshape(-1)\n",
    "    ite_midpoint_test_np = ((ite_lower + ite_upper) / 2).detach().cpu().numpy().reshape(-1)\n",
    "    true_ite_test_np = np.asarray(true_ite).reshape(-1)\n",
    "    true_u_test_np = true_u_test.detach().cpu().numpy().reshape(-1)\n",
    "    \n",
    "    ite_lower_train_np = ite_lower_train.detach().cpu().numpy().reshape(-1)\n",
    "    ite_upper_train_np = ite_upper_train.detach().cpu().numpy().reshape(-1)\n",
    "    ite_midpoint_train_np = ((ite_lower_train + ite_upper_train) / 2).detach().cpu().numpy().reshape(-1)\n",
    "    true_u_train_np = true_u_train.detach().cpu().numpy().reshape(-1)\n",
    "    \n",
    "    ite_lower_combined = np.concatenate([ite_lower_train_np, ite_lower_test_np])\n",
    "    ite_upper_combined = np.concatenate([ite_upper_train_np, ite_upper_test_np])\n",
    "    ite_midpoint_combined = np.concatenate([ite_midpoint_train_np, ite_midpoint_test_np])\n",
    "    true_ite_combined = np.concatenate([true_ite_train_np, true_ite_test_np])\n",
    "    true_u_combined = np.concatenate([true_u_train_np, true_u_test_np])\n",
    "\n",
    "    # Plot ITE bound without using GP\n",
    "    N_train = len(true_u_train_np)\n",
    "    N_total = len(true_u_combined)\n",
    "    rank_x   = np.argsort(np.argsort(true_u_combined))    \n",
    "    sort_idx = np.argsort(rank_x)\n",
    "    x_sorted         = rank_x[sort_idx]\n",
    "    ite_lo_sorted    = ite_lower_combined[sort_idx]\n",
    "    ite_up_sorted    = ite_upper_combined[sort_idx]\n",
    "    ite_mid_sorted   = ite_midpoint_combined[sort_idx]\n",
    "    true_ite_sorted  = true_ite_combined[sort_idx]\n",
    "    is_train_mask        = np.zeros(N_total, dtype=bool)\n",
    "    is_train_mask[:N_train] = True\n",
    "    is_train_sorted      = is_train_mask[sort_idx]\n",
    "\n",
    "    x0, x1 = 0, len(x_sorted) - 1\n",
    "    plt.figure(figsize=(10, 5))\n",
    "    plt.xlim(x0, x1)\n",
    "    plt.plot(x_sorted, ite_lo_sorted,  'g--', label='ITE Lower')\n",
    "    plt.plot(x_sorted, ite_up_sorted,  'r--', label='ITE Upper')\n",
    "    plt.fill_between(x_sorted, ite_lo_sorted, ite_up_sorted, alpha=0.3, color='lightblue', label='ITE Interval')\n",
    "    \n",
    "    plt.plot(x_sorted[is_train_sorted], true_ite_sorted[is_train_sorted], 'm^', label='True ITE (Train)', markersize=6)\n",
    "    plt.plot(x_sorted[~is_train_sorted], true_ite_sorted[~is_train_sorted], 'kx', label='True ITE (Test)', markersize=6)\n",
    "    plt.plot(x_sorted[~is_train_sorted], true_ite_sorted[~is_train_sorted], 'k-', alpha=0.6)\n",
    "    \n",
    "    plt.title(\"ITE Interval (Train + Test, ranked by true u)\")\n",
    "    plt.xlabel(\"Rank of true u (train + test)\")\n",
    "    plt.ylabel(\"Individual Treatment Effect (ITE)\")\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"ite_interval_ranku_iter{iteration_id}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    gp_lengthscale = 0.4\n",
    "    gp_variance = 5.0\n",
    "    gp_noise = 1e-4\n",
    "    _,_,_,_,ite_lower_train, ite_upper_train = predict_y_bounds_from_encoder_samples(model, z_train_gp, n_samples=100, ci=0.90)\n",
    "    ite_lower_test, ite_upper_test, cov_post = gp_posterior_ite(z_train_gp, ite_lower_train, ite_upper_train, z_test,\n",
    "                                                                noise_std=torch.tensor(gp_noise, device=z_train_gp.device),\n",
    "                                                                lengthscale=gp_lengthscale, variance=gp_variance)\n",
    "    \n",
    "    ite_lower_train_np = ite_lower_train.detach().cpu().numpy().squeeze()\n",
    "    ite_upper_train_np = ite_upper_train.detach().cpu().numpy().squeeze()\n",
    "    ite_lower_test_np = ite_lower_test.detach().cpu().numpy().squeeze()\n",
    "    ite_upper_test_np = ite_upper_test.detach().cpu().numpy().squeeze()\n",
    "    std_gp_np = torch.sqrt(torch.diag(cov_post)).detach().cpu().numpy().squeeze()\n",
    "    \n",
    "    ite_lower_all = np.concatenate([ite_lower_train_np, ite_lower_test_np])\n",
    "    ite_upper_all = np.concatenate([ite_upper_train_np, ite_upper_test_np])\n",
    "    std_all = np.concatenate([np.full_like(ite_lower_train_np, np.sqrt(gp_noise)), std_gp_np])\n",
    "    type_all = np.array(['train'] * len(ite_lower_train_np) + ['test'] * len(ite_lower_test_np))\n",
    "    \n",
    "    true_u_train_np = true_u_train.detach().cpu().numpy().squeeze()\n",
    "    true_u_test_np = true_u_test.detach().cpu().numpy().squeeze()\n",
    "    t0_train_np = np.zeros_like(true_u_train_np)\n",
    "    t1_train_np = np.ones_like(true_u_train_np)\n",
    "    t0_test_np = np.zeros_like(true_u_test_np)\n",
    "    t1_test_np = np.ones_like(true_u_test_np)\n",
    "    \n",
    "    y0_train_np = call_outcome_func(outcome_func,true_u_train_np, t0_train_np, z_train_gp.numpy(), ua0_train_np, ua1_train_np)\n",
    "    y1_train_np = call_outcome_func(outcome_func,true_u_train_np, t1_train_np, z_train_gp.numpy(), ua0_train_np, ua1_train_np)\n",
    "    y0_test_np =  call_outcome_func(outcome_func, true_u_test_np, t0_test_np, z_test.numpy(), ua0_test_np, ua1_test_np)\n",
    "    y1_test_np =  call_outcome_func(outcome_func, true_u_test_np, t1_test_np, z_test.numpy(), ua0_test_np, ua1_test_np)\n",
    "    \n",
    "    true_ite_train_np = y1_train_np - y0_train_np\n",
    "    true_ite_test_np = y1_test_np - y0_test_np\n",
    "    true_ite_all = np.concatenate([true_ite_train_np.reshape(-1),true_ite_test_np.reshape(-1)])\n",
    "    # 6) 用 true u 的秩作为 x 轴并整体排序\n",
    "    true_u_train_np = true_u_train.detach().cpu().numpy().squeeze()\n",
    "    true_u_test_np  = true_u_test.detach().cpu().numpy().squeeze()\n",
    "    x_all = np.concatenate([true_u_train_np, true_u_test_np]).squeeze()\n",
    "    x_rank = np.argsort(np.argsort(x_all))      \n",
    "    \n",
    "    sort_idx          = np.argsort(x_rank)\n",
    "    x_sorted          = x_rank[sort_idx]\n",
    "    ite_lower_sorted  = ite_lower_all[sort_idx]\n",
    "    ite_upper_sorted  = ite_upper_all[sort_idx]\n",
    "    std_sorted        = std_all[sort_idx]\n",
    "    type_sorted       = type_all[sort_idx]\n",
    "    true_ite_sorted   = true_ite_all[sort_idx]\n",
    "\n",
    "    \n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.xlim(x0, x1)\n",
    "    plt.fill_between(x_sorted,\n",
    "                     ite_lower_sorted - 2 * std_sorted,\n",
    "                     ite_upper_sorted + 2 * std_sorted,\n",
    "                     alpha=0.3, label='GP Posterior ± 2σ')\n",
    "    plt.plot(x_sorted, ite_lower_sorted, 'g--', label='GP Lower')\n",
    "    plt.plot(x_sorted, ite_upper_sorted, 'r--', label='GP Upper')\n",
    "    \n",
    "    for xi, lo, up, typ in zip(x_sorted, ite_lower_sorted, ite_upper_sorted, type_sorted):\n",
    "        color = 'purple' if typ == 'train' else 'blue'\n",
    "        plt.plot([xi, xi], [lo, up], color=color, lw=2, alpha=0.7) \n",
    "    plt.plot(x_sorted, true_ite_sorted, 'k*', label='True ITE')\n",
    "    plt.plot([], [], color='blue',   lw=2, label='GP Interval (Test)')\n",
    "    plt.plot([], [], color='purple', lw=2, label='Encoder Interval (Train)')   \n",
    "    plt.title(\"GP Posterior over ITE (Train + Test, Sorted by Rank(true u))\")\n",
    "    plt.xlabel(\"Rank of true u (train + test)\")\n",
    "    plt.ylabel(\"ITE\")\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"ite_gp_posterior_iter{iteration_id}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    coverage = ((ite_lower <= torch.tensor(true_ite_np)) & (ite_upper >= torch.tensor(true_ite_np))).float().mean().item()\n",
    "    coverage_all.append(coverage)    \n",
    "    interval_length = (ite_upper - ite_lower).mean().item()\n",
    "    interval_length_all.append(interval_length)\n",
    "\n",
    "    # TEDVAE part\n",
    "    z_train, t_train, y_train, u_train_true = data_train['z'], data_train['t'], data_train['y'], data_train['u']\n",
    "    ym, ys = y_train.mean(), y_train.std()\n",
    "    y_train = (y_train - ym) / ys\n",
    "    ua0_train_true = data_train.get('ua0', None)\n",
    "    ua1_train_true = data_train.get('ua1', None)\n",
    "    ua1_train_true_np = ua1_train_true.numpy().reshape(-1, 1)\n",
    "    u_train_true_np = u_train_true.unsqueeze(-1).numpy()\n",
    "    \n",
    "    t1 = np.ones_like(t_train.numpy())\n",
    "    t0 = np.zeros_like(t_train.numpy())\n",
    "    \n",
    "    ua1_train_true_np = ua1_train_true.numpy().reshape(-1, 1)  \n",
    "    true_ua1_test_np = true_ua1_test.numpy().reshape(1, -1)    \n",
    "    \n",
    "    true_ite_train = simple_outcome_func_with_aux_u(u_train_true_np, t1, None, ua1_train_true_np) - simple_outcome_func_with_aux_u(u_train_true_np, t0, None, ua1_train_true_np)\n",
    "    true_ite_test = simple_outcome_func_with_aux_u(true_u_test_np, t1, None, true_ua1_test_np) - simple_outcome_func_with_aux_u(true_u_test_np, t0, None, true_ua1_test_np)\n",
    "    \n",
    "    args = {\"feature_dim\": 3,\n",
    "            \"latent_dim\": 20,\n",
    "            \"latent_dim_t\": 10,\n",
    "            \"latent_dim_y\": 10,\n",
    "            \"hidden_dim\": 500,\n",
    "            \"num_layers\": 4,\n",
    "            \"num_epochs\": 200,\n",
    "            \"batch_size\": 1000,\n",
    "            \"learning_rate\": 1e-3,\n",
    "            \"learning_rate_decay\": 0.01,\n",
    "            \"weight_decay\": 1e-4,\n",
    "            \"seed\": 1234567890}\n",
    "    \n",
    "    pyro.clear_param_store()\n",
    "    tedvae = TEDVAE(feature_dim=args[\"feature_dim\"],\n",
    "                    continuous_dim=z_train.shape[1],\n",
    "                    binary_dim=0,\n",
    "                    latent_dim=args[\"latent_dim\"],\n",
    "                    latent_dim_t=args[\"latent_dim_t\"],\n",
    "                    latent_dim_y=args[\"latent_dim_y\"],\n",
    "                    hidden_dim=args[\"hidden_dim\"],\n",
    "                    num_layers=args[\"num_layers\"],\n",
    "                    num_samples=10)\n",
    "    \n",
    "    tedvae.fit(z_train, t_train, y_train,\n",
    "               num_epochs=args[\"num_epochs\"],\n",
    "               batch_size=args[\"batch_size\"],\n",
    "               learning_rate=args[\"learning_rate\"],\n",
    "               learning_rate_decay=args[\"learning_rate_decay\"],\n",
    "               weight_decay=args[\"weight_decay\"])\n",
    "    \n",
    "    est_ite = tedvae.ite(z_test, ym, ys)\n",
    "    est_ite_train = tedvae.ite(z_train, ym, ys)\n",
    "    \n",
    "    true_ite_test_torch = torch.tensor(true_ite_test.squeeze(), dtype=torch.float32)\n",
    "    est_ite_test_torch = est_ite.squeeze().cpu().float()\n",
    "    TEDVAE_pehe_test = torch.sqrt(F.mse_loss(est_ite_test_torch, true_ite_test_torch)).item()\n",
    "    \n",
    "    true_ite_train_torch = torch.tensor(true_ite_train.squeeze(), dtype=torch.float32)\n",
    "    est_ite_train_torch = est_ite_train.squeeze().cpu().float()\n",
    "    TEDVAE_pehe_train = torch.sqrt(F.mse_loss(est_ite_train_torch, true_ite_train_torch)).item()\n",
    "\n",
    "    print(f\"ItervalGP VAE PEHE: {IntervalGP_VAE_pehe:.4f}\")\n",
    "    print(f\"TEDVAE PEHE (test): {TEDVAE_pehe_test:.4f}\")\n",
    "    intervalgpvae_pehes.append(IntervalGP_VAE_pehe)\n",
    "    tedvae_pehes_tests.append(TEDVAE_pehe_test)\n",
    "    tedvae_pehes_trains.append(TEDVAE_pehe_train)\n",
    "    \n",
    "    true_ate = np.mean(true_ite_np)\n",
    "    tedvae_ate = est_ite.mean().item()\n",
    "    intervalgpvae_ate = ite_m1_tensor.mean().item()\n",
    "\n",
    "    true_ite_ates.append(true_ate)\n",
    "    tedvae_ates.append(tedvae_ate)\n",
    "    intervalgpvae_ates.append(intervalgpvae_ate)\n",
    "    \n",
    "    tedvae_ate_error = abs(tedvae_ate - true_ate)\n",
    "    intervalgpvae_ate_error = abs(intervalgpvae_ate - true_ate)\n",
    "    ate_error_all.append({\"intervalgpvae\": intervalgpvae_ate_error,\n",
    "                          \"tedvae\": tedvae_ate_error})\n",
    "    \n",
    "    interval_length = (ite_upper - ite_lower).mean().item()\n",
    "    interval_length_all.append(interval_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c67da5-a867-43d7-8ae4-3e07bfac9cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_intervalgpvae = np.mean(intervalgpvae_pehes)\n",
    "avg_tedvae_test = np.mean(tedvae_pehes_tests)\n",
    "print(f\"Avg ItervalGP VAE PEHE: {avg_intervalgpvae:.4f}\")\n",
    "print(f\"Avg TEDVAE PEHE (test): {avg_tedvae_test:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "313b5879-b1dd-467e-9954-778921a851b2",
   "metadata": {},
   "source": [
    "# Plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3c2ee9-d6b5-4864-b65a-0c373b98f032",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ======================== 构建绘图用数据（与训练循环对齐） ========================\n",
    "# 从 ate_error_all 拆出两类 ATE 误差\n",
    "intervalgpvae_errors = [d[\"intervalgpvae\"] for d in ate_error_all]\n",
    "tedvae_errors        = [d[\"tedvae\"]        for d in ate_error_all]\n",
    "\n",
    "# 计算共同长度 N，防止列表长度不一致导致越界或右侧空白\n",
    "lengths = [\n",
    "    len(intervalgpvae_errors),\n",
    "    len(tedvae_errors),\n",
    "    len(intervalgpvae_pehes),\n",
    "    len(tedvae_pehes_tests),\n",
    "    len(coverage_all),\n",
    "]\n",
    "N = min([L for L in lengths if L > 0]) if any(lengths) else 0\n",
    "\n",
    "# 截断/对齐到共同长度 N\n",
    "intervalgpvae_errors = intervalgpvae_errors[:N]\n",
    "tedvae_errors        = tedvae_errors[:N]\n",
    "intervalgpvae_pehes  = intervalgpvae_pehes[:N]\n",
    "tedvae_pehes_tests   = tedvae_pehes_tests[:N]\n",
    "coverage_plot        = coverage_all[:N]\n",
    "\n",
    "# x 轴\n",
    "x = np.arange(N)\n",
    "x_labels = [str(i + 1) for i in x]\n",
    "\n",
    "# 均值（空时为 NaN）\n",
    "avg_intervalgpvae_pehe = np.mean(intervalgpvae_pehes) if len(intervalgpvae_pehes) else float(\"nan\")\n",
    "avg_tedvae_pehe_test   = np.mean(tedvae_pehes_tests)  if len(tedvae_pehes_tests)  else float(\"nan\")\n",
    "avg_ate_err_interval   = np.mean(intervalgpvae_errors) if len(intervalgpvae_errors) else float(\"nan\")\n",
    "avg_ate_err_tedvae     = np.mean(tedvae_errors)        if len(tedvae_errors)        else float(\"nan\")\n",
    "avg_coverage           = np.mean(coverage_plot)        if len(coverage_plot)        else float(\"nan\")\n",
    "\n",
    "# ======================== 绘图（不共享 x，各自设置 xlim，避免右侧空白） ========================\n",
    "fig, axs = plt.subplots(3, 1, figsize=(12, 14))\n",
    "\n",
    "# ---------- (1) PEHE 折线 ----------\n",
    "axs[0].plot(x + 1, tedvae_pehes_tests, marker='s', label='TEDVAE', color='tab:blue')\n",
    "axs[0].plot(x + 1, intervalgpvae_pehes, marker='o', label='IntervalGP-VAE', color='tab:orange')\n",
    "\n",
    "# 均值线\n",
    "if not np.isnan(avg_tedvae_pehe_test):\n",
    "    axs[0].axhline(avg_tedvae_pehe_test,   linestyle='--', color='tab:blue',   alpha=0.5)\n",
    "if not np.isnan(avg_intervalgpvae_pehe):\n",
    "    axs[0].axhline(avg_intervalgpvae_pehe, linestyle='--', color='tab:orange', alpha=0.5)\n",
    "\n",
    "# 左上角均值说明（不遮挡）\n",
    "if N > 0:\n",
    "    text_lines = []\n",
    "    if not np.isnan(avg_tedvae_pehe_test):\n",
    "        text_lines.append(f'TEDVAE mean: {avg_tedvae_pehe_test:.3f}')\n",
    "    if not np.isnan(avg_intervalgpvae_pehe):\n",
    "        text_lines.append(f'IntervalGP-VAE mean: {avg_intervalgpvae_pehe:.3f}')\n",
    "    if text_lines:\n",
    "        axs[0].text(\n",
    "            0.02, 0.98, \"\\n\".join(text_lines),\n",
    "            transform=axs[0].transAxes, ha='left', va='top',\n",
    "            bbox=dict(boxstyle='round,pad=0.25', fc='white', ec='none', alpha=0.75),\n",
    "            fontsize=10\n",
    "        )\n",
    "\n",
    "axs[0].set_title('PEHE Comparison: IntervalGP-VAE vs. TEDVAE', fontsize=14)\n",
    "axs[0].set_ylabel('PEHE', fontsize=12)\n",
    "axs[0].set_xlim(0.5, (N if N else 1) + 0.5)\n",
    "\n",
    "# ✅ 把图例放到右上角（与 ATE Error 一致）\n",
    "axs[0].legend(loc='upper right', framealpha=0.9)  # 如需更靠右上，可用 bbox_to_anchor=(1.0, 1.0)\n",
    "\n",
    "axs[0].grid(True, linestyle='--', alpha=0.5)\n",
    "\n",
    "\n",
    "# ---------- (2) ATE Error 条形图 ----------\n",
    "bar_width = 0.35\n",
    "axs[1].bar(x + bar_width/2, tedvae_errors,        width=bar_width, label='TEDVAE ATE Error',        color='tab:blue')\n",
    "axs[1].bar(x - bar_width/2, intervalgpvae_errors, width=bar_width, label='IntervalGP-VAE ATE Error', color='tab:orange')\n",
    "\n",
    "# 均值线\n",
    "if not np.isnan(avg_ate_err_tedvae):\n",
    "    axs[1].axhline(avg_ate_err_tedvae,   linestyle='--', color='tab:blue',   alpha=0.6)\n",
    "if not np.isnan(avg_ate_err_interval):\n",
    "    axs[1].axhline(avg_ate_err_interval, linestyle='--', color='tab:orange', alpha=0.6)\n",
    "\n",
    "# 左上角均值说明（不遮挡）\n",
    "if N > 0:\n",
    "    text_lines = []\n",
    "    if not np.isnan(avg_ate_err_tedvae):\n",
    "        text_lines.append(f'TEDVAE mean: {avg_ate_err_tedvae:.3f}')\n",
    "    if not np.isnan(avg_ate_err_interval):\n",
    "        text_lines.append(f'IntervalGP-VAE mean: {avg_ate_err_interval:.3f}')\n",
    "    if text_lines:\n",
    "        axs[1].text(\n",
    "            0.02, 0.98, \"\\n\".join(text_lines),\n",
    "            transform=axs[1].transAxes, ha='left', va='top',\n",
    "            bbox=dict(boxstyle='round,pad=0.25', fc='white', ec='none', alpha=0.75),\n",
    "            fontsize=10\n",
    "        )\n",
    "\n",
    "axs[1].set_ylabel('ATE Error', fontsize=12)\n",
    "axs[1].set_title('ATE Error: IntervalGP-VAE vs. TEDVAE', fontsize=14)\n",
    "axs[1].set_xlim(-0.5, (N-0.5) if N else 0.5)\n",
    "axs[1].set_xticks(x)\n",
    "axs[1].set_xticklabels(x_labels)\n",
    "axs[1].legend()\n",
    "axs[1].grid(axis='y', linestyle='--', alpha=0.5)\n",
    "\n",
    "# ---------- (3) Coverage Rate 条形图 ----------\n",
    "axs[2].bar(x, coverage_plot, width=0.6, color='tab:green', label='Coverage')\n",
    "\n",
    "# 均值线\n",
    "if not np.isnan(avg_coverage):\n",
    "    axs[2].axhline(avg_coverage, color='tab:green', linestyle='--', alpha=0.6)\n",
    "\n",
    "# 左上角均值说明（白底框，防遮挡）\n",
    "if N > 0 and not np.isnan(avg_coverage):\n",
    "    axs[2].text(\n",
    "        0.02, 0.98, f'mean: {avg_coverage:.3f}',\n",
    "        transform=axs[2].transAxes, ha='left', va='top',\n",
    "        bbox=dict(boxstyle='round,pad=0.25', fc='white', ec='none', alpha=0.75),\n",
    "        color='tab:green', fontsize=10\n",
    "    )\n",
    "\n",
    "# 单柱标注：抬高且不裁剪\n",
    "for i, v in enumerate(coverage_plot):\n",
    "    axs[2].text(i, v + 0.025, f'{v:.2f}', ha='center', va='bottom',\n",
    "                fontsize=9, clip_on=False)\n",
    "\n",
    "axs[2].set_xlabel('Combination ID', fontsize=12)\n",
    "axs[2].set_ylabel('Coverage Rate', fontsize=12)\n",
    "axs[2].set_title('Coverage Rate of IntervalGP-VAE', fontsize=14)\n",
    "\n",
    "# 自适应顶部留白（防止标注顶到边界）；至少 1.05\n",
    "ylim_top = max(1.05, float(np.max(coverage_plot)) + 0.06) if len(coverage_plot) else 1.05\n",
    "axs[2].set_ylim(0, ylim_top)\n",
    "\n",
    "axs[2].set_xlim(-0.5, (N-0.5) if N else 0.5)\n",
    "axs[2].set_xticks(x)\n",
    "axs[2].set_xticklabels(x_labels)\n",
    "axs[2].legend()\n",
    "axs[2].grid(axis='y', linestyle='--', alpha=0.6)\n",
    "\n",
    "# 收紧边距\n",
    "for ax in axs:\n",
    "    ax.margins(x=0.02)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('syn_results.png', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "338499f9-2a08-4eac-9e87-0ae5da600109",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (InervalGP-VAE)",
   "language": "python",
   "name": "inervalgp-vae"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
