{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Syntethic experiments"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import functools\n",
    "import pandas as pd\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.nn.functional import mse_loss, l1_loss\n",
    "from torch import Tensor, cosine_similarity\n",
    "from typing import Callable, Dict\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from pytorch_lightning import seed_everything\n",
    "from latent_invariances.utils.relreps import *\n",
    "from latent_invariances.utils.transforms import *\n",
    "from latent_invariances.utils.data_generation import *\n",
    "from latent_invariances.utils.similarity_functions import *\n",
    "\n",
    "from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef\n",
    "\n",
    "\n",
    "seed_everything(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choose the distance metrics you want to use to generate relative representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "projection_dict: Dict[str, Callable[[Tensor, Tensor], Tensor]] = {\n",
    "    \"Absolute\": lambda points, **kwargs: points,\n",
    "    \"Cosine\": functools.partial(\n",
    "        abs_to_rel, dist_func=custom_cosine, normalizing_func=None\n",
    "    ),\n",
    "    \"Center Cosine\": functools.partial(\n",
    "        abs_to_rel, dist_func=custom_cosine, normalizing_func=custom_normalize\n",
    "    ),\n",
    "    \"Euclidean\": functools.partial(\n",
    "        abs_to_rel, dist_func=custom_euclidean, normalizing_func=None\n",
    "    ),\n",
    "    \"Wasserstein\": functools.partial(\n",
    "        abs_to_rel, dist_func=custom_wasserstein, normalizing_func=None\n",
    "    ),\n",
    "    \"CoB Lstsq\": basis_change_lstsq,\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choose the transformations to apply wrt the anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformations_dict: Dict[str, Callable[[Tensor], Tensor]] = {\n",
    "    \"Original\": lambda data_points, **kwargs: data_points,\n",
    "    \"Isotropic Scale\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[custom_isotropic_scale]\n",
    "    ),\n",
    "    \"Orthogonal\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[custom_orthogonal]\n",
    "    ),\n",
    "    \"Translation\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[custom_translation]\n",
    "    ),\n",
    "    \"Permutation\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[permute_indices]\n",
    "    ),\n",
    "    \"Affine Transformation\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[affine_transformation]\n",
    "    ),\n",
    "    \"Linear Transformation\": functools.partial(\n",
    "        apply_transformations, transf_funcs=[linear_transformation]\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dictionary of the possible ways to initialize the latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_latent_space: Dict[str, Callable[[int], Tuple[torch.Tensor, List[int]]]] = {\n",
    "    \"Grid\": functools.partial(generate_data_points, space_pattern=generate_grid),\n",
    "    \"Diagonal\": functools.partial(\n",
    "        generate_data_points, space_pattern=generate_diagonal\n",
    "    ),\n",
    "    \"Spiral\": functools.partial(generate_data_points, space_pattern=generate_spiral),\n",
    "    \"Square\": functools.partial(generate_data_points, space_pattern=generate_square),\n",
    "    \"Node\": functools.partial(generate_data_points, space_pattern=generate_node),\n",
    "    \"Circle\": functools.partial(generate_data_points, space_pattern=generate_circle),\n",
    "    \"Ellipse\": functools.partial(generate_data_points, space_pattern=generate_ellipse),\n",
    "    \"Clustered\": functools.partial(\n",
    "        generate_data_points, space_pattern=generate_cluster\n",
    "    ),\n",
    "    \"Random\": functools.partial(generate_data_points, space_pattern=generate_random),\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate, Transform and Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import numpy as np\n",
    "\n",
    "nrows = len(projection_dict.keys())\n",
    "num_plots = len(transformations_dict.keys())\n",
    "num_subplots = nrows * len(transformations_dict.keys())\n",
    "\n",
    "fig, axes = plt.subplots(\n",
    "    nrows, num_plots, figsize=(5 * num_plots, num_subplots), sharex=False, sharey=False\n",
    ")\n",
    "\n",
    "# Generate original space\n",
    "npoints = 24\n",
    "init_form = \"Spiral\"\n",
    "\n",
    "data_points, anchor_indices = init_latent_space[init_form](npoints=npoints)\n",
    "x, y = data_points[:, 0], data_points[:, 1]\n",
    "\n",
    "# Generate colors for points\n",
    "coordinates_sum = data_points.sum(axis=1)\n",
    "normalized_coordinates_sum = (coordinates_sum - coordinates_sum.min()) / (\n",
    "    coordinates_sum.max() - coordinates_sum.min()\n",
    ")\n",
    "cmap = cm.get_cmap(\"viridis\")\n",
    "colors = cmap(normalized_coordinates_sum.tolist())\n",
    "\n",
    "anchor_colors = [\"darkviolet\", \"fuchsia\"]\n",
    "\n",
    "assert len(anchor_indices) == len(\n",
    "    anchor_colors\n",
    "), \"number of anchors and anchor colors must be the same\"\n",
    "\n",
    "index_subplot = 0\n",
    "for projection_name, function in projection_dict.items():\n",
    "    axes[index_subplot, 0].set_ylabel(projection_name, fontsize=20)\n",
    "\n",
    "    for i, (transf_name, transformation) in enumerate(transformations_dict.items()):\n",
    "        transformed_points = transformation(data_points=data_points)\n",
    "        tramsformed_anchors = transformed_points[anchor_indices]\n",
    "        transformed_relative = function(\n",
    "            anchors=tramsformed_anchors, points=transformed_points\n",
    "        )\n",
    "        x, y = transformed_relative[:, 0], transformed_relative[:, 1]\n",
    "        axes[index_subplot, i].scatter(x, y, c=colors, marker=\".\", s=200)\n",
    "        axes[index_subplot, i].scatter(\n",
    "            x[anchor_indices], y[anchor_indices], c=anchor_colors, marker=\"*\", s=300\n",
    "        )\n",
    "        axes[index_subplot, i].set_aspect(\"equal\", \"box\")\n",
    "        if index_subplot == 0:\n",
    "            axes[index_subplot, i].set_title(transf_name, fontsize=20)\n",
    "\n",
    "    index_subplot += 1\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\n",
    "    PROJECT_ROOT / \"notebooks\" / \"synthetic\" / \"results\" / f\"{init_form}_appendix.svg\",\n",
    "    bbox_inches=\"tight\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !rsvg-convert -f pdf -o 'results/Grid_appendix.pdf' 'results/Grid_appendix.svg'\n",
    "# !rm 'results/Grid_appendix'.svg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !rsvg-convert -f pdf -o 'results/Spiral_main.pdf' 'results/Spiral_main.svg'\n",
    "# !rm 'results/Spiral_main'.svg"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate similarities dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "similarities = [\n",
    "    \"Projection\",\n",
    "    \"Transformation\",\n",
    "    \"mse_mean\",\n",
    "    \"mse_std\",\n",
    "    \"l1_mean\",\n",
    "    \"l1_std\",\n",
    "    \"cosine_mean\",\n",
    "    \"cosine_std\",\n",
    "    \"pearson_mean\",\n",
    "    \"pearson_std\",\n",
    "]\n",
    "info = {key: [] for key in similarities}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del projection_dict[\"Absolute\"]  # remove absolute projection\n",
    "del transformations_dict[\"Original\"]  # remove original representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for projection_name, function in projection_dict.items():\n",
    "    relative = function(anchors=data_points[anchor_indices], points=data_points)\n",
    "\n",
    "    for i, (transf_name, transformation) in enumerate(transformations_dict.items()):\n",
    "        transformed_points = transformation(data_points=data_points)\n",
    "        tramsformed_anchors = transformed_points[anchor_indices]\n",
    "        transformed_relative = function(\n",
    "            anchors=tramsformed_anchors, points=transformed_points\n",
    "        )\n",
    "\n",
    "        info[\"Projection\"].append(projection_name)\n",
    "        info[\"Transformation\"].append(transf_name)\n",
    "\n",
    "        mse = mse_loss(relative, transformed_relative, reduction=\"none\")\n",
    "        l1 = l1_loss(relative, transformed_relative, reduction=\"none\")\n",
    "        cosine = cosine_similarity(relative, transformed_relative)\n",
    "        pearson = pearson_corrcoef(relative.T, transformed_relative.T)\n",
    "\n",
    "        info[\"mse_mean\"].append(mse.mean().item())\n",
    "        info[\"mse_std\"].append(mse.std().item())\n",
    "        info[\"l1_mean\"].append(l1.mean().item())\n",
    "        info[\"l1_std\"].append(l1.std().item())\n",
    "        info[\"cosine_mean\"].append(cosine.mean().item())\n",
    "        info[\"cosine_std\"].append(cosine.std().item())\n",
    "        info[\"pearson_mean\"].append(pearson.mean().item())\n",
    "        info[\"pearson_std\"].append(pearson.std().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pd.DataFrame(info)\n",
    "output_path = PROJECT_ROOT / \"notebooks\" / \"synthetic\" / \"results\" / f\"{init_form}.csv\"\n",
    "results.to_csv(output_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(output_path)\n",
    "df.groupby([\"Projection\", \"Transformation\"]).mean().round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
