{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3972b7a0",
   "metadata": {},
   "source": [
    "## Problem\n",
    "\n",
    "The two-bit-env problem with confounded-descendant:\n",
    "\n",
    "\\begin{align*}\n",
    "    & Y \\sim \\text{Rad}(c_e)\\\\\n",
    "    & X_1 = Y \\text{Rad}(a)\\\\\n",
    "    & X_2 = Y \\text{Rad}(b_e)\n",
    "\\end{align*}\n",
    "\n",
    "Use $(a, b_e, c_e)$ to denote an environment. I use $a = 0.25$ for all environments. We have 4 training domains, with $b_e = 0.8, 0.9, 0.85, 0.95$ and $c_e = 0.9, 0.1, 0.7, 0.3$ respectively. For test domain, we flip the relationship between $X_2$ and $Y$: $b_e = 0.1$. We use $c_e = 0.5$ \n",
    "\n",
    "\n",
    "Since $X_1, X_2, Y$ takes values in $\\{1, -1\\}$, we use $x_{01}$ to denote the prediction $f(1, -1)$ etc, where first coordinate is for $x_1$ and the second component is for $x_2$. I use $p_{011}$ to denote $P(X_1 = 1, X_2 = -1, Y = -1)$, etc. \n",
    "\n",
    "## obj and predictor\n",
    "\n",
    "We use $\\text{obj} := E[\\text{loss}(f(X), Y)] + \\lambda E[g(f(X), Y)]$ as the objective. Here $g$ is the penalty for IRMv1 and gIRMv1 respectively. We computed obj and find the minimizer over $(x_{00}, x_{10}, x_{01}, x_{11})$. We know the optimal invariant solution should be of the form $(v, -v, v, -v)$.\n",
    "\n",
    "\n",
    "* Note: I implement gIRM by first performing importance sampling on both the loss and the penalty. (Another way is to do importance sampling only on the penalty. But the two approaches give the same result here.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e526a434",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympy import *\n",
    "import numpy as np\n",
    "from scipy.optimize import minimize as scipy_min\n",
    "np.random.seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5cf0e2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sqls(zh, z):\n",
    "    return 0.5 * Pow(zh - z, 2)\n",
    "\n",
    "def logls(zh, z):\n",
    "    return log(1 + exp(-z * zh))\n",
    "\n",
    "def zols(zh, z):\n",
    "    if Eq(sign(zh), z):\n",
    "        return 0\n",
    "    return 1\n",
    "\n",
    "def IRMparts(e, ls, x00, x10, x01, x11, generalized = False):\n",
    "    a, b, c = e\n",
    "   \n",
    "    \n",
    "    w = symbols('w')\n",
    "    \n",
    "    p000 = (1 - c)*(1 - a)*(1 - b)\n",
    "    p001 = c*a*b;\n",
    "    p100 = (1 - c)*a*(1 - b)\n",
    "    p101 = c*(1 - a)*b\n",
    "    p010 = (1 - c)*(1 - a)*b\n",
    "    p011 = c*a*(1 - b)\n",
    "    p110 = (1 - c)*a*b\n",
    "    p111 = c * (1 - a) * (1 - b);\n",
    "\n",
    "    R = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \\\n",
    "    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\\\n",
    "    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\\\n",
    "    p110 * ls(x11, 1) + p111 * ls(x11, -1)\n",
    "    \n",
    "    if not generalized:\n",
    "        dR = R.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)\n",
    "        return [R, dR]\n",
    "     \n",
    "    c = 0.5 ## after reweighting it's equivalent to the case without prior shift\n",
    "    p000 = (1 - c)*(1 - a)*(1 - b)\n",
    "    p001 = c*a*b;\n",
    "    p100 = (1 - c)*a*(1 - b)\n",
    "    p101 = c*(1 - a)*b\n",
    "    p010 = (1 - c)*(1 - a)*b\n",
    "    p011 = c*a*(1 - b)\n",
    "    p110 = (1 - c)*a*b\n",
    "    p111 = c * (1 - a) * (1 - b);\n",
    "\n",
    "    R_w = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \\\n",
    "    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\\\n",
    "    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\\\n",
    "    p110 * ls(x11, 1) + p111 * ls(x11, -1)\n",
    "\n",
    "    dR = R_w.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)\n",
    "\n",
    "    return [R, dR]\n",
    "\n",
    "def IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized = False):\n",
    "    R, dR = IRMparts(e, ls, x00, x10, x01, x11, generalized)\n",
    "    return R + lam * dR**2\n",
    "\n",
    "def IRMv1(envs, lam, ls, x00, x10, x01, x11, generalized = False):\n",
    "    results = 0\n",
    "    for e in envs:\n",
    "        results += IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized)\n",
    "        \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "db0617f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prior_shift(alpha, lam, betalist, clist, ls, inits, generalized = False):\n",
    "    envs = []\n",
    "    for i in range(2):\n",
    "        envs += [[alpha, betalist[i], clist[i]]]\n",
    "\n",
    "    x00, x10, x01, x11 = symbols('x00 x10 x01 x11')\n",
    "    obj = IRMv1(envs, lam ,ls, x00, x10, x01, x11, generalized)\n",
    "\n",
    "    func = lambdify([x00, x10, x01, x11], obj)\n",
    "    def func_np(params):\n",
    "        x00, x10, x01, x11 = params\n",
    "        return func(x00, x10, x01, x11)\n",
    "\n",
    "    n_init = inits.shape[0]\n",
    "    sols = np.zeros((n_init, 4))\n",
    "    objs = []\n",
    "    for i in range(n_init):\n",
    "        res = scipy_min(func_np, inits[i, :], method='Nelder-Mead', tol=1e-10)\n",
    "        sols[i, :] = res.x\n",
    "        objs.append(res.fun)\n",
    "\n",
    "    obj_min = min(objs)\n",
    "    sol_min = sols[objs.index(obj_min), :]\n",
    "    print(\"obj,   sol\")\n",
    "    print((round(obj_min, 2), sol_min.round(2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2e82774",
   "metadata": {},
   "source": [
    "## Without prior shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "027cc222",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(0.75, array([ 0.52, -0.5 ,  0.5 , -0.48]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(1.12, array([ 1.16, -1.11,  1.08, -1.03]))\n"
     ]
    }
   ],
   "source": [
    "n_init = 500\n",
    "inits = np.random.uniform(low = -2, high = 2, size = (n_init, 4))\n",
    "\n",
    "lam = 1e+20\n",
    "alpha = 0.25\n",
    "# betalist = [0.1, 0.2]\n",
    "# clist = [0.5, 0.5]\n",
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.5, 0.5, 0.5, 0.5]\n",
    "\n",
    "\n",
    "print(\"square loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, sqls, inits)\n",
    "\n",
    "print(\"logistic loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, logls, inits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bafbb3c",
   "metadata": {},
   "source": [
    "## With prior shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1804a3f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(0.44, array([ 0.93,  0.15, -0.4 , -0.85]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(0.72, array([ 2.53, -0.08, -0.93, -3.19]))\n"
     ]
    }
   ],
   "source": [
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.9, 0.1, 0.7, 0.3]\n",
    "\n",
    "print(\"square loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, sqls, inits)\n",
    "\n",
    "print(\"logistic loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, logls, inits)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a88bdf0e",
   "metadata": {},
   "source": [
    "## g-IRM fixes the issue with prior-shift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a4ff7c7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "square loss\n",
      "obj,   sol\n",
      "(0.75, array([ 0.52, -0.5 ,  0.5 , -0.48]))\n",
      "logistic loss\n",
      "obj,   sol\n",
      "(1.12, array([ 1.16, -1.11,  1.08, -1.03]))\n"
     ]
    }
   ],
   "source": [
    "betalist = [0.2, 0.1, 0.15, 0.05]\n",
    "clist = [0.9, 0.1, 0.7, 0.3]\n",
    "\n",
    "print(\"square loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, sqls, inits, generalized = True)\n",
    "\n",
    "print(\"logistic loss\")\n",
    "prior_shift(alpha, lam, betalist, clist, logls, inits, generalized = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
