{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"GZJJ20vpGm1v","executionInfo":{"status":"ok","timestamp":1716399967203,"user_tz":-120,"elapsed":3237,"user":{"displayName":"Alex H","userId":"17875663050767085463"}}},"outputs":[],"source":["import torch\n","import math\n","import numpy as np\n","import matplotlib.pyplot as plt"]},{"cell_type":"code","execution_count":2,"metadata":{"id":"sBHfRmidGm1x","executionInfo":{"status":"ok","timestamp":1716399967204,"user_tz":-120,"elapsed":3,"user":{"displayName":"Alex H","userId":"17875663050767085463"}}},"outputs":[],"source":["# helper methods\n","class MLP(torch.nn.Module):\n","    def __init__(self, input_size, hidden_size, output_size):\n","        super(MLP, self).__init__()\n","        self.fc1 = torch.nn.Linear(input_size, hidden_size)\n","        self.fc2 = torch.nn.Linear(hidden_size, output_size)\n","        self.relu = torch.nn.ReLU()\n","\n","    def forward(self, x):\n","        x = self.fc1(x)\n","        x = self.relu(x)\n","        x = self.fc2(x)\n","        return x\n","\n","def step(x):\n","    return (torch.sign(x) + 1) // 2\n","\n","def identity(x):\n","    return x\n","\n","def quadratic(x):\n","    return x**2\n","\n","def power(k):\n","    def helper(x):\n","        return x**k\n","    return helper\n","\n","def set_seed(seed):\n","    np.random.seed(seed)\n","    torch.manual_seed(seed)\n","\n","def z_oracle(dz):\n","    return torch.normal(0, 1, (dz, 1))\n","\n","def multiple_z_oracle(dz, n):\n","    return torch.normal(0, 1, (dz, n))\n","\n","\n","def two_sample_oracle(f, dx, theta_star, Gamma, noise_std):\n","    def helper(Z):\n","        # h and noise\n","        h = noise_std * torch.normal(1, 1, (dx, 1))\n","        h_prime = noise_std * torch.normal(1, 1, (dx, 1))\n","        e2 = noise_std * torch.normal(0, 1, (dx, 1))\n","        e2_prime = noise_std * torch.normal(0, 1, (dx, 1))\n","        e1 = noise_std * torch.normal(0, 1, (1, 1))\n","\n","        # samples\n","        X = f(Gamma.T @ Z) + h + e2\n","        X_prime = f(Gamma.T @ Z) + h_prime + e2_prime\n","        Y = theta_star.T @ X + h[0] + e1\n","\n","        return X, X_prime, Y\n","    return helper\n","\n","def linear_g_model(theta, X):\n","    return theta.T @ X"]},{"cell_type":"markdown","metadata":{"id":"RZIl5UZBGm1y"},"source":["# Linear X to Y"]},{"cell_type":"code","execution_count":3,"metadata":{"id":"L5b5ZNjGGm1z","executionInfo":{"status":"ok","timestamp":1716399967204,"user_tz":-120,"elapsed":3,"user":{"displayName":"Alex H","userId":"17875663050767085463"}}},"outputs":[],"source":["def model_train(seed, dz, dx, horiz, noise_std, alpha, zx_model_type=\"linear\"):\n","    # groundtruth\n","    set_seed(seed)\n","    theta_star = torch.randn(dx, 1)\n","    Gamma = torch.randn(dz, dx)\n","\n","    # theta initialization\n","    theta = torch.zeros(dx, 1)\n","    mse_history, theta_history = [], []\n","\n","    # X to Z model\n","    if zx_model_type == \"step\":\n","        f = step\n","    elif zx_model_type == \"abs\":\n","        f = torch.abs\n","    elif zx_model_type == \"linear\":\n","        f = identity\n","    elif zx_model_type == \"sin\":\n","        f = torch.sin\n","    elif zx_model_type == \"quadratic\":\n","        f = quadratic\n","    elif zx_model_type == \"cubic\":\n","        f = power(3)\n","    elif zx_model_type == \"quartic\":\n","        f = power(4)\n","    else:\n","        raise NotImplementedError\n","\n","    # two sample oracle\n","    double_sample_oracle = two_sample_oracle(f, dx, theta_star, Gamma, noise_std)\n","\n","    # model training\n","    for i in range(horiz):\n","        # generate Z, X, X', and Y\n","        Z = z_oracle(dz)\n","        X, X_prime, Y = double_sample_oracle(Z)\n","\n","        # obtain gradient in two-sample IV\n","        grad_g = X_prime\n","        v = (linear_g_model(theta, X) - Y) * grad_g\n","\n","        # update theta\n","        theta.data = theta.data - alpha * v\n","\n","        # evaluate MSE\n","        error = torch.norm(theta - theta_star)**2\n","\n","        # update history\n","        mse_history.append(error)\n","        theta_history.append(theta)\n","\n","    return [torch.stack(mse_history), torch.stack(theta_history), theta_star, Gamma]"]},{"cell_type":"code","execution_count":4,"metadata":{"id":"VppIb7uNGm10","executionInfo":{"status":"ok","timestamp":1716399985595,"user_tz":-120,"elapsed":16100,"user":{"displayName":"Alex H","userId":"17875663050767085463"}}},"outputs":[],"source":["dx, dz = 4, 8\n","noise = 1.0\n","horiz = 1000\n","a = 0.01\n","alpha = a * math.log(horiz) / horiz\n","result_list = []\n","iteration = [i for i in range(horiz)]\n","\n","for i in range(50):\n","    seed = i\n","    mse, theta, theta_star, Gamma = model_train(i, dz, dx, horiz, noise, alpha, zx_model_type=\"linear\")\n","    result_list.append(mse)\n","\n","result_list = torch.stack(result_list)\n","mse_mean = torch.mean(result_list, dim=0)\n","mse_std = torch.std(result_list, dim=0)"]}],"metadata":{"kernelspec":{"display_name":"pytorch_env","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.18"},"colab":{"provenance":[]}},"nbformat":4,"nbformat_minor":0}