{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target solution is always [1, 0.]. ERM (arithmetic_mean) fails on all, and-mask works\n",
    "\n",
    "import torch\n",
    "from torch.autograd import grad\n",
    "\n",
    "def example_A(n=10000, env=1):\n",
    "    x = torch.randn(n, 1) * env\n",
    "    h = torch.randn(n, 1) * env\n",
    "    y = x + h\n",
    "    z = h + torch.randn(n, 1)\n",
    "    return torch.cat((x, z), 1), y.sum(1, keepdim=True)\n",
    "\n",
    "def example_B(n=10000, env=1):\n",
    "    x = torch.randn(n, 1) * env\n",
    "    y = x + torch.randn(n, 1) * env\n",
    "    z = y + torch.randn(n, 1)\n",
    "    return torch.cat((x, z), 1), y.sum(1, keepdim=True)\n",
    "\n",
    "def example_C(n=10000, env=1):\n",
    "    z = env * torch.randn(n, 1)\n",
    "    x = torch.randn(n, 1) + z\n",
    "    y = x + torch.randn(n, 1) + z\n",
    "    # returns only x and zeros as a dummy variable to keep dims consistent with _B and _C\n",
    "    return torch.cat((x, torch.zeros_like(z)), 1), y.sum(1, keepdim=True)\n",
    "\n",
    "\n",
    "# CHANGE THESE TWO HERE <-----------------------\n",
    "dataloader = example_A # _A, _B, _C\n",
    "method = 'arithm_mean' # 'arithm_mean' or 'and_mask'\n",
    "\n",
    "\n",
    "weights = torch.nn.Parameter(torch.ones(2, 1))\n",
    "\n",
    "lr = 1e-3\n",
    "weight_decay = 0.0001\n",
    "opt = torch.optim.Adam([weights], lr=lr)\n",
    "mse = torch.nn.MSELoss(reduction=\"none\")\n",
    "\n",
    "environments = [dataloader(env=0.1),\n",
    "                dataloader(env=1)]\n",
    "\n",
    "for iteration in range(50000):\n",
    "    penalty = 0\n",
    "    opt.zero_grad()\n",
    "    grads = []\n",
    "    for x_e, y_e in environments:\n",
    "        p = torch.randperm(len(x_e))\n",
    "        error_e = mse(x_e[p] @ weights, y_e[p])\n",
    "        error = 1e-5 * error_e.mean()\n",
    "        \n",
    "        grads.append(torch.autograd.grad(error, weights)[0])\n",
    "    \n",
    "    grad = torch.stack(grads, dim=-1)\n",
    "            \n",
    "    if method == 'and_mask':\n",
    "        signs = torch.sign(grad) \n",
    "        mask = torch.abs(signs.mean(dim=-1)) == 1\n",
    "        avg_grad = grad.mean(dim=-1) * mask\n",
    "        final_grads = avg_grad\n",
    "    elif method == 'arithm_mean':\n",
    "        avg_grad = grad.mean(dim=-1)\n",
    "        final_grads = avg_grad\n",
    "    else:\n",
    "        raise ValueError()\n",
    "    \n",
    "    weights.grad = final_grads + weight_decay * lr * weights    \n",
    "    opt.step()\n",
    "\n",
    "    if iteration % 1000 == 0:\n",
    "        print(iteration, weights)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
