{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simulations.simulations import path_wise_dataset_1, treatment_col_index, confounder_col_index, mediator1_col_index, mediator2_col_index, ModelWrapper, calculate_true_cate_but_mediator2\n",
    "from path_wise.path_wise import compute_path_wise_effects_doubly_robust\n",
    "import numpy as np\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_shap_comparison_table(results_gt, result_model, decimal_places=2):\n",
    "    \"\"\"\n",
    "    Create a simplified comparison table for PW-SHAP* and PE-SHAP* methods.\n",
    "    \n",
    "    Args:\n",
    "        results_gt: Ground truth results dictionary\n",
    "        result_model: Model results dictionary  \n",
    "        decimal_places: Number of decimal places to round to (default: 2)\n",
    "    \n",
    "    Returns:\n",
    "        String containing the formatted table\n",
    "    \"\"\"\n",
    "    def format_value(gt_val, model_val):\n",
    "        \"\"\"Format value as gt_val(model_val) or gt_val(-) if model_val is None/NaN\"\"\"\n",
    "        if model_val is None or np.isnan(model_val):\n",
    "            return f\"{gt_val:.{decimal_places}f}(-)\"\n",
    "        else:\n",
    "            return f\"{gt_val:.{decimal_places}f}({model_val:.{decimal_places}f})\"\n",
    "    \n",
    "    # Extract PW-SHAP* values (using path_wise_shap keys)\n",
    "    pwshap_t_y_gt = results_gt[\"pishap_t_y\"]\n",
    "    pwshap_m1_m2_y_gt = results_gt[\"path_wise_shap_t_m1_m2_y\"] \n",
    "    pwshap_t_m1_y_gt = results_gt[\"path_wise_shap_t_m1_y\"]\n",
    "    pwshap_t_m2_y_gt = results_gt[\"path_wise_shap_t_m2_y\"]\n",
    "\n",
    "    pwshap_t_y_model = result_model[\"pishap_t_y\"]\n",
    "    pwshap_m1_m2_y_model = result_model[\"path_wise_shap_t_m1_m2_y\"]\n",
    "    pwshap_t_m1_y_model = result_model[\"path_wise_shap_t_m1_y\"] \n",
    "    pwshap_t_m2_y_model = result_model[\"path_wise_shap_t_m2_y\"]\n",
    "\n",
    "    # Extract PE-SHAP* values (using pishap keys)\n",
    "    pishap_t_y_gt = results_gt[\"pishap_t_y\"]\n",
    "    pishap_m1_m2_y_gt = results_gt[\"pishap_t_m1_m2_y\"]\n",
    "    pishap_t_m1_y_gt = results_gt[\"pishap_t_m1_y\"]\n",
    "    pishap_t_m2_y_gt = results_gt[\"pishap_t_m2_y\"]\n",
    "\n",
    "    pishap_t_y_model = result_model[\"pishap_t_y\"]\n",
    "    pishap_m1_m2_y_model = result_model[\"pishap_t_m1_m2_y\"]\n",
    "    pishap_t_m1_y_model = result_model[\"pishap_t_m1_y\"]\n",
    "    pishap_t_m2_y_model = result_model[\"pishap_t_m2_y\"]\n",
    "\n",
    "    # Create the simplified table\n",
    "    table = f\"\"\"| PW-SHAP* | {format_value(pwshap_t_y_gt, pwshap_t_y_model)} & {format_value(pwshap_m1_m2_y_gt, pwshap_m1_m2_y_model)} & {format_value(pwshap_t_m1_y_gt, pwshap_t_m1_y_model)} & {format_value(pwshap_t_m2_y_gt, pwshap_t_m2_y_model)} |\n",
    "| PE-SHAP* | {format_value(pishap_t_y_gt, pishap_t_y_model)} & {format_value(pishap_m1_m2_y_gt, pishap_m1_m2_y_model)} & {format_value(pishap_t_m1_y_gt, pishap_t_m1_y_model)} & {format_value(pishap_t_m2_y_gt, pishap_t_m2_y_model)} |\"\"\"\n",
    "    \n",
    "    return table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| PW-SHAP* | 0.79(0.78) & 0.26(0.26) & 0.32(0.33) & 0.09(0.08) |\n",
      "| PE-SHAP* | 0.79(0.78) & -0.26(-0.26) & -0.17(-0.18) & 0.06(0.07) |\n"
     ]
    }
   ],
   "source": [
    "X_test1, y_test1 = path_wise_dataset_1(num_samples=30000)\n",
    "\n",
    "results_gt = compute_path_wise_effects_doubly_robust([1, 0.2, 0.6, 1], ModelWrapper(), X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "mlp = MLPRegressor(random_state=0)\n",
    "\n",
    "mlp.fit(X_test1, y_test1)\n",
    "\n",
    "result_model = compute_path_wise_effects_doubly_robust([1, 0.2, 0.6, 1], mlp, X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "# Create and display the table\n",
    "table = create_shap_comparison_table(results_gt, result_model)\n",
    "print(table)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| PW-SHAP* | 0.75(0.75) & 0.22(0.22) & 0.17(0.18) & 0.07(0.07) |\n",
      "| PE-SHAP* | 0.75(0.75) & -0.22(-0.22) & -0.15(-0.15) & -0.05(-0.04) |\n"
     ]
    }
   ],
   "source": [
    "X_test1, y_test1 = path_wise_dataset_1(num_samples=30000)\n",
    "\n",
    "results_gt = compute_path_wise_effects_doubly_robust([1, 0.5, 0.6, 1], ModelWrapper(), X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "mlp = MLPRegressor(random_state=0)\n",
    "\n",
    "mlp.fit(X_test1, y_test1)\n",
    "\n",
    "result_model = compute_path_wise_effects_doubly_robust([1, 0.5, 0.6, 1], mlp, X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "# Create and display the table\n",
    "table = create_shap_comparison_table(results_gt, result_model)\n",
    "print(table)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| PW-SHAP* | 0.85(0.85) & 0.33(0.32) & 0.28(0.28) & 0.02(0.01) |\n",
      "| PE-SHAP* | 0.85(0.85) & -0.33(-0.32) & -0.31(-0.31) & -0.05(-0.04) |\n"
     ]
    }
   ],
   "source": [
    "X_test1, y_test1 = path_wise_dataset_1(num_samples=30000)\n",
    "\n",
    "results_gt = compute_path_wise_effects_doubly_robust([1, 0.5, 0.3, 1], ModelWrapper(), X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "mlp = MLPRegressor(random_state=0)\n",
    "\n",
    "mlp.fit(X_test1, y_test1)\n",
    "\n",
    "result_model = compute_path_wise_effects_doubly_robust([1, 0.5, 0.3, 1], mlp, X_test1, treatment_col_index, mediator1_col_index, mediator2_col_index, confounder_col_index)\n",
    "\n",
    "# Create and display the table\n",
    "table = create_shap_comparison_table(results_gt, result_model)\n",
    "print(table)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def monte_carlo_evaluation_all_models_doubly_robust(sample, model_types=None, \n",
    "                                                  num_runs=100, num_samples=1000, base_seed=42):\n",
    "    \"\"\"\n",
    "    Perform Monte Carlo evaluation for all specified model types using doubly robust estimator.\n",
    "    \n",
    "    Args:\n",
    "        sample: The sample for which to evaluate CATE\n",
    "        model_types: List of model types to test\n",
    "        num_runs: Number of Monte Carlo runs\n",
    "        num_samples: Number of samples per run for training the estimator\n",
    "        base_seed: Base seed for reproducibility (each run gets base_seed + run_number)\n",
    "    \n",
    "    Returns:\n",
    "        Dictionary with results for each model type\n",
    "    \"\"\"\n",
    "    results = {}\n",
    "    true_value = calculate_true_cate_but_mediator2(sample)\n",
    "    \n",
    "    print(f\"True CATE without mediator2: {true_value:.4f}\")\n",
    "    print(f\"Testing {len(model_types)} model types: {model_types}\")\n",
    "    print(f\"Running {num_runs} Monte Carlo simulations for each model...\")\n",
    "    \n",
    "    for model_type, propensity_model_type in model_types:\n",
    "        print(f\"\\n{'='*60}\")\n",
    "        print(f\"EVALUATING MODEL: {model_type}\")\n",
    "        print(f\"{'='*60}\")\n",
    "        \n",
    "        estimates = []\n",
    "        \n",
    "        for run in range(num_runs):\n",
    "            # Generate new dataset for this run\n",
    "            X_test, y_test = path_wise_dataset_1(num_samples=num_samples, seed=base_seed+run)\n",
    "            # Get estimate from the doubly robust estimator with specific model type\n",
    "            result = compute_path_wise_effects_doubly_robust(\n",
    "                sample, \n",
    "                ModelWrapper(), \n",
    "                X_test, \n",
    "                treatment_col_index, \n",
    "                mediator1_col_index, \n",
    "                mediator2_col_index,\n",
    "                confounder_col_index,\n",
    "                model_type=model_type,\n",
    "                propensity_model_type=propensity_model_type\n",
    "            )\n",
    "\n",
    "            estimate = result[\"cate_mediator1_confounder\"]  # Extract scalar value\n",
    "\n",
    "            if estimate == np.nan:\n",
    "                raise Exception(\"Failed to indetify\")\n",
    "            estimates.append(estimate)\n",
    "            \n",
    "            if (run + 1) % 20 == 0:\n",
    "                print(f\"  {model_type}_{propensity_model_type}: Completed {run + 1}/{num_runs} runs\")\n",
    "        \n",
    "        estimates = np.array(estimates)\n",
    "        \n",
    "        # Calculate only MAE and Monte Carlo Error\n",
    "        errors = estimates - true_value\n",
    "        mae = np.mean(np.abs(errors))\n",
    "        \n",
    "        # Monte Carlo Error (standard error of the mean)\n",
    "        monte_carlo_error = np.std(estimates) / np.sqrt(num_runs)\n",
    "        \n",
    "        results[f\"{model_type}_{propensity_model_type}\"] = {\n",
    "            'true_value': true_value,\n",
    "            'estimates': estimates,\n",
    "            'mae': mae,\n",
    "            'monte_carlo_error': monte_carlo_error\n",
    "        }\n",
    "        \n",
    "        print(f\"  {model_type}_{propensity_model_type} Results:\")\n",
    "        print(f\"    MAE: {mae:.4f}\")\n",
    "        print(f\"    Monte Carlo Error: {monte_carlo_error:.4f}\")\n",
    "    \n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_multiple_samples_averaged_doubly_robust(samples, model_types=None, \n",
    "                                                   num_runs=5, num_samples=10000, base_seed=42):\n",
    "    \"\"\"\n",
    "    Evaluate multiple samples and return averaged results across all samples for doubly robust estimator.\n",
    "\n",
    "    Args:\n",
    "        samples: List of samples to evaluate\n",
    "        model_types: List of model types to test\n",
    "        num_runs: Number of Monte Carlo runs per sample\n",
    "        num_samples: Number of samples per run\n",
    "        base_seed: Base seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        Dictionary with averaged results and individual sample results\n",
    "    \"\"\"\n",
    "    import pandas as pd\n",
    "    import numpy as np\n",
    "    \n",
    "    all_results = {}\n",
    "    sample_results = {}\n",
    "    \n",
    "    print(f\"Evaluating {len(samples)} samples with {len(model_types)} model types each (Doubly Robust)\")\n",
    "    print(f\"Total evaluations: {len(samples)} × {len(model_types)} = {len(samples) * len(model_types)}\")\n",
    "    print(\"=\"*80)\n",
    "    \n",
    "    # Evaluate each sample\n",
    "    for i, sample in enumerate(samples):\n",
    "        print(f\"\\nSAMPLE {i+1}: {sample}\")\n",
    "        print(f\"True CATE: {calculate_true_cate_but_mediator2(sample):.4f}\")\n",
    "        print(\"-\" * 60)\n",
    "        \n",
    "        # Run evaluation for this sample\n",
    "        sample_result = monte_carlo_evaluation_all_models_doubly_robust(\n",
    "            sample, \n",
    "            model_types, \n",
    "            num_runs, \n",
    "            num_samples, \n",
    "            base_seed + i  # Different seed for each sample\n",
    "        )\n",
    "        \n",
    "        sample_results[f\"sample_{i+1}\"] = sample_result\n",
    "        \n",
    "        # Print results for this sample\n",
    "        for model_type, propensity_model_type in model_types:\n",
    "            mae = sample_result[f\"{model_type}_{propensity_model_type}\"]['mae']\n",
    "            mce = sample_result[f\"{model_type}_{propensity_model_type}\"]['monte_carlo_error']\n",
    "            print(f\"  {model_type}_{propensity_model_type}: MAE={mae:.4f}, MCE={mce:.4f}\")\n",
    "    \n",
    "    # Calculate averaged results\n",
    "    print(\"\\n\" + \"=\"*80)\n",
    "    print(\"AVERAGED RESULTS ACROSS ALL SAMPLES (DOUBLY ROBUST)\")\n",
    "    print(\"=\"*80)\n",
    "    \n",
    "    averaged_results = {}\n",
    "    for model_type, propensity_model_type in model_types:\n",
    "        # Collect all estimates and true values for this model type across samples\n",
    "        all_estimates = []\n",
    "        all_true_values = []\n",
    "        \n",
    "        for i in range(len(samples)):\n",
    "            estimates = sample_results[f\"sample_{i+1}\"][f\"{model_type}_{propensity_model_type}\"]['estimates']\n",
    "            true_value = sample_results[f\"sample_{i+1}\"][f\"{model_type}_{propensity_model_type}\"]['true_value']\n",
    "            all_estimates.extend(estimates)\n",
    "            all_true_values.extend([true_value] * len(estimates))\n",
    "        \n",
    "        # Compute MAE for all samples combined\n",
    "        all_estimates = np.array(all_estimates)\n",
    "        all_true_values = np.array(all_true_values)\n",
    "        combined_mae = np.mean(np.abs(all_estimates - all_true_values))\n",
    "        \n",
    "        # Compute Monte Carlo error for all samples combined\n",
    "        # Monte Carlo error is the standard error of the mean: std(estimates) / sqrt(n)\n",
    "        combined_mce = np.std(all_estimates) / np.sqrt(len(all_estimates))\n",
    "        \n",
    "        averaged_results[f\"{model_type}_{propensity_model_type}\"] = {\n",
    "            'mae_combined': combined_mae,\n",
    "            'mce_combined': combined_mce,\n",
    "            'total_estimates': len(all_estimates)\n",
    "        }\n",
    "    \n",
    "    # Create comparison table\n",
    "    comparison_data = []\n",
    "    for model_type, propensity_model_type in model_types:\n",
    "        result = averaged_results[f\"{model_type}_{propensity_model_type}\"]\n",
    "        comparison_data.append({\n",
    "            'Model': f\"{model_type}_{propensity_model_type}\",\n",
    "            'MAE (Combined)': f\"{result['mae_combined']:.4f}\",\n",
    "            'Monte Carlo Error (Combined)': f\"{result['mce_combined']:.4f}\",\n",
    "            'Total Estimates': result['total_estimates']\n",
    "        })\n",
    "    \n",
    "    df_comparison = pd.DataFrame(comparison_data)\n",
    "    print(df_comparison.to_string(index=False))\n",
    "    \n",
    "    # Find best performing model by combined MAE\n",
    "    best_model = min(model_types, key=lambda x: averaged_results[f\"{x[0]}_{x[1]}\"]['mae_combined'])\n",
    "    print(f\"\\nBest performing model (by combined MAE): {best_model[0]}_{best_model[1]}\")\n",
    "    print(f\"Combined MAE: {averaged_results[f\"{best_model[0]}_{best_model[1]}\"]['mae_combined']:.4f}\")\n",
    "    print(f\"Total estimates used: {averaged_results[f\"{best_model[0]}_{best_model[1]}\"]['total_estimates']}\")\n",
    "    \n",
    "    return averaged_results, sample_results, df_comparison\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing 3 samples with all model types and computing averaged results (Doubly Robust)...\n",
      "================================================================================\n",
      "Evaluating 5 samples with 6 model types each (Doubly Robust)\n",
      "Total evaluations: 5 × 6 = 30\n",
      "================================================================================\n",
      "\n",
      "SAMPLE 1: [1, 0.2, 0.3, 1]\n",
      "True CATE: 1.1930\n",
      "------------------------------------------------------------\n",
      "True CATE without mediator2: 1.1930\n",
      "Testing 6 model types: [('xgb', 'logistic'), ('mlp', 'logistic'), ('poly2', 'logistic'), ('linear', 'logistic'), ('linear', 'logistic_invalid'), ('poly2', 'logistic_invalid')]\n",
      "Running 10 Monte Carlo simulations for each model...\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: xgb\n",
      "============================================================\n",
      "  xgb_logistic Results:\n",
      "    MAE: 0.1113\n",
      "    Monte Carlo Error: 0.0406\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: mlp\n",
      "============================================================\n",
      "  mlp_logistic Results:\n",
      "    MAE: 0.1099\n",
      "    Monte Carlo Error: 0.0404\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic Results:\n",
      "    MAE: 0.1114\n",
      "    Monte Carlo Error: 0.0406\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic Results:\n",
      "    MAE: 0.1090\n",
      "    Monte Carlo Error: 0.0400\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic_invalid Results:\n",
      "    MAE: 0.1483\n",
      "    Monte Carlo Error: 0.0570\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic_invalid Results:\n",
      "    MAE: 0.1114\n",
      "    Monte Carlo Error: 0.0406\n",
      "  xgb_logistic: MAE=0.1113, MCE=0.0406\n",
      "  mlp_logistic: MAE=0.1099, MCE=0.0404\n",
      "  poly2_logistic: MAE=0.1114, MCE=0.0406\n",
      "  linear_logistic: MAE=0.1090, MCE=0.0400\n",
      "  linear_logistic_invalid: MAE=0.1483, MCE=0.0570\n",
      "  poly2_logistic_invalid: MAE=0.1114, MCE=0.0406\n",
      "\n",
      "SAMPLE 2: [1, 0.2, 0.4, 1]\n",
      "True CATE: 1.1430\n",
      "------------------------------------------------------------\n",
      "True CATE without mediator2: 1.1430\n",
      "Testing 6 model types: [('xgb', 'logistic'), ('mlp', 'logistic'), ('poly2', 'logistic'), ('linear', 'logistic'), ('linear', 'logistic_invalid'), ('poly2', 'logistic_invalid')]\n",
      "Running 10 Monte Carlo simulations for each model...\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: xgb\n",
      "============================================================\n",
      "  xgb_logistic Results:\n",
      "    MAE: 0.1346\n",
      "    Monte Carlo Error: 0.0506\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: mlp\n",
      "============================================================\n",
      "  mlp_logistic Results:\n",
      "    MAE: 0.1329\n",
      "    Monte Carlo Error: 0.0502\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic Results:\n",
      "    MAE: 0.1346\n",
      "    Monte Carlo Error: 0.0506\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic Results:\n",
      "    MAE: 0.1387\n",
      "    Monte Carlo Error: 0.0518\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic_invalid Results:\n",
      "    MAE: 0.1595\n",
      "    Monte Carlo Error: 0.0560\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic_invalid Results:\n",
      "    MAE: 0.1346\n",
      "    Monte Carlo Error: 0.0506\n",
      "  xgb_logistic: MAE=0.1346, MCE=0.0506\n",
      "  mlp_logistic: MAE=0.1329, MCE=0.0502\n",
      "  poly2_logistic: MAE=0.1346, MCE=0.0506\n",
      "  linear_logistic: MAE=0.1387, MCE=0.0518\n",
      "  linear_logistic_invalid: MAE=0.1595, MCE=0.0560\n",
      "  poly2_logistic_invalid: MAE=0.1346, MCE=0.0506\n",
      "\n",
      "SAMPLE 3: [1, 0.2, 0.5, 1]\n",
      "True CATE: 1.0930\n",
      "------------------------------------------------------------\n",
      "True CATE without mediator2: 1.0930\n",
      "Testing 6 model types: [('xgb', 'logistic'), ('mlp', 'logistic'), ('poly2', 'logistic'), ('linear', 'logistic'), ('linear', 'logistic_invalid'), ('poly2', 'logistic_invalid')]\n",
      "Running 10 Monte Carlo simulations for each model...\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: xgb\n",
      "============================================================\n",
      "  xgb_logistic Results:\n",
      "    MAE: 0.1866\n",
      "    Monte Carlo Error: 0.0673\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: mlp\n",
      "============================================================\n",
      "  mlp_logistic Results:\n",
      "    MAE: 0.1856\n",
      "    Monte Carlo Error: 0.0669\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic Results:\n",
      "    MAE: 0.1866\n",
      "    Monte Carlo Error: 0.0673\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic Results:\n",
      "    MAE: 0.1889\n",
      "    Monte Carlo Error: 0.0683\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic_invalid Results:\n",
      "    MAE: 0.1898\n",
      "    Monte Carlo Error: 0.0688\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic_invalid Results:\n",
      "    MAE: 0.1866\n",
      "    Monte Carlo Error: 0.0673\n",
      "  xgb_logistic: MAE=0.1866, MCE=0.0673\n",
      "  mlp_logistic: MAE=0.1856, MCE=0.0669\n",
      "  poly2_logistic: MAE=0.1866, MCE=0.0673\n",
      "  linear_logistic: MAE=0.1889, MCE=0.0683\n",
      "  linear_logistic_invalid: MAE=0.1898, MCE=0.0688\n",
      "  poly2_logistic_invalid: MAE=0.1866, MCE=0.0673\n",
      "\n",
      "SAMPLE 4: [1, 0.2, 0.6, 1]\n",
      "True CATE: 1.0430\n",
      "------------------------------------------------------------\n",
      "True CATE without mediator2: 1.0430\n",
      "Testing 6 model types: [('xgb', 'logistic'), ('mlp', 'logistic'), ('poly2', 'logistic'), ('linear', 'logistic'), ('linear', 'logistic_invalid'), ('poly2', 'logistic_invalid')]\n",
      "Running 10 Monte Carlo simulations for each model...\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: xgb\n",
      "============================================================\n",
      "  xgb_logistic Results:\n",
      "    MAE: 0.1639\n",
      "    Monte Carlo Error: 0.0643\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: mlp\n",
      "============================================================\n",
      "  mlp_logistic Results:\n",
      "    MAE: 0.1655\n",
      "    Monte Carlo Error: 0.0650\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic Results:\n",
      "    MAE: 0.1638\n",
      "    Monte Carlo Error: 0.0643\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic Results:\n",
      "    MAE: 0.1706\n",
      "    Monte Carlo Error: 0.0665\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic_invalid Results:\n",
      "    MAE: 0.1664\n",
      "    Monte Carlo Error: 0.0678\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic_invalid Results:\n",
      "    MAE: 0.1638\n",
      "    Monte Carlo Error: 0.0643\n",
      "  xgb_logistic: MAE=0.1639, MCE=0.0643\n",
      "  mlp_logistic: MAE=0.1655, MCE=0.0650\n",
      "  poly2_logistic: MAE=0.1638, MCE=0.0643\n",
      "  linear_logistic: MAE=0.1706, MCE=0.0665\n",
      "  linear_logistic_invalid: MAE=0.1664, MCE=0.0678\n",
      "  poly2_logistic_invalid: MAE=0.1638, MCE=0.0643\n",
      "\n",
      "SAMPLE 5: [1, 0.2, 0.7, 1]\n",
      "True CATE: 0.9930\n",
      "------------------------------------------------------------\n",
      "True CATE without mediator2: 0.9930\n",
      "Testing 6 model types: [('xgb', 'logistic'), ('mlp', 'logistic'), ('poly2', 'logistic'), ('linear', 'logistic'), ('linear', 'logistic_invalid'), ('poly2', 'logistic_invalid')]\n",
      "Running 10 Monte Carlo simulations for each model...\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: xgb\n",
      "============================================================\n",
      "  xgb_logistic Results:\n",
      "    MAE: 0.1910\n",
      "    Monte Carlo Error: 0.0748\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: mlp\n",
      "============================================================\n",
      "  mlp_logistic Results:\n",
      "    MAE: 0.1912\n",
      "    Monte Carlo Error: 0.0750\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic Results:\n",
      "    MAE: 0.1910\n",
      "    Monte Carlo Error: 0.0749\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic Results:\n",
      "    MAE: 0.1988\n",
      "    Monte Carlo Error: 0.0780\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: linear\n",
      "============================================================\n",
      "  linear_logistic_invalid Results:\n",
      "    MAE: 0.2784\n",
      "    Monte Carlo Error: 0.0948\n",
      "\n",
      "============================================================\n",
      "EVALUATING MODEL: poly2\n",
      "============================================================\n",
      "  poly2_logistic_invalid Results:\n",
      "    MAE: 0.1910\n",
      "    Monte Carlo Error: 0.0749\n",
      "  xgb_logistic: MAE=0.1910, MCE=0.0748\n",
      "  mlp_logistic: MAE=0.1912, MCE=0.0750\n",
      "  poly2_logistic: MAE=0.1910, MCE=0.0749\n",
      "  linear_logistic: MAE=0.1988, MCE=0.0780\n",
      "  linear_logistic_invalid: MAE=0.2784, MCE=0.0948\n",
      "  poly2_logistic_invalid: MAE=0.1910, MCE=0.0749\n",
      "\n",
      "================================================================================\n",
      "AVERAGED RESULTS ACROSS ALL SAMPLES (DOUBLY ROBUST)\n",
      "================================================================================\n",
      "                  Model MAE (Combined) Monte Carlo Error (Combined)  Total Estimates\n",
      "           xgb_logistic         0.1575                       0.0304               50\n",
      "           mlp_logistic         0.1570                       0.0304               50\n",
      "         poly2_logistic         0.1575                       0.0305               50\n",
      "        linear_logistic         0.1612                       0.0310               50\n",
      "linear_logistic_invalid         0.1885                       0.0339               50\n",
      " poly2_logistic_invalid         0.1575                       0.0305               50\n",
      "\n",
      "Best performing model (by combined MAE): mlp_logistic\n",
      "Combined MAE: 0.1570\n",
      "Total estimates used: 50\n"
     ]
    }
   ],
   "source": [
    "# Test with 3 different samples and compute averaged results for doubly robust estimator\n",
    "samples = [\n",
    "    [1, 0.2, 0.3, 1],\n",
    "    [1, 0.2, 0.4, 1],\n",
    "    [1, 0.2, 0.5, 1],\n",
    "    [1, 0.2, 0.6, 1],\n",
    "    [1, 0.2, 0.7, 1],\n",
    "]\n",
    "\n",
    "print(\"Testing 3 samples with all model types and computing averaged results (Doubly Robust)...\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "# Run comprehensive evaluation for all samples\n",
    "averaged_results, sample_results, comparison_df = evaluate_multiple_samples_averaged_doubly_robust(\n",
    "    samples, \n",
    "    model_types=[(\"xgb\", \"logistic\"), (\"mlp\", \"logistic\"), (\"poly2\", \"logistic\"), (\"linear\", \"logistic\"), (\"linear\", \"logistic_invalid\"), (\"poly2\", \"logistic_invalid\")],\n",
    "    num_runs=10,\n",
    "    num_samples=1000,\n",
    "    base_seed=0\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat_model = ModelWrapper().predict(X_test1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = X_test1[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_treated, y_treated = X_test1[T == 1], y_hat_model[T == 1]\n",
    "X_control, y_control = X_test1[T == 0], y_hat_model[T == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: #000;\n",
       "  --sklearn-color-text-muted: #666;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: flex;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "  align-items: start;\n",
       "  justify-content: space-between;\n",
       "  gap: 0.5em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label .caption {\n",
       "  font-size: 0.6rem;\n",
       "  font-weight: lighter;\n",
       "  color: var(--sklearn-color-text-muted);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 0.5em;\n",
       "  text-align: center;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MLPRegressor(max_iter=500, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>MLPRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.neural_network.MLPRegressor.html\">?<span>Documentation for MLPRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>MLPRegressor(max_iter=500, random_state=42)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "MLPRegressor(max_iter=500, random_state=42)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neural_network import MLPRegressor\n",
    "model_mu1 = MLPRegressor(hidden_layer_sizes=(100,), max_iter=500, random_state=42)\n",
    "model_mu1.fit(X_treated, y_treated)\n",
    "model_mu0 = MLPRegressor(hidden_layer_sizes=(100,), max_iter=500, random_state=42)\n",
    "model_mu0.fit(X_control, y_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(175.7287851551348)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(np.abs(model_mu1.predict(X_treated) - y_treated))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: #000;\n",
       "  --sklearn-color-text-muted: #666;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-2 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-2 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: flex;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "  align-items: start;\n",
       "  justify-content: space-between;\n",
       "  gap: 0.5em;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label .caption {\n",
       "  font-size: 0.6rem;\n",
       "  font-weight: lighter;\n",
       "  color: var(--sklearn-color-text-muted);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-2 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-2 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 0.5em;\n",
       "  text-align: center;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-2 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LinearRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>LinearRegression</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.linear_model.LinearRegression.html\">?<span>Documentation for LinearRegression</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>LinearRegression()</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "LinearRegression()"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "model_mu1 = LinearRegression()\n",
    "model_mu1.fit(X_treated, y_treated)\n",
    "model_mu0 = LinearRegression()\n",
    "model_mu0.fit(X_control, y_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(3558.615715189507)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(np.abs(model_mu1.predict(X_treated) - y_treated))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.        ,  1.00204839, -1.3135332 ,  0.96349818])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_mu1.coef_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| PW-SHAP* | 0.79(0.80) | 0.26(0.27) | 0.32(0.33) | 0.08(0.09) |\n",
      "| PE-SHAP* | 0.79(0.80) | -0.26(-0.27) | -0.18(-0.18) | 0.06(0.06) |\n"
     ]
    }
   ],
   "source": [
    "def create_shap_comparison_table(results_gt, result_model, decimal_places=2):\n",
    "    \"\"\"\n",
    "    Create a simplified comparison table for PW-SHAP* and PE-SHAP* methods.\n",
    "    \n",
    "    Args:\n",
    "        results_gt: Ground truth results dictionary\n",
    "        result_model: Model results dictionary  \n",
    "        decimal_places: Number of decimal places to round to (default: 2)\n",
    "    \n",
    "    Returns:\n",
    "        String containing the formatted table\n",
    "    \"\"\"\n",
    "    def format_value(gt_val, model_val):\n",
    "        \"\"\"Format value as gt_val(model_val) or gt_val(-) if model_val is None/NaN\"\"\"\n",
    "        if model_val is None or np.isnan(model_val):\n",
    "            return f\"{gt_val:.{decimal_places}f}(-)\"\n",
    "        else:\n",
    "            return f\"{gt_val:.{decimal_places}f}({model_val:.{decimal_places}f})\"\n",
    "    \n",
    "    # Extract PW-SHAP* values (using path_wise_shap keys)\n",
    "    pwshap_t_y_gt = results_gt[\"pishap_t_y\"]\n",
    "    pwshap_m1_m2_y_gt = results_gt[\"path_wise_shap_t_m1_m2_y\"] \n",
    "    pwshap_t_m1_y_gt = results_gt[\"path_wise_shap_t_m1_y\"]\n",
    "    pwshap_t_m2_y_gt = results_gt[\"path_wise_shap_t_m2_y\"]\n",
    "\n",
    "    pwshap_t_y_model = result_model[\"pishap_t_y\"]\n",
    "    pwshap_m1_m2_y_model = result_model[\"path_wise_shap_t_m1_m2_y\"]\n",
    "    pwshap_t_m1_y_model = result_model[\"path_wise_shap_t_m1_y\"] \n",
    "    pwshap_t_m2_y_model = result_model[\"path_wise_shap_t_m2_y\"]\n",
    "\n",
    "    # Extract PE-SHAP* values (using pishap keys)\n",
    "    pishap_t_y_gt = results_gt[\"pishap_t_y\"]\n",
    "    pishap_m1_m2_y_gt = results_gt[\"pishap_t_m1_m2_y\"]\n",
    "    pishap_t_m1_y_gt = results_gt[\"pishap_t_m1_y\"]\n",
    "    pishap_t_m2_y_gt = results_gt[\"pishap_t_m2_y\"]\n",
    "\n",
    "    pishap_t_y_model = result_model[\"pishap_t_y\"]\n",
    "    pishap_m1_m2_y_model = result_model[\"pishap_t_m1_m2_y\"]\n",
    "    pishap_t_m1_y_model = result_model[\"pishap_t_m1_y\"]\n",
    "    pishap_t_m2_y_model = result_model[\"pishap_t_m2_y\"]\n",
    "\n",
    "    # Create the simplified table\n",
    "    table = f\"\"\"| PW-SHAP* | {format_value(pwshap_t_y_gt, pwshap_t_y_model)} | {format_value(pwshap_m1_m2_y_gt, pwshap_m1_m2_y_model)} | {format_value(pwshap_t_m1_y_gt, pwshap_t_m1_y_model)} | {format_value(pwshap_t_m2_y_gt, pwshap_t_m2_y_model)} |\n",
    "| PE-SHAP* | {format_value(pishap_t_y_gt, pishap_t_y_model)} | {format_value(pishap_m1_m2_y_gt, pishap_m1_m2_y_model)} | {format_value(pishap_t_m1_y_gt, pishap_t_m1_y_model)} | {format_value(pishap_t_m2_y_gt, pishap_t_m2_y_model)} |\"\"\"\n",
    "    \n",
    "    return table\n",
    "\n",
    "# Create and display the table\n",
    "table = create_shap_comparison_table(results_gt, result_model)\n",
    "print(table)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Table with 2 decimal places:\n",
      "| PW-SHAP* | 0.79(0.80) | 0.26(0.27) | 0.32(0.33) | 0.08(0.09) |\n",
      "| PE-SHAP* | 0.79(0.80) | -0.26(-0.27) | -0.18(-0.18) | 0.06(0.06) |\n",
      "\n",
      "Table with 3 decimal places:\n",
      "| PW-SHAP* | 0.786(0.795) | 0.261(0.267) | 0.320(0.330) | 0.083(0.092) |\n",
      "| PE-SHAP* | 0.786(0.795) | -0.261(-0.267) | -0.178(-0.175) | 0.059(0.063) |\n"
     ]
    }
   ],
   "source": [
    "# Example usage with different decimal precision\n",
    "print(\"Table with 2 decimal places:\")\n",
    "print(create_shap_comparison_table(results_gt, result_model, decimal_places=2))\n",
    "\n",
    "print(\"\\nTable with 3 decimal places:\")\n",
    "print(create_shap_comparison_table(results_gt, result_model, decimal_places=3))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
