{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Source Code for Trained Transformer Evaluation (Main Result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sas7bdat import SAS7BDAT\n",
    "from collections import OrderedDict\n",
    "import re\n",
    "import os\n",
    "from models import *\n",
    "from eval import *\n",
    "from samplers import *\n",
    "from tasks import *\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from tqdm.notebook import tqdm\n",
    "from eval import get_run_metrics, read_run_dir, get_model_from_run\n",
    "# from plot_utils import *\n",
    "import pickle\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def evaluate_models(run_path, n_list, batch_size=64, n_dims=15, scale_Theta=1, \n",
    "                       normalize_w=True, normalize_Theta=False, mode='linear', q_truncated=None, num_task=500, \n",
    "                       delta=5, seed=1):\n",
    "    # Evaluate MSE vs number of in-context samples\n",
    "\n",
    "    torch.random.manual_seed(seed)\n",
    "    p = int(1/3 * n_dims)\n",
    "\n",
    "    # Load the GPT model\n",
    "    gptmodel, conf = get_model_from_run(run_path, step=-1)\n",
    "    gptmodel = gptmodel.cuda().eval()\n",
    "\n",
    "    # Models to evaluate\n",
    "    if mode == \"multicollinearity-heavy\":\n",
    "        all_model = [RidgeIVRegressionModel(), RidgeLeastSquaresModel(), gptmodel]\n",
    "    else:\n",
    "        all_model = [IVRegressionModel(), LeastSquaresModel(), gptmodel]\n",
    "\n",
    "    metrics = {}\n",
    "\n",
    "    for model in all_model:\n",
    "        metrics[model.name] = {}\n",
    "        if \"gpt\" in model.name:\n",
    "            batch_size = int(batch_size / 4)\n",
    "        for n in tqdm(n_list):\n",
    "            n_points = n + 1 # Include test sample\n",
    "            all_pred_mse = []\n",
    "            all_coef_mse = []\n",
    "            metrics[model.name][n] = {}\n",
    "\n",
    "            for i in range(num_task):\n",
    "                data_sampler = get_data_sampler(\"iv\", n_dims, normalize_Theta=normalize_Theta, mode=mode, q_truncated=q_truncated)\n",
    "                task_sampler = get_task_sampler(\"iv_regression\", n_dims, batch_size, normalize_w=normalize_w)\n",
    "                U = torch.randn(batch_size, n_points, p)\n",
    "                xzs, xs_p = gen_standard_iv(data_sampler, n_points, batch_size, U, scale_Theta)\n",
    "                pred_mse = eval_batch(model, task_sampler, xzs, xs_p, U)\n",
    "                coef_mse = eval_coef(model, task_sampler, xzs, U, delta=delta)\n",
    "\n",
    "                all_pred_mse.append(pred_mse[:, -1].mean().item())\n",
    "                all_coef_mse.append(coef_mse[:, -1].mean().item())\n",
    "\n",
    "            metrics[model.name][n]['pred_mse'] = np.mean(all_pred_mse)\n",
    "            metrics[model.name][n]['coef_mse'] = np.mean(all_coef_mse)\n",
    "\n",
    "    return metrics\n",
    "\n",
    "def eval_ivstrength(run_path, r_list, n=50, batch_size=64, n_dims=15, \n",
    "                                normalize_w=True, normalize_Theta=False, mode=\"linear\",\n",
    "                                q_truncated=None, num_task=500, delta=5, seed=1):\n",
    "    # Evaluate MSE vs IV strength\n",
    "    \n",
    "    torch.random.manual_seed(seed)\n",
    "    p = int(1/3 * n_dims)\n",
    "\n",
    "    # Load the GPT model and its configuration\n",
    "    gptmodel, conf = get_model_from_run(run_path, step=-1)\n",
    "    gptmodel = gptmodel.cuda().eval()\n",
    "\n",
    "    # Models to evaluate\n",
    "    all_model = [IVRegressionModel(), LeastSquaresModel(), gptmodel]\n",
    "\n",
    "    metrics = {}\n",
    "\n",
    "    for model in all_model:\n",
    "        metrics[model.name] = {}\n",
    "        \n",
    "        # Adjust batch size for GPT model\n",
    "        if \"gpt\" in model.name:\n",
    "            batch_size = int(batch_size / 4)\n",
    "        \n",
    "        # Loop over the list of r (IV strength)\n",
    "        for r in tqdm(r_list):\n",
    "            n_points = n + 1  # Include test sample\n",
    "\n",
    "            all_pred_mse = []\n",
    "            all_coef_mse = []\n",
    "            metrics[model.name][r] = {}\n",
    "\n",
    "            for i in range(num_task):\n",
    "                # Data and task sampling\n",
    "                data_sampler = get_data_sampler(\"iv\", n_dims, normalize_Theta=normalize_Theta, q_truncated=q_truncated)\n",
    "                task_sampler = get_task_sampler(\"iv_regression\", n_dims, batch_size, normalize_w=normalize_w)\n",
    "                \n",
    "                # Generate random input data\n",
    "                U = torch.randn(batch_size, n_points, p)\n",
    "                xzs, xs_p = gen_standard_iv(data_sampler, n_points, batch_size, U, r)\n",
    "                \n",
    "                # Evaluate model predictions and coefficients\n",
    "                pred_mse = eval_batch(model, task_sampler, xzs, xs_p, U)\n",
    "                coef_mse = eval_coef(model, task_sampler, xzs, U, delta=delta)\n",
    "\n",
    "                # Collect MSE values\n",
    "                all_pred_mse.append(pred_mse[:, -1].mean().item())\n",
    "                all_coef_mse.append(coef_mse[:, -1].mean().item())\n",
    "\n",
    "            # Store the mean MSE values for this model and 'n'\n",
    "            metrics[model.name][r]['pred_mse'] = np.mean(all_pred_mse)\n",
    "            metrics[model.name][r]['coef_mse'] = np.mean(all_coef_mse)\n",
    "\n",
    "    return metrics\n",
    "\n",
    "def eval_endostrength(run_path, e_list, n=50, batch_size=64, n_dims=15, \n",
    "                                normalize_w=True, normalize_Theta=False, mode=\"linear\",\n",
    "                                q_truncated=None, num_task=500, delta=5, seed=1):\n",
    "    # Evaluate MSE vs endogeneity strength\n",
    "    \n",
    "    torch.random.manual_seed(seed)\n",
    "    p = int(1/3 * n_dims)\n",
    "\n",
    "    # Load the GPT model and its configuration\n",
    "    gptmodel, conf = get_model_from_run(run_path, step=-1)\n",
    "    gptmodel = gptmodel.cuda().eval()\n",
    "\n",
    "    # Models to evaluate\n",
    "    all_model = [IVRegressionModel(), LeastSquaresModel(), gptmodel]\n",
    "\n",
    "    metrics = {}\n",
    "\n",
    "    for model in all_model:\n",
    "        metrics[model.name] = {}\n",
    "        \n",
    "        # Adjust batch size for GPT model\n",
    "        if \"gpt\" in model.name:\n",
    "            batch_size = int(batch_size / 4)\n",
    "        \n",
    "        # Loop over the list of r (IV strength)\n",
    "        for scale_u in tqdm(e_list):\n",
    "            n_points = n + 1  # Include test sample\n",
    "\n",
    "            all_pred_mse = []\n",
    "            all_coef_mse = []\n",
    "            metrics[model.name][scale_u] = {}\n",
    "\n",
    "            for i in range(num_task):\n",
    "                # Data and task sampling\n",
    "                data_sampler = get_data_sampler(\"iv\", n_dims, normalize_Theta=normalize_Theta, q_truncated=q_truncated)\n",
    "                task_sampler = get_task_sampler(\"iv_regression\", n_dims, batch_size, normalize_w=normalize_w)\n",
    "                \n",
    "                # Generate random input data\n",
    "                U = torch.randn(batch_size, n_points, p) * scale_u\n",
    "                xzs, xs_p = gen_standard_iv(data_sampler, n_points, batch_size, U)\n",
    "                \n",
    "                # Evaluate model predictions and coefficients\n",
    "                pred_mse = eval_batch(model, task_sampler, xzs, xs_p, U)\n",
    "                coef_mse = eval_coef(model, task_sampler, xzs, U, delta=delta)\n",
    "\n",
    "                # Collect MSE values\n",
    "                all_pred_mse.append(pred_mse[:, -1].mean().item())\n",
    "                all_coef_mse.append(coef_mse[:, -1].mean().item())\n",
    "\n",
    "            # Store the mean MSE values for this model and 'n'\n",
    "            metrics[model.name][scale_u]['pred_mse'] = np.mean(all_pred_mse)\n",
    "            metrics[model.name][scale_u]['coef_mse'] = np.mean(all_coef_mse)\n",
    "\n",
    "    return metrics\n",
    "\n",
    "\n",
    "def plot_mse_vs_sample_size(metrics, n_list, colors=None, ylim=None, ylim_inset=None, title=None, save_path=None):\n",
    "\n",
    "    # Define a color map for the models if not provided\n",
    "    if colors is None:\n",
    "        colors = ['blue', 'green', 'red', 'orange', 'cyan']\n",
    "    color_map = {}\n",
    "    i = 0  # Index for color list\n",
    "\n",
    "    plt.figure(figsize=(14, 10))\n",
    "\n",
    "    # Loop through each model\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            if \"Ridge\" in model_name:\n",
    "                model_name = \"Ridge 2SLS\"\n",
    "            else:\n",
    "                model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            if \"Ridge\" in model_name:\n",
    "                model_name = \"Ridge OLS\"\n",
    "            else:\n",
    "                model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "        if model_name not in color_map:\n",
    "            color_map[model_name] = colors[i % len(colors)]\n",
    "            i += 1\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "        \n",
    "        # Collect MSE values for each number of points\n",
    "        for n in n_list:\n",
    "            pred_mses.append(points[n]['pred_mse'])\n",
    "            coef_mses.append(points[n]['coef_mse'])\n",
    "        \n",
    "        # Plotting the MSE values\n",
    "        plt.plot(n_list, pred_mses, label=f'{model_name} ICPE', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        plt.plot(n_list, coef_mses, label=f'{model_name} MSE', linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        \n",
    "    if ylim is None:\n",
    "        ylim = (0, 0.7 * max([max(points[n]['pred_mse'], points[n]['coef_mse']) for points in metrics.values() for n in n_list]))\n",
    "    if title is None:\n",
    "        title = r'Error$^2$ vs In-context Sample Size'\n",
    "\n",
    "    plt.title(title, fontsize=28, pad=20)\n",
    "    plt.xlabel('Number of In-context Samples', fontsize=28)\n",
    "    plt.ylabel(r'Error$^2$', fontsize=28)\n",
    "    plt.ylim(ylim)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=22)\n",
    "    plt.legend(loc='upper right', fontsize=24)\n",
    "    plt.grid(True)\n",
    "\n",
    "    # Create an inset plot for zoom-in and position it outside the main plot in the lower-right corner\n",
    "    ax_inset = inset_axes(plt.gca(), width=\"40%\", height=\"40%\", loc='lower right', \n",
    "                          bbox_to_anchor=(0.65, 0.1, 0.5, 0.5), bbox_transform=plt.gcf().transFigure)\n",
    "\n",
    "    # Plotting the zoomed-in region on the inset plot\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            if \"Ridge\" in model_name:\n",
    "                model_name = \"Ridge 2SLS\"\n",
    "            else:\n",
    "                model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            if \"Ridge\" in model_name:\n",
    "                model_name = \"Ridge OLS\"\n",
    "            else:\n",
    "                model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "\n",
    "        for n in n_list:\n",
    "            pred_mses.append(points[n]['pred_mse'])\n",
    "            coef_mses.append(points[n]['coef_mse'])\n",
    "\n",
    "        # Plot the zoomed-in region for MSEE\n",
    "        ax_inset.plot(n_list, coef_mses, linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        ax_inset.plot(n_list, pred_mses, color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "\n",
    "    # Set limits for the zoomed-in region\n",
    "    if ylim_inset is None:\n",
    "        ylim_inset = (0, 1)\n",
    "    ax_inset.set_xlim(n_list[-1]-7, n_list[-1]+0.5)  # Adjust the x-limits to zoom in around these points\n",
    "    ax_inset.set_ylim(ylim_inset)  # Adjust the y-limits to focus on lower MSEE values\n",
    "    ax_inset.grid(True)\n",
    "\n",
    "    # Optionally, customize inset appearance (ticks, labels, etc.)\n",
    "    ax_inset.tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "    # Save or display the plot\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "    else:\n",
    "        plt.show()\n",
    "\n",
    "def plot_mse_vs_ivstrength(metrics, r_list, colors=None, ylim=None, ylim_inset=None, title=None, save_path=None):\n",
    "\n",
    "    # Define a color map for the models if not provided\n",
    "    if colors is None:\n",
    "        colors = ['blue', 'green', 'red', 'orange', 'cyan']\n",
    "    color_map = {}\n",
    "    i = 0  # Index for color list\n",
    "\n",
    "    plt.figure(figsize=(14, 10))\n",
    "\n",
    "    # Loop through each model\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "        if model_name not in color_map:\n",
    "            color_map[model_name] = colors[i % len(colors)]\n",
    "            i += 1\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "        \n",
    "        # Collect MSE values for each number of points\n",
    "        for r in r_list:\n",
    "            pred_mses.append(points[r]['pred_mse'])\n",
    "            coef_mses.append(points[r]['coef_mse'])\n",
    "        \n",
    "        # Plotting the MSE values\n",
    "        plt.plot(r_list, pred_mses, label=f'{model_name} ICPE', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        plt.plot(r_list, coef_mses, label=f'{model_name} MSE', linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        \n",
    "    if ylim is None:\n",
    "        ylim = (0, 0.7 * max([max(points[r]['pred_mse'], points[r]['coef_mse']) for points in metrics.values() for r in r_list]))\n",
    "    if title is None:\n",
    "        title = r'Error$^2$ vs IV Strength'\n",
    "\n",
    "    plt.title(title, fontsize=28, pad=20)\n",
    "    plt.xlabel('Instrumental Strength', fontsize=28)\n",
    "    plt.ylabel(r'Error$^2$', fontsize=28)\n",
    "    plt.ylim(ylim)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=22)\n",
    "    plt.legend(loc='upper right', fontsize=24)\n",
    "    plt.grid(True)\n",
    "\n",
    "    # Create an inset plot for zoom-in and position it outside the main plot in the lower-right corner\n",
    "    ax_inset = inset_axes(plt.gca(), width=\"40%\", height=\"40%\", loc='lower right', \n",
    "                          bbox_to_anchor=(0.65, 0.1, 0.5, 0.5), bbox_transform=plt.gcf().transFigure)\n",
    "\n",
    "    # Plotting the zoomed-in region on the inset plot\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "\n",
    "        for r in r_list:\n",
    "            pred_mses.append(points[r]['pred_mse'])\n",
    "            coef_mses.append(points[r]['coef_mse'])\n",
    "\n",
    "        # Plot the zoomed-in region for MSEE\n",
    "        ax_inset.plot(r_list, coef_mses, linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        ax_inset.plot(r_list, pred_mses, color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "\n",
    "    # Set limits for the zoomed-in region\n",
    "    if ylim_inset is None:\n",
    "        ylim_inset = (0, 1)\n",
    "    ax_inset.set_xlim(r_list[-1]-0.35, r_list[-1]+0.03)  # Adjust the x-limits to zoom in around these points\n",
    "    ax_inset.set_ylim(ylim_inset)  # Adjust the y-limits to focus on lower MSEE values\n",
    "    ax_inset.grid(True)\n",
    "\n",
    "    # Optionally, customize inset appearance (ticks, labels, etc.)\n",
    "    ax_inset.tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "    # Save or display the plot\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "    else:\n",
    "        plt.show()\n",
    "\n",
    "def plot_mse_vs_endostrength(metrics, e_list, colors=None, ylim=None, ylim_inset=None, title=None, save_path=None):\n",
    "\n",
    "    # Define a color map for the models if not provided\n",
    "    if colors is None:\n",
    "        colors = ['blue', 'green', 'red', 'orange', 'cyan']\n",
    "    color_map = {}\n",
    "    i = 0  # Index for color list\n",
    "\n",
    "    plt.figure(figsize=(14, 10))\n",
    "\n",
    "    # Loop through each model\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "        if model_name not in color_map:\n",
    "            color_map[model_name] = colors[i % len(colors)]\n",
    "            i += 1\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "        \n",
    "        # Collect MSE values for each number of points\n",
    "        for scale_u in e_list:\n",
    "            pred_mses.append(points[scale_u]['pred_mse'])\n",
    "            coef_mses.append(points[scale_u]['coef_mse'])\n",
    "        \n",
    "        # Plotting the MSE values\n",
    "        plt.plot(e_list, pred_mses, label=f'{model_name} ICPE', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        plt.plot(e_list, coef_mses, label=f'{model_name} MSE', linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        \n",
    "    if ylim is None:\n",
    "        ylim = (0, 0.7 * max([max(points[scale_u]['pred_mse'], points[scale_u]['coef_mse']) for points in metrics.values() for scale_u in e_list]))\n",
    "    if title is None:\n",
    "        title = r'Error$^2$ vs Endogeneity Level'\n",
    "\n",
    "    plt.title(title, fontsize=28, pad=20)\n",
    "    plt.xlabel('Endogeneity Level', fontsize=28)\n",
    "    plt.ylabel(r'Error$^2$', fontsize=28)\n",
    "    plt.ylim(ylim)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=22)\n",
    "    plt.legend(loc='upper left', fontsize=24)\n",
    "    plt.grid(True)\n",
    "\n",
    "    # Create an inset plot for zoom-in and position it outside the main plot in the lower-right corner\n",
    "    ax_inset = inset_axes(plt.gca(), width=\"40%\", height=\"40%\", loc='lower right', \n",
    "                          bbox_to_anchor=(0.65, 0.1, 0.5, 0.5), bbox_transform=plt.gcf().transFigure)\n",
    "\n",
    "    # Plotting the zoomed-in region on the inset plot\n",
    "    for model_name, points in metrics.items():\n",
    "        if \"iv\" in model_name:\n",
    "            model_name = \"2SLS\"\n",
    "        elif \"OLS\" in model_name:\n",
    "            model_name = \"OLS\"\n",
    "        elif \"gpt\" in model_name:\n",
    "            model_name = \"Transformer\"\n",
    "\n",
    "        pred_mses = []\n",
    "        coef_mses = []\n",
    "\n",
    "        for scale_u in e_list:\n",
    "            pred_mses.append(points[scale_u]['pred_mse'])\n",
    "            coef_mses.append(points[scale_u]['coef_mse'])\n",
    "\n",
    "        # Plot the zoomed-in region for MSEE\n",
    "        ax_inset.plot(e_list, coef_mses, linestyle='--', color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "        ax_inset.plot(e_list, pred_mses, color=color_map[model_name], marker='', linewidth=3, markersize=10)\n",
    "\n",
    "    # Set limits for the zoomed-in region\n",
    "    if ylim_inset is None:\n",
    "        ylim_inset = (0, 1)\n",
    "    ax_inset.set_xlim(e_list[-1]-0.35, e_list[-1]+0.03)  # Adjust the x-limits to zoom in around these points\n",
    "    ax_inset.set_ylim(ylim_inset)  # Adjust the y-limits to focus on lower MSEE values\n",
    "    ax_inset.grid(True)\n",
    "\n",
    "    # Optionally, customize inset appearance (ticks, labels, etc.)\n",
    "    ax_inset.tick_params(axis='both', which='major', labelsize=12)\n",
    "\n",
    "    # Save or display the plot\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, bbox_inches='tight')\n",
    "    else:\n",
    "        plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set Model Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_dir = \"../TrainedTF\"\n",
    "run_name = \"iv_regression\"\n",
    "run_id = \"Loop=10_N=51_d=15_nhead=12_D=80_L=2\"\n",
    "\n",
    "run_path = os.path.join(run_dir, run_name, run_id)\n",
    "n_list = range(20, 51, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation (May take a few hours) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, mode = \"quadratic\", num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n_quadratic.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, mode = \"non-linear\", num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n_non-linear.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, q_truncated = 3, num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n_underdetermined.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, mode = \"multicollinearity\", num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n_multicollinearity.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = evaluate_models(run_path = run_path, n_list = n_list, mode = \"multicollinearity-heavy\", num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_n_multicollinearity-heavy.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_list = np.arange(0.1, 2.1, 0.1)\n",
    "metrics = eval_ivstrength(run_path = run_path, r_list = r_list, num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_ivstrength.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "e_list = np.arange(0.1, 2.1, 0.1)\n",
    "metrics = eval_endostrength(run_path = run_path, e_list = e_list, num_task=500)\n",
    "store_path = os.path.join(run_path, \"results/mse_vs_endostrength.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(metrics, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,6), save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n_quadratic.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n_quadratic.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,12), ylim_inset=(0,0.5), title=r\"Error$^2$ vs In-Context Sample Size (Quadratic IV)\", save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n_non-linear.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n_non-linear.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,12), ylim_inset=(0,0.6), title=r\"Error$^2$ vs In-Context Sample Size (Non-Linear IV)\", save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n_underdetermined.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n_underdetermined.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,5), ylim_inset=(0,1), title=r\"Error$^2$ vs In-Context Sample Size (q<p)\", save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n_multicollinearity.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n_multicollinearity.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,12.8), ylim_inset=(0,1), title=r\"Error$^2$ vs In-Context Sample Size (Multicollinearity)\", save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/mse_vs_n_multicollinearity-heavy.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_n_multicollinearity-heavy.png\")\n",
    "plot_mse_vs_sample_size(metrics, n_list, ylim=(0,6), ylim_inset=(0,1), title=r\"Error$^2$ vs In-Context Sample Size (Heavy Multicollinearity)\", save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_list = np.arange(0.1, 2.1, 0.1)\n",
    "load_path = os.path.join(run_path, \"results/mse_vs_ivstrength.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_ivstrength.png\")\n",
    "plot_mse_vs_ivstrength(metrics, r_list, ylim=(0,4.8), save_path=save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_list = np.arange(0.1, 2.1, 0.1)\n",
    "load_path = os.path.join(run_path, \"results/mse_vs_endostrength.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "save_path = os.path.join(run_path, \"figures/mse_vs_endostrength.png\")\n",
    "plot_mse_vs_endostrength(metrics, r_list, ylim=(0,6), save_path=save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Real World Dataset Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sas7bdat import SAS7BDAT\n",
    "\n",
    "def calculate_beta_gpt(df, gptmodel, delta=1):\n",
    "    # Variables of interest\n",
    "    y = torch.tensor(df['WEEKSM'].to_numpy(), dtype=torch.float32)  # Mother's hours of work per week\n",
    "    z = torch.tensor(df['SEX2ND'].to_numpy(), dtype=torch.float32)  # Indicator for whether the first and second child have the same sex\n",
    "    x = torch.tensor(df['KIDCOUNT'].to_numpy(), dtype=torch.float32)  # Number of children\n",
    "    \n",
    "    xz = torch.zeros((len(df), 15), dtype=torch.float32) \n",
    "    xz[:, 0] = x\n",
    "    xz[:, 5] = z\n",
    "\n",
    "    xz = xz.unsqueeze(0)\n",
    "    z = z.unsqueeze(0)\n",
    "    x = x.unsqueeze(0)\n",
    "    y = y.unsqueeze(0)\n",
    "\n",
    "    xz1 = xz.clone()\n",
    "    y_hat1 = gptmodel(xz1.to(\"cuda\"), y.to(\"cuda\"), inds=[y.shape[1] - 1]).detach()\n",
    "    \n",
    "    # Compute perturbed prediction\n",
    "    xz2 = xz.clone()\n",
    "    xz2[:, -1, 0] += delta\n",
    "    y_hat2 = gptmodel(xz2.to(\"cuda\"), y.to(\"cuda\"), inds=[y.shape[1] - 1]).detach()\n",
    "    \n",
    "    # Calculate beta_gpt\n",
    "    beta_gpt = (y_hat2.to(\"cpu\") - y_hat1.to(\"cpu\")) / delta\n",
    "    return beta_gpt\n",
    "\n",
    "def estimate_beta(df, gptmodel, n=50, delta=5, repetitions=100, seed=1):\n",
    "    beta_gpt_list = []\n",
    "    n = n + 1 # Include test sample\n",
    "    \n",
    "    for _ in range(repetitions):\n",
    "        sampled_df = df.sample(n=51, random_state=seed, replace=False)\n",
    "        # sampled_df = df\n",
    "        \n",
    "        df0 = sampled_df[:n]\n",
    "        beta = calculate_beta_gpt(df0, gptmodel, delta=delta)\n",
    "        beta_gpt_list.append(beta)\n",
    "        seed += 1\n",
    "        \n",
    "    # Calculate the average beta_gpt across all repetitions\n",
    "    beta_gpt = np.median(beta_gpt_list, axis=0)\n",
    "    beta_gpt_list = np.array([beta_gpt_list[i] for i in range(len(beta_gpt_list))])[:,0,0]\n",
    "    return beta_gpt, beta_gpt_list\n",
    "\n",
    "# Load model\n",
    "gptmodel, conf = get_model_from_run(run_path, step=-1)\n",
    "gptmodel = gptmodel.cuda().eval()\n",
    "delta = 5\n",
    "\n",
    "# Set the data path\n",
    "file_path = os.path.join(run_path, \"AngEv98/m_d_806.sas7bdat\")\n",
    "\n",
    "# Read only the first 10000 rows\n",
    "with SAS7BDAT(file_path) as file:\n",
    "    chunk = []\n",
    "    for i, row in enumerate(file):\n",
    "        if i == 0:\n",
    "            # Use the first row as column names\n",
    "            columns = row\n",
    "        else:\n",
    "            chunk.append(row)\n",
    "        if len(chunk) == 20000: # Read only the first 20000 rows\n",
    "            break\n",
    "\n",
    "# Preprocessing\n",
    "data = pd.DataFrame(chunk, columns=columns)\n",
    "data = data[(data['WEEKSM'] != '00') & (data['SEX2ND'] != '') & (data['STATE']=='01')]\n",
    "data['SEX2ND'] = data['SEX2ND'].astype(float).to_numpy() \n",
    "data['KIDCOUNT'] = data['KIDCOUNT'].astype(float).to_numpy() \n",
    "data['WEEKSM'] = data['WEEKSM'].astype(float).to_numpy() / 52\n",
    "\n",
    "z = torch.tensor(data[\"SEX2ND\"].to_numpy(), dtype=torch.float32)  # Indicator for whether the first and second child have the same sex\n",
    "x = torch.tensor(data['KIDCOUNT'].to_numpy(), dtype=torch.float32)  # Number of children\n",
    "y = torch.tensor(data['WEEKSM'].to_numpy(), dtype=torch.float32)  # Mother's number of working weeks per year\n",
    "\n",
    "z_with_intercept = torch.cat([torch.ones_like(z).view(-1, 1), z.view(-1, 1)], dim=1)\n",
    "Theta = torch.linalg.lstsq(z_with_intercept, x.view(-1, 1)).solution\n",
    "x_hat = z_with_intercept @ Theta\n",
    "x_hat_with_intercept = torch.cat([torch.ones_like(x_hat).view(-1, 1), x_hat.view(-1, 1)], dim=1)\n",
    "beta_2sls = torch.linalg.lstsq(x_hat_with_intercept, y.view(-1, 1)).solution[1]\n",
    "beta_ols = torch.linalg.lstsq(torch.cat([torch.ones_like(x.view(-1, 1)), x.view(-1, 1)], dim=1), y.view(-1, 1)).solution[1]\n",
    "\n",
    "beta_gpt, beta_gpt_list = estimate_beta(data, gptmodel, n=50, delta=5, repetitions=500)\n",
    "\n",
    "\n",
    "store_path = os.path.join(run_path, \"results/labor_supply_estimate.pkl\")\n",
    "with open(store_path, \"wb\") as file:\n",
    "    pickle.dump(beta_gpt_list, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_path = os.path.join(run_path, \"results/labor_supply_estimate.pkl\")\n",
    "with open(load_path, \"rb\") as file:\n",
    "    metrics = pickle.load(file)\n",
    "\n",
    "save_path = os.path.join(run_path, \"figures/labor_supply_estimate_boxplot.png\")\n",
    "\n",
    "sns.boxplot(data=metrics, color='grey')\n",
    "\n",
    "# Enhancing labels and title\n",
    "# plt.xlabel('Number of In-Context Samples (n=50)', fontsize=14, labelpad=10)\n",
    "plt.ylabel(r'Estimated $\\beta$', fontsize=14, labelpad=10)\n",
    "plt.title('Boxplot of Estimated Coefficient', fontsize=16, pad=15)\n",
    "plt.axhline(y=beta_ols, color='green', linestyle='--', linewidth=1.5, label=fr'$\\beta_{{OLS}} = {float(beta_ols):.3f}$')\n",
    "plt.axhline(y=beta_2sls, color='blue', linestyle='--', linewidth=1.5, label=fr'$\\beta_{{2SLS}} = {float(beta_2sls):.3f}$')\n",
    "plt.axhline(y=np.median(metrics), color='red', linestyle='--', linewidth=1.5, label=fr'$\\beta_{{GPT}} = {np.median(metrics):.3f}$')\n",
    "\n",
    "# Display the plot\n",
    "plt.ylim(-0.17,0.03)\n",
    "plt.legend(loc=\"upper right\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(save_path)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ivtransformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
