{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from numba import jit\n",
    "\n",
    "def lloss(w1,w2,a,b,mu,var):\n",
    "    loss = (1-a)*(1-b)*(w1+w2-1)**2 + a*b*(w1+w2+1)**2 + (1-a)*b*(w1-w2-1)**2 + a*(1-b)*(w1-w2+1)**2 + (w1+w2)**2 * (mu**2+var)\n",
    "    return 0.5 * loss\n",
    "\n",
    "@jit(nopython=True)\n",
    "def irmv1_loss(w1,w2,a,b,lam,mu,var):\n",
    "    loss = (1-a)*(1-b)*(w1+w2-1)**2 + a*b*(w1+w2+1)**2 + (1-a)*b*(w1-w2-1)**2 + a*(1-b)*(w1-w2+1)**2 + (w1+w2)**2 * (mu**2+var)\n",
    "    re = (1-a)*(1-b)*(w1+w2-1)*(w1+w2) + a*b*(w1+w2+1)*(w1+w2) + (1-a)*b*(w1-w2-1)*(w1-w2) + a*(1-b)*(w1-w2+1)*(w1-w2) + (w1+w2)**2 * (mu**2+var)\n",
    "    return 0.5 * loss + lam*(re**2)\n",
    "\n",
    "@jit(nopython=True)\n",
    "def rex_loss(w1,w2,a,b1,b2,mu1,var1,mu2,var2):\n",
    "    a = 1-2*a\n",
    "    b1 = 1-2*b1\n",
    "    b2 = 1-2*b2\n",
    "    \n",
    "    loss1 = (w1**2+w2**2)*(1+mu1**2+var1) + 1 + 2*w1*w2*(a*b1+mu1**2+var1) - 2*w1*a - 2*w2*b1\n",
    "    loss2 = (w1**2+w2**2)*(1+mu2**2+var2) + 1 + 2*w1*w2*(a*b2+mu2**2+var2) - 2*w1*a - 2*w2*b2\n",
    "\n",
    "    return lam*(loss1-loss2)**2\n",
    "\n",
    "\n",
    "@jit(nopython=True)\n",
    "def icorr_loss(w1,w2,a,b1,b2,lam,mu1,var1,mu2,var2):\n",
    "    loss1 = (1-a)*(1-b1)*(w1+w2-1)**2 + a*b1*(w1+w2+1)**2 + (1-a)*b1*(w1-w2-1)**2 + a*(1-b1)*(w1-w2+1)**2 + (w1+w2)**2 * (mu1**2+var1)\n",
    "    re1 = (1-a)*(1-b1)*(w1+w2) - a*b1*(w1+w2) + (1-a)*b1*(w1-w2) - a*(1-b1)*(w1-w2)\n",
    "    \n",
    "    loss2 = (1-a)*(1-b2)*(w1+w2-1)**2 + a*b2*(w1+w2+1)**2 + (1-a)*b2*(w1-w2-1)**2 + a*(1-b2)*(w1-w2+1)**2 + (w1+w2)**2 * (mu2**2+var2)\n",
    "    re2 = (1-a)*(1-b2)*(w1+w2) - a*b2*(w1+w2) + (1-a)*b2*(w1-w2) - a*(1-b2)*(w1-w2)\n",
    "    return 0.5 * (loss1+loss2) + lam*(re1-re2)**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.15142036000000003\n",
      "0.16172780000000003\n",
      "0.25449476\n",
      "0.29572452\n"
     ]
    }
   ],
   "source": [
    "#ERM\n",
    "LL = 100\n",
    "ww1 = 0\n",
    "ww2 = 0\n",
    "for ii in range(-1000,1000):\n",
    "    for jj in range(-1000,1000):\n",
    "        w1 = ii/1000\n",
    "        w2 = jj/1000\n",
    "        ll = lloss(w1,w2,0.1,0.2,0.2,0.01) + lloss(w1,w2,0.1,0.25,0.1,0.02)\n",
    "        if ll<LL:\n",
    "            LL = ll\n",
    "            ww1 = w1\n",
    "            ww2 = w2\n",
    "            \n",
    "print(lloss(ww1,ww2,0.1,0.2,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.25,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.7,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.9,0.,0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.18048050000000004\n",
      "0.18048050000000004\n",
      "0.18048050000000002\n",
      "0.18048050000000004\n"
     ]
    }
   ],
   "source": [
    "#IRM\n",
    "LL = 100\n",
    "ww1 = 0\n",
    "ww2 = 0\n",
    "for ii in range(-1000,1000):\n",
    "    w1 = ii/1000\n",
    "    w2 = 0\n",
    "    ll = lloss(w1,w2,0.1,0.2,0.2,0.01) + lloss(w1,w2,0.1,0.25,0.1,0.02)\n",
    "    if ll<LL:\n",
    "        LL = ll\n",
    "        ww1 = w1\n",
    "        ww2 = w2\n",
    "\n",
    "print(lloss(ww1,ww2,0.1,0.2,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.25,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.7,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.9,0.,0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5000000000000001\n",
      "0.5\n",
      "0.5\n",
      "0.5\n"
     ]
    }
   ],
   "source": [
    "#IRMv1\n",
    "LL = 100\n",
    "ww1 = 0\n",
    "ww2 = 0\n",
    "lam = 100000000\n",
    "for ii in range(1000):\n",
    "    for jj in range(1000):\n",
    "        w1 = ii/1000\n",
    "        w2 = jj/1000\n",
    "        ll = irmv1_loss(w1,w2,0.1,0.2,lam,0.2,0.01) + irmv1_loss(w1,w2,0.1,0.25,lam,0.1,0.02)\n",
    "        if ll<LL:\n",
    "            LL = ll\n",
    "            ww1 = w1\n",
    "            ww2 = w2\n",
    "            \n",
    "print(lloss(ww1,ww2,0.1,0.2,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.25,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.7,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.9,0.,0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.18048050000000004\n",
      "0.18048050000000004\n",
      "0.18048050000000002\n",
      "0.18048050000000004\n"
     ]
    }
   ],
   "source": [
    "#ICorr\n",
    "LL = 100\n",
    "ww1 = 0\n",
    "ww2 = 0\n",
    "lam = 100000000\n",
    "for ii in range(1000):\n",
    "    for jj in range(1000):\n",
    "        w1 = ii/1000\n",
    "        w2 = jj/1000\n",
    "        ll = icorr_loss(w1,w2,0.1,0.2,0.25,lam,0.2,0.01,0.1,0.02)\n",
    "        if ll<LL:\n",
    "            LL = ll\n",
    "            ww1 = w1\n",
    "            ww2 = w2\n",
    "            \n",
    "print(lloss(ww1,ww2,0.1,0.2,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.25,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.7,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.9,0.,0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5000000000000001\n",
      "0.5\n",
      "0.5\n",
      "0.5\n"
     ]
    }
   ],
   "source": [
    "#VREx\n",
    "LL = 100\n",
    "ww1 = 0\n",
    "ww2 = 0\n",
    "lam = 100000000\n",
    "for ii in range(1000):\n",
    "    for jj in range(1000):\n",
    "        w1 = ii/1000\n",
    "        w2 = jj/1000\n",
    "        ll = rex_loss(w1,w2,0.1,0.2,0.25,0.2,0.01,0.1,0.02)\n",
    "        if ll<LL:\n",
    "            LL = ll\n",
    "            ww1 = w1\n",
    "            ww2 = w2\n",
    "            \n",
    "print(lloss(ww1,ww2,0.1,0.2,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.25,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.7,0.,0.))\n",
    "print(lloss(ww1,ww2,0.1,0.9,0.,0.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
