{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2225b337",
   "metadata": {
    "id": "2225b337"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "# Import libraries, the code is built on PyTorch\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms.functional as TF\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "\n",
    "from sklearn import svm\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.model_selection import GridSearchCV, PredefinedSplit\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.cross_decomposition import CCA\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "#from torchvision import transforms\n",
    "#from torchvision.datasets import MNIST\n",
    "# from noise import pnoise2\n",
    "# from itertools import groupby\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "plt.rcParams[\"font.serif\"] = \"Times New Roman\"\n",
    "plt.rcParams[\"mathtext.fontset\"] = \"cm\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9050cf8f",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9050cf8f",
    "outputId": "d3ed0cf9-36a6-4f93-f9c2-f440f51311f8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "# Check if CUDA is running\n",
    "device = 'cuda'\n",
    "if(not torch.cuda.is_available()):\n",
    "    device = 'cpu'\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "82d8be39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data dictionary from the file\n",
    "loaded_data = np.load('data.npz', allow_pickle=True)\n",
    "loaded_data_params = np.load('noisy_params.npz', allow_pickle=True)\n",
    "\n",
    "tr_X = torch.stack(list(loaded_data['train_X']), dim=0).to('cpu').numpy()\n",
    "tr_Y = torch.stack(list(loaded_data['train_Y']), dim=0).to('cpu').numpy()\n",
    "tr_labels = loaded_data['train_labels']\n",
    "tr_theta = loaded_data['train_theta']\n",
    "tr_scale = loaded_data['train_scale']\n",
    "tr_noise = loaded_data['train_noise']\n",
    "\n",
    "ts_X = torch.stack(list(loaded_data['test_X']), dim=0).to('cpu').numpy()\n",
    "ts_Y = torch.stack(list(loaded_data['test_Y']), dim=0).to('cpu').numpy()\n",
    "ts_labels = loaded_data['test_labels']\n",
    "ts_theta = loaded_data['test_theta']\n",
    "ts_scale = loaded_data['test_scale']\n",
    "ts_noise = loaded_data['test_noise']\n",
    "\n",
    "val_X = torch.stack(list(loaded_data['validation_X']), dim=0).to('cpu').numpy()\n",
    "val_Y = torch.stack(list(loaded_data['validation_Y']), dim=0).to('cpu').numpy()\n",
    "val_labels = loaded_data['validation_labels']\n",
    "val_theta = loaded_data['validation_theta']\n",
    "val_scale = loaded_data['validation_scale']\n",
    "val_noise = loaded_data['validation_noise']\n",
    "\n",
    "theta = loaded_data_params['theta']\n",
    "scales = loaded_data_params['scales']\n",
    "noise_factor = loaded_data_params['noise_factor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "207b3420",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CCA\n",
    "# Perform CCA on train:\n",
    "dz=256\n",
    "cca = CCA(n_components=dz);\n",
    "cca.fit(tr_X,tr_Y);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "23sApNIykAMI",
   "metadata": {
    "id": "23sApNIykAMI"
   },
   "outputs": [],
   "source": [
    "zxctr, zyctr = cca.transform(tr_X,tr_Y)\n",
    "zxcts, zycts = cca.transform(ts_X,ts_Y)\n",
    "zxcv, zycv = cca.transform(val_X,val_Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "46945c59",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "46945c59",
    "outputId": "2ebd6a5b-1b00-48cd-fe80-96a139cd40de"
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'plotting_function' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_16544/640552369.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mplotting_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzxcv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'zx'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Labels'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_theta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Angles'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_scale\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Scales'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_noise\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Noise Factors'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mplotting_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzycv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'zy'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Labels'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_theta\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Angles'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_scale\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Scales'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_noise\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Noise Factors'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'plotting_function' is not defined"
     ]
    }
   ],
   "source": [
    "plotting_function(zxcv, 'zx', val_labels, 'Labels', val_theta, 'Angles', val_scale, 'Scales', val_noise, 'Noise Factors')\n",
    "plotting_function(zycv, 'zy', val_labels, 'Labels', val_theta, 'Angles', val_scale, 'Scales', val_noise, 'Noise Factors')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18004ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run a TSNE visualizatoin:\n",
    "# Initialize the t-SNE model\n",
    "tsne_x = TSNE(n_components=2, random_state=42, n_jobs=-1)\n",
    "tsne_y = TSNE(n_components=2, random_state=42, n_jobs=-1)\n",
    "\n",
    "# Fit and transform your data to 2D using t-SNE\n",
    "tsne_result_x = tsne_x.fit_transform(zxcv)\n",
    "tsne_result_y = tsne_y.fit_transform(zycv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd72fb37",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tsne_plot(savename, tsne_result, lbl, c=val_labels):\n",
    "    # Create a custom discrete color map with 10 colors and save it\n",
    "    cmap = plt.cm.get_cmap('viridis', 10)\n",
    "    # Create a figure for the t-SNE plot and colorbar\n",
    "    fig, axs = plt.subplots(1, 2, figsize=(4.5, 4), gridspec_kw={'width_ratios': [20, 1]});\n",
    "\n",
    "    # Create a scatter plot of the 2D t-SNE representation\n",
    "    def tsne_pl(ax, tsne_result, lbl, c=val_labels):\n",
    "        scatter = ax.scatter(tsne_result[:, 0], tsne_result[:, 1], alpha=0.7, c=c, cmap=cmap)\n",
    "        ax.set_title(\"t-SNE Visualization - \"+str(lbl))\n",
    "        ax.set_xlabel(\"Dimension 1\")\n",
    "        ax.set_ylabel(\"Dimension 2\")\n",
    "        return scatter\n",
    "\n",
    "    scatter = tsne_pl(axs[0], tsne_result, lbl, val_labels);\n",
    "\n",
    "    # Create a custom colorbar in the second subfigure\n",
    "    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap), cax=axs[1], orientation='vertical', ticks=np.linspace(0, 9, 10))\n",
    "    cbar.ax.set_yticks(np.linspace(0,0.9,10)+0.05)\n",
    "    cbar.ax.set_yticklabels(np.arange(10));\n",
    "\n",
    "    plt.tight_layout();\n",
    "    plt.savefig(\"t-SNE Visualization - \"+str(savename)+'.pdf', dpi=300, bbox_inches='tight');\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15c9e177",
   "metadata": {},
   "outputs": [],
   "source": [
    "tsne_plot(r'$Z_X - CCA$', tsne_result_x, r'$Z_X$ - CCA')\n",
    "tsne_plot(r'$Z_Y - CCA$', tsne_result_y, r'$Z_Y$ - CCA')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e39c59",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:torchEnv]",
   "language": "python",
   "name": "conda-env-torchEnv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
