{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "from models import AutoEncoder, GenNet\n",
    "from helpers.utils import to_torch, MedianHeuristicMMR, gen_params\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(A_dim, Z_dim, params, net, sample_size, A_str):\n",
    "    A = np.random.uniform(-1, 1, size=(sample_size, A_dim))\n",
    "    V = np.random.multivariate_normal(mean=np.zeros(shape=(Z_dim,)), cov=params['cov_ez'],\n",
    "                                      size=(sample_size,))\n",
    "    Z = A @ params['M'] * A_str + V\n",
    "\n",
    "    X = net(to_torch(Z)).detach().numpy()\n",
    "    X = X / np.std(X, axis=0, keepdims=True)\n",
    "\n",
    "    return X, A, Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dim = 10\n",
    "Z_dim = 2\n",
    "\n",
    "np.random.seed(1)\n",
    "\n",
    "n_iter = 1000\n",
    "sample_size = 1000\n",
    "\n",
    "a_s = 1\n",
    "\n",
    "res_df = pd.DataFrame()\n",
    "\n",
    "params = gen_params(Z_dim, Z_dim)\n",
    "\n",
    "torch.manual_seed(1)\n",
    "gen_net = GenNet(Z_dim, X_dim)\n",
    "gen_net.init_weights()\n",
    "\n",
    "X, A, _ = gen_data(Z_dim, Z_dim, params, gen_net, sample_size, a_s)\n",
    "\n",
    "model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)\n",
    "model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=1e+2)\n",
    "se_callback = MedianHeuristicMMR()\n",
    "\n",
    "trainer_MMR = pl.Trainer(max_epochs=n_iter, callbacks=[se_callback], enable_progress_bar=False,\n",
    "                         enable_checkpointing=False,\n",
    "                         enable_model_summary=False, accelerator=\"cpu\")\n",
    "trainer_baseline = pl.Trainer(max_epochs=n_iter, enable_progress_bar=False, accelerator=\"cpu\",\n",
    "                              enable_checkpointing=False,\n",
    "                              enable_model_summary=False,\n",
    "                              )\n",
    "\n",
    "trainer_baseline.fit(model_baseline)\n",
    "model_MMR.load_state_dict(model_baseline.state_dict())\n",
    "trainer_MMR.fit(model_MMR)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_new, A_new, Z_new = gen_data(Z_dim, Z_dim, params, gen_net, 1000, a_s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predZ_MMR = model_MMR.encode(to_torch(X_new)).detach().numpy()\n",
    "predZ_baseline = model_baseline.encode(to_torch(X_new)).detach().numpy()\n",
    "\n",
    "init_views = dict(MMR_Z1=lambda ax: ax.view_init(10, 10),\n",
    "                  MMR_Z2=lambda ax: ax.view_init(190, -75),\n",
    "                  Vanilla_Z1=lambda ax: ax.view_init(-170, 15),\n",
    "                  Vanilla_Z2=lambda ax: ax.view_init(0, 120))\n",
    "\n",
    "for method in ['MMR', 'Vanilla']:\n",
    "    \n",
    "    fig = plt.figure(figsize=plt.figaspect(0.5))\n",
    "\n",
    "\n",
    "    for z_idx in [0, 1]:\n",
    "        ax = fig.add_subplot(1, 2, z_idx + 1, projection='3d')\n",
    "\n",
    "        Z = Z_new[:, [z_idx]]\n",
    "        Z_pred = predZ_MMR if method == 'MMR' else predZ_baseline\n",
    "        X1 = Z_pred[:, 0]\n",
    "        X2 = Z_pred[:, 1]\n",
    "\n",
    "        lr = LinearRegression()\n",
    "        lr.fit(Z_pred, Z)\n",
    "\n",
    "        c = 1.1\n",
    "        x_surf = np.linspace(X1.min()*c, X1.max()*c, 20)\n",
    "        y_surf = np.linspace(X2.min()*c, X2.max()*c, 20)\n",
    "        x_surf, y_surf = np.meshgrid(x_surf, y_surf)\n",
    "        xy_surf = np.hstack([x_surf.ravel()[:, np.newaxis], y_surf.ravel()[:, np.newaxis]])\n",
    "\n",
    "\n",
    "        out = lr.predict(X=xy_surf).flatten()\n",
    "\n",
    "        cols = sns.color_palette('tab10')\n",
    "        ax.plot_surface(x_surf, y_surf,\n",
    "                        out.reshape(x_surf.shape),\n",
    "                        alpha=0.2,\n",
    "                        rstride=4, cstride=4, color=cols[3], edgecolors='grey', linewidths=0.1)\n",
    "\n",
    "        view_init = init_views['{}_Z{}'.format(method, z_idx + 1)]\n",
    "        view_init(ax)\n",
    "        \n",
    "        ax.scatter(X1, X2, Z, s=15, alpha=1, edgecolors='white', linewidths=0.2,  color=cols[0])\n",
    "        \n",
    "        ax.set_xlabel(r'Estimated $Z_1$')\n",
    "        ax.set_ylabel(r'Estimated $Z_2$')\n",
    "        ax.set_zlabel(r'True $Z_{}$'.format(z_idx + 1))\n",
    "        \n",
    "        ax.set_zlim(Z.min()*c, Z.max()*c)\n",
    "        ax.set_xlim(X1.min()*c,X1.max()*c)\n",
    "        ax.set_ylim(X2.min()*c,X2.max()*c)\n",
    "    \n",
    "    plt.subplots_adjust(top=1.25)\n",
    "    plt.suptitle(\"Method = AE-{}\".format(method))\n",
    "    plt.savefig(\"{}_3d_plot.pdf\".format(method))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
