{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e0124b37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.metrics import distance_correlation, id_correlation\n",
    "from utils.utils import cat\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7cf0da94-8abf-4fa4-a574-2cf303c3ce91",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "N_points=5000\n",
    "\n",
    "theta = torch.nn.functional.normalize(torch.as_tensor(np.random.uniform(-np.pi, np.pi, (N_points,2)), dtype=torch.float64))\n",
    "X = theta[:,0].unsqueeze(1)\n",
    "Z = theta[:,1].unsqueeze(1)\n",
    "Y = torch.as_tensor(np.random.uniform(-1, 1, (N_points,1)), dtype=torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028892ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Distance correlation between X and rest: \", distance_correlation(X, cat([Y, Z])))\n",
    "print(\"Distance correlation between Y and rest: \", distance_correlation(Y, cat([X, Z])))\n",
    "print(\"Distance correlation between Z and rest: \", distance_correlation(Z, cat([X, Y])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d903c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Id correlation between X and rest: \", id_correlation(X, cat([Y, Z])))\n",
    "print(\"Id correlation between Y and rest: \", id_correlation(Y, cat([X, Z])))\n",
    "print(\"Id correlation between Z and rest: \", id_correlation(Z, cat([X, Y])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90168c1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.view_init(15, -110)\n",
    "ax.scatter(X, Y, Z, s=50)\n",
    "\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "ax.set_zticks([])\n",
    "plt.savefig('results/cylinder.svg', dpi=200, bbox_inches='tight', format='svg')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c64c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(13, 10))\n",
    "plt.scatter(X, Y, s=100, c=Z, cmap='viridis')\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.colorbar(ticks=[])\n",
    "plt.savefig('results/cylinder_xy.svg', dpi=200, bbox_inches='tight', format='svg')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d269886",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(13, 10))\n",
    "plt.scatter(X, Z, s=100, c=Y, cmap='viridis')\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.colorbar(ticks=[])\n",
    "plt.savefig('results/cylinder_xz.svg', dpi=200, bbox_inches='tight', format='svg')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
