{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcf495bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from aeon.datasets import load_classification\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "import ot\n",
    "from joblib import Parallel, delayed\n",
    "from pathlib import Path\n",
    "import json\n",
    "from functools import partial\n",
    "import argparse\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import OrdinalEncoder, StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from src.model import primal_fit_to\n",
    "from src.utils import primal_left\n",
    "from src.representation import augment, polynomial_feature_map\n",
    "from src.numpy_metric import hs_metric,operator_metric,eigenvalue_metric,subspace_metric, chordal_metric, martin_metric\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from tqdm import tqdm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f78b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_ID = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8786fdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"results/exp_{EXP_ID}/settings.json\", \"r\") as f:\n",
    "    args = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1cb0aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_alpha_dct = dict(\n",
    "    BasicMotions = 0.01,\n",
    "    Cricket = 0.01,\n",
    "    ERing = 0.01,\n",
    "    EigenWorms = 0.5,\n",
    "    Epilepsy = 0.01,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10893baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filtering(D,R,L,threshold=1e-2):\n",
    "    mask = np.abs(np.exp(D))>threshold\n",
    "    D_f = D[mask]\n",
    "    R_f = R[:,mask]\n",
    "    L_f = L[:,mask]\n",
    "    return D_f,R_f,L_f\n",
    "\n",
    "def set_parameters(n_samples,sampling_ratio = 0.2,max_sampfreq=100,max_context_window=50): \n",
    "    sampfreq = min(max_sampfreq, int((n_samples//2)*sampling_ratio))\n",
    "    context_window = min(max_context_window, n_samples//2)\n",
    "    return sampfreq, context_window\n",
    "\n",
    "def metric_cross_matrix(\n",
    "    metric_func: callable,\n",
    "    Ts_lst: list,\n",
    "    Tt_lst: list = None,\n",
    "    n_jobs: int = 1,\n",
    "    metric_kwargs: dict = {}\n",
    "    ) -> np.ndarray:\n",
    "\n",
    "    def compute(i, j):\n",
    "            val = metric_func(*Ts_lst[i], *Ts_lst[j], **metric_kwargs)\n",
    "            return (i, j, val)\n",
    "    \n",
    "    if Tt_lst is None:\n",
    "        Tt_lst = Ts_lst\n",
    "        n = len(Ts_lst)\n",
    "        idxs = np.vstack([np.triu_indices(n)]).T\n",
    "        results = Parallel(n_jobs=n_jobs)(\n",
    "            delayed(compute)(i, j) for i, j in tqdm(idxs)\n",
    "        )\n",
    "        results = np.array(results)\n",
    "        mat = np.zeros((n, n))\n",
    "        mat[results[:,0].astype(int), results[:,1].astype(int)] = results[:,2]\n",
    "        mat = mat + mat.T - np.diag(mat.diagonal())\n",
    "        return mat\n",
    "    else:\n",
    "        m, n = len(Ts_lst), len(Tt_lst)\n",
    "        idxs = np.array([[i, j] for i in range(m) for j in range(n)])\n",
    "        results = Parallel(n_jobs=n_jobs)(\n",
    "            delayed(compute)(i, j) for i, j in tqdm(idxs)\n",
    "        )\n",
    "        results = np.array(results)\n",
    "        mat = np.zeros((m, n))\n",
    "        mat[results[:,0].astype(int), results[:,1].astype(int)] = results[:,2]\n",
    "        return mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442908ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dct = dict([(dataset, dict()) for dataset in dataset_alpha_dct.keys()])\n",
    "\n",
    "for i, (dataset, alpha) in enumerate(dataset_alpha_dct.items()):\n",
    "    print(f\"Processing {dataset} with alpha={alpha}\")\n",
    "    X, y = load_classification(dataset)\n",
    "    X = np.swapaxes(X, 1, 2)\n",
    "    ord = OrdinalEncoder()\n",
    "    y = ord.fit_transform(y.reshape(-1,1))\n",
    "    n_ts,n_samples,n_d = X.shape\n",
    "    sampfreq, context_window = set_parameters(n_samples, args[\"sampling_ratio\"], args[\"max_sampfreq\"], args[\"max_context_window\"])\n",
    "\n",
    "    def compute_T(X): \n",
    "            X_temp = augment(X, context_window)\n",
    "            Z = polynomial_feature_map(X_temp, order=args[\"poly_order\"])\n",
    "            e = primal_fit_to(Z,1/sampfreq,tikhonov_reg=args[\"tikhonov_reg\"],rank=args[\"max_rank\"],symmetry=None)\n",
    "            D,R,L = e[\"values\"],e[\"right\"],primal_left(e,Z)\n",
    "            return filtering(D,R,L,threshold=args[\"eigen_tol\"])\n",
    "\n",
    "    print(\"Computing spectral decompositions...\")\n",
    "    T_lst = Parallel(n_jobs=args[\"n_jobs\"])(delayed(compute_T)(x_train) for x_train in tqdm(X))\n",
    "    T_lst = np.array(T_lst,dtype=object)\n",
    "\n",
    "    metrics = [\n",
    "            partial(hs_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),\n",
    "            partial(operator_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),\n",
    "            partial(eigenvalue_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),\n",
    "            partial(subspace_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),\n",
    "            partial(martin_metric,sampfreqs=sampfreq,sampfreqt=sampfreq),\n",
    "            partial(chordal_metric, alpha=alpha)\n",
    "            ]\n",
    "    metrics_names = [\n",
    "            \"Hilbert-Schmidt\",\n",
    "            \"Operator\",\n",
    "            \"Martin\",\n",
    "            \"SOT\",\n",
    "            \"GOT\",\n",
    "            \"SGOT\",\n",
    "        ]\n",
    "    \n",
    "    for j, (metric_func, metric_name) in enumerate(zip(metrics, metrics_names)):\n",
    "        try:\n",
    "            print(f\"Computing distance matrix for metric: {metric_name}\")\n",
    "            D = metric_cross_matrix(metric_func, T_lst)\n",
    "            perplexity = min(30, int(n_ts//3))\n",
    "            tsne = TSNE(n_components=2, metric=\"precomputed\", perplexity=perplexity, random_state=42,init=\"random\")\n",
    "            X_embedded = tsne.fit_transform(D)\n",
    "            embedding_dct[dataset][metric_name] = X_embedded.tolist()\n",
    "\n",
    "            with open(f\"results/exp_{EXP_ID}/tsne_embedding.json\", \"w\") as f:\n",
    "                json.dump(embedding_dct, f)\n",
    "        except Exception as e:\n",
    "            print(f\"Error computing metric {metric_name} for dataset {dataset}: {e}\")\n",
    "            continue\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7891928",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_names = [\n",
    "    \"Hilbert-Schmidt\",\n",
    "    \"Operator\",\n",
    "    \"Martin\",\n",
    "    \"SOT\",\n",
    "    \"GOT\",\n",
    "    \"SGOT\",\n",
    "]\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 14\n",
    "plt.rcParams['axes.titlesize'] = 14\n",
    "plt.rcParams['legend.fontsize'] = 12\n",
    "plt.rcParams['legend.title_fontsize'] = 14\n",
    "\n",
    "scale = 1.7\n",
    "fig, axes = plt.subplots(len(dataset_alpha_dct), len(metrics_names), figsize=(scale*len(metrics_names),scale*len(dataset_alpha_dct)))\n",
    "for i, (dataset, alpha) in enumerate(dataset_alpha_dct.items()):\n",
    "    X,y = load_classification(dataset)\n",
    "    for j, metric_name in enumerate(metrics_names):\n",
    "        ax = axes[i,j] if len(dataset_alpha_dct)>1 else axes[j]\n",
    "        try:\n",
    "            X= np.array(embedding_dct[dataset][metric_name])\n",
    "            names = np.unique(y)\n",
    "            for name in names:\n",
    "                mask = y.flatten()==name\n",
    "                ax.scatter(X[mask,0], X[mask,1], s=14, label=name,alpha=0.5)\n",
    "        except Exception as e:\n",
    "            print(f\"Error plotting metric {metric_name} for dataset {dataset}: {e}\")\n",
    "        if i==0:\n",
    "            ax.set_title(metric_name)\n",
    "        if j==0:\n",
    "            ax.set_ylabel(dataset)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "\n",
    "plt.subplots_adjust(wspace=0.03, hspace=0.03)\n",
    "plt.savefig(f\"results/exp_{EXP_ID}/tsne.pdf\" , bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fec40532",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_names = [\n",
    "    \"Hilbert-Schmidt\",\n",
    "    \"Operator\",\n",
    "    \"Martin\",\n",
    "    \"SOT\",\n",
    "    \"GOT\",\n",
    "    \"SGOT\",\n",
    "]\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Helvetica\"\n",
    "plt.rcParams[\"xtick.labelsize\"] = 10\n",
    "plt.rcParams[\"ytick.labelsize\"] = 10\n",
    "plt.rcParams[\"axes.labelsize\"] = 14\n",
    "plt.rcParams[\"legend.fontsize\"] = 14\n",
    "plt.rcParams['axes.titlesize'] = 14\n",
    "plt.rcParams['legend.fontsize'] = 12\n",
    "plt.rcParams['legend.title_fontsize'] = 14\n",
    "\n",
    "scale = 1.7\n",
    "datasets = [\"EigenWorms\",\"Epilepsy\"]\n",
    "metrics = [\"Hilbert-Schmidt\", \"GOT\", \"SGOT\"]\n",
    "fig, axes = plt.subplots(len(datasets), len(metrics), figsize=(scale*len(metrics),scale*len(datasets)))\n",
    "for i, dataset in enumerate(datasets):\n",
    "    X,y = load_classification(dataset)\n",
    "    for j, metric_name in enumerate(metrics):\n",
    "        ax = axes[i,j] if len(dataset_alpha_dct)>1 else axes[j]\n",
    "        try:\n",
    "            X= np.array(embedding_dct[dataset][metric_name])\n",
    "            names = np.unique(y)\n",
    "            for name in names:\n",
    "                mask = y.flatten()==name\n",
    "                ax.scatter(X[mask,0], X[mask,1], s=14, label=name,alpha=0.5)\n",
    "        except Exception as e:\n",
    "            print(f\"Error plotting metric {metric_name} for dataset {dataset}: {e}\")\n",
    "        if i==0:\n",
    "            ax.set_title(metric_name)\n",
    "        if j==0:\n",
    "            ax.set_ylabel(dataset)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "    #ax.legend(title=\"Class\", bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "\n",
    "\n",
    "plt.subplots_adjust(wspace=0.03, hspace=0.03)\n",
    "plt.savefig(f\"results/exp_{EXP_ID}/main_tsne.pdf\" , bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "KooPOT",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
