{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "96c8f9f5",
   "metadata": {},
   "source": [
    "## Problem\n",
    "\n",
    "The two-bit-env problem with confounded-descendant:\n",
    "\n",
    "\\begin{align*}\n",
    "    & X_1 \\sim \\text{Rad}(c_e)\\\\\n",
    "    & Y = X_1 \\text{Rad}(a)\\\\\n",
    "    & X_2 = Y \\text{Rad}(b_e)\n",
    "\\end{align*}\n",
    "\n",
    "Use $(a, b_e, c_e)$ to denote an environment. I use $a = 0.25$ for all environments. We have 4 training domains, with $b_e = 0.8, 0.9, 0.85, 0.95$ and $c_e = 0.9, 0.1, 0.7, 0.3$ respectively. For test domain, we flip the relationship between $X_2$ and $Y$: $b_e = 0.1$. We use $c_e = 0.5$ \n",
    "\n",
    "\n",
    "Since $X_1, X_2, Y$ takes values in $\\{1, -1\\}$, we use $x_{01}$ to denote the prediction $f(1, -1)$ etc, where first coordinate is for $x_1$ and the second component is for $x_2$. I use $p_{011}$ to denote $P(X_1 = 1, X_2 = -1, Y = -1)$, etc. \n",
    "\n",
    "## obj and predictor\n",
    "\n",
    "We use $\\text{obj} := E[\\text{loss}(f(X), Y)] + \\lambda E[g(f(X), Y)]$ as the objective. Here $g$ is the penalty for IRMv1 and gIRMv1 respectively. We computed obj and find the minimizer over $(x_{00}, x_{10}, x_{01}, x_{11})$. We know the optimal invariant solution should be of the form $(v, -v, v, -v)$. \n",
    "\n",
    "* Note: I implement gIRM by first performing importance sampling on both the loss and the penalty. (Another way is to do importance sampling only on the penalty. But the two approaches give the same result here.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e526a434",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympy import *\n",
    "import numpy as np\n",
    "from scipy.optimize import minimize as scipy_min\n",
    "import pdb\n",
    "np.random.seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5cf0e2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sqls(zh, z):\n",
    "    return 0.5 * Pow(zh - z, 2)\n",
    "\n",
    "def logls(zh, z):\n",
    "    return log(1 + exp(-z * zh))\n",
    "\n",
    "def zols(zh, z):\n",
    "    if Eq(sign(zh), z):\n",
    "        return 0\n",
    "    return 1\n",
    "\n",
    "def IRMparts(e, ls, x00, x10, x01, x11, generalized = False):\n",
    "    a, b, c = e\n",
    "    w = symbols('w')\n",
    "    \n",
    "    p000 = (1 - c)*(1 - a)*(1 - b)\n",
    "    p001 = (1 - c) * a * b\n",
    "    p100 = c * a * (1- b)\n",
    "    p101 = c*(1 - a)*b\n",
    "    p010 = (1 - c)*(1 - a)*b\n",
    "    p011 = (1-c)*a*(1 - b)\n",
    "    p110 = c * a * b\n",
    "    p111 = c * (1 - a) * (1 - b);\n",
    "    \n",
    "    R = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \\\n",
    "    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\\\n",
    "    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\\\n",
    "    p110 * ls(x11, 1) + p111 * ls(x11, -1)\n",
    "    \n",
    "    if not generalized:\n",
    "        dR = R.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)\n",
    "        return [R, dR]\n",
    "    \n",
    "    marg = (1 - c) * (1 - a) + c * a ## P(Y = 1)\n",
    "    p000 *= 0.5 / marg\n",
    "    p010 *= 0.5 / marg\n",
    "    p100 *= 0.5 / marg\n",
    "    p110 *= 0.5 / marg\n",
    "\n",
    "    p001 *= 0.5 / (1 - marg)\n",
    "    p011 *= 0.5 / (1 - marg)\n",
    "    p101 *= 0.5 / (1 - marg)\n",
    "    p111 *= 0.5 / (1 - marg)\n",
    "\n",
    "    R_w = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \\\n",
    "    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\\\n",
    "    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\\\n",
    "    p110 * ls(x11, 1) + p111 * ls(x11, -1)\n",
    "    \n",
    "    dR = R_w.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)\n",
    "    \n",
    "    return [R, dR]\n",
    "\n",
    "def IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized = False):\n",
    "    R, dR = IRMparts(e, ls, x00, x10, x01, x11, generalized)\n",
    "    return R + lam * dR**2\n",
    "\n",
    "def IRMv1(envs, lam, ls, x00, x10, x01, x11, generalized = False):\n",
    "    results = 0\n",
    "    for e in envs:\n",
    "        results += IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized)\n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "db0617f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def conf_desc(alpha, lam, betalist, clist, ls, inits, generalized = False):\n",
    "    envs = []\n",
    "    for i in range(len(betalist)):\n",
    "        envs += [[alpha, betalist[i], clist[i]]]\n",
    "\n",
    "    x00, x10, x01, x11 = symbols('x00 x10 x01 x11')\n",
    "    obj = IRMv1(envs, lam ,ls, x00, x10, x01, x11, generalized)\n",
    "\n",
    "    func = lambdify([x00, x10, x01, x11], obj)\n",
    "    def func_np(params):\n",
    "        x00, x10, x01, x11 = params\n",
    "        return func(x00, x10, x01, x11)\n",
    "    \n",
    "\n",
    "    n_init = inits.shape[0]\n",
    "    sols = np.zeros((n_init, 4))\n",
    "    objs = []\n",
    "    for i in range(n_init):\n",
    "        res = scipy_min(func_np, inits[i, :], method='Nelder-Mead', tol=1e-12)\n",
    "        sols[i, :] = res.x\n",
    "        objs.append(res.fun)\n",
    "\n",
    "    obj_min = min(objs)\n",
    "    sol_min = sols[objs.index(obj_min), :]\n",
    "    print(\"obj,   sol\")\n",
    "    print((round(obj_min, 2), sol_min.round(2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2e82774",
   "metadata": {},
   "source": [
    "## Without covariate shift in $X_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "027cc222",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(1.5, array([ 0.52, -0.52,  0.48, -0.48]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(2.25, array([ 1.01, -1.11,  1.08, -1.18]))\n"
     ]
    }
   ],
   "source": [
    "## global parameters\n",
    "lam = 1e+20\n",
    "alpha = 0.25\n",
    "### initialization (the same for all runs)\n",
    "n_init = 500\n",
    "inits = np.random.uniform(low = -2, high = 2, size = (n_init, 4))\n",
    "\n",
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.5, 0.5, 0.5, 0.5]\n",
    "\n",
    "print(\"square loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, sqls, inits)\n",
    "print(\"logistic loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, logls, inits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bafbb3c",
   "metadata": {},
   "source": [
    "## With covariate shift in$X_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1804a3f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(1.5, array([ 0.5, -0.5,  0.5, -0.5]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(2.25, array([ 1.1, -1.1,  1.1, -1.1]))\n"
     ]
    }
   ],
   "source": [
    "alpha = 0.25\n",
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.9, 0.1, 0.7, 0.3]\n",
    "\n",
    "print(\"square loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, sqls, inits)\n",
    "print(\"logistic loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, logls, inits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a88bdf0e",
   "metadata": {},
   "source": [
    "## g-IRMv1 with covariate shift in $X_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a4ff7c7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(4.53, array([0., 0., 0., 0.]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(1282.48, array([0., 0., 0., 0.]))\n"
     ]
    }
   ],
   "source": [
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.9, 0.1, 0.7, 0.15]\n",
    "\n",
    "print(\"square loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, sqls, inits, generalized = True)\n",
    "print(\"logistic loss\")\n",
    "conf_desc(alpha, lam, betalist, clist, logls, inits, generalized = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
