{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import plotly.graph_objs as go\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%time\n",
    "\n",
    "# generate a_i matrices and b_i vectors for loss functions\n",
    "\n",
    "from scipy.stats import ortho_group\n",
    "np.random.seed(42)\n",
    "\n",
    "dim = 500\n",
    "n_domains = 2\n",
    "\n",
    "# creating matrices for domain losses\n",
    "alpha = 0.7\n",
    "a_diag = np.diag(np.arange(1, dim + 1) ** -alpha)[None, :].repeat(n_domains, axis=0)  # creating eigenvalues\n",
    "\n",
    "# # eigenvectors (simple)\n",
    "# bases = np.array([  \n",
    "#     ortho_group.rvs(dim=dim)\n",
    "#     for _ in range(n_domains)\n",
    "# ])\n",
    "\n",
    "# # eigenvectors (complicated)\n",
    "def get_vecs():\n",
    "    _x, _y = np.meshgrid(np.arange(dim), np.arange(dim))\n",
    "    orth_gen = np.random.normal(size=[dim, dim]) * np.exp(-np.abs(_x - _y) / 3)  # \n",
    "    # orth_gen = np.random.normal(size=[dim, dim]) * np.exp(-(_x - _y) ** 2)\n",
    "    orth_gen = orth_gen - orth_gen.T\n",
    "    orth = scipy.linalg.expm(orth_gen)\n",
    "    return orth\n",
    "\n",
    "# bases = np.array([  \n",
    "#     np.eye(dim),\n",
    "#     get_vecs(),\n",
    "# ])\n",
    "bases = np.array([get_vecs() for _ in range(n_domains)])\n",
    "\n",
    "# np.einsum('kba,kbc,kcd->kad', bases, a, bases)  # applying C^T A C\n",
    "a = np.einsum('kac,kcd->kad', np.einsum('kba,kbc->kac', bases, a_diag), bases)  # faster, but also shit\n",
    "\n",
    "b = np.random.normal(size=[n_domains, dim])"
   ],
   "id": "717903f7c826b80f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# define formulae for loss functions and their derivatives\n",
    "\n",
    "def flow(x_, a_, b_, t_=1):\n",
    "    return (scipy.linalg.expm(-a_ * t_) @ (x_ - b_).T).T + b_\n",
    "\n",
    "def loss(x_, a_, b_):\n",
    "    return np.einsum('kij,ksi,ksj->ks', \n",
    "                     a_, \n",
    "                     x_[None, :, :] - b_[:, None, :], \n",
    "                     x_[None, :, :] - b_[:, None, :]) / 2\n",
    "\n",
    "def grad(x_, a_, b_):\n",
    "    return np.einsum('ij,sj->si', a_, x_ - b_)\n",
    "\n",
    "def comm(x_, a_, b_):\n",
    "    a1_, a2_, b1_, b2_ = *a_, *b_\n",
    "    return (a1_ @ a2_ @ (x_ - b2_).T - a2_ @ a1_ @ (x_ - b1_).T).T"
   ],
   "id": "4be2cc37320d7e6f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# perform commutation experiment for table in the paper\n",
    "\n",
    "import itertools\n",
    "np.random.seed(123)\n",
    "for (dt, t) in itertools.product([.001, .01, .1], [.1, .3, 1, 3]):\n",
    "    n_samples = 100\n",
    "    x = np.random.normal(size=[n_samples, dim])\n",
    "    x = flow(x, a_base, b_base, t_=t)\n",
    "    \n",
    "    x0 = flow(x, a[0], b[0], t_=dt)\n",
    "    x1 = flow(x, a[1], b[1], t_=dt)\n",
    "    \n",
    "    x01 = flow(x0, a[1], b[1], t_=dt)\n",
    "    x10 = flow(x1, a[0], b[0], t_=dt)\n",
    "    \n",
    "    x_base = flow(x, a_base, b_base, t_=2*dt)\n",
    "    \n",
    "    excess_loss10 = loss(x10, a, b) - loss(x_base, a, b)\n",
    "    excess_loss01 = loss(x01, a, b) - loss(x_base, a, b)\n",
    "\n",
    "    comm_score = np.einsum('ki,ki->k', comm(x, a, b), grad(x, a_base, b_base)) * dt ** 2 / 2\n",
    "\n",
    "    ratio = excess_loss10.mean(0) / comm_score\n",
    "    print(f'{t:5.1f} & {dt:8.3f} & '\n",
    "          f'${np.median(ratio):8.3f} \\pm '\n",
    "          f'{(np.quantile(ratio, .9) - np.quantile(ratio, .1))/2:8.3f}$ \\\\\\\\')"
   ],
   "id": "ba97f80962726a82"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "c8223ffa813722cc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "4e82d5aa193e9048"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "39cfa989feabf21"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "3a31ded42dd0eaaa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "c54d72ce01b14dd4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "fa4c696d47c948c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "6e568b8876b44e94"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
