{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric.loader import DataLoader\n",
    "from sklearn.random_projection import GaussianRandomProjection\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import font_manager as fm\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib import colors\n",
    "import e3nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.kfold import KFoldDataset\n",
    "from models import TFNModel\n",
    "from utils.seed import fix_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "fold, num_normal, num_res, max_ell = 3, 6, 49, 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================\n",
    "# Setup\n",
    "# =============================\n",
    "seed = 42\n",
    "fix_seed(seed)\n",
    "rp = GaussianRandomProjection(n_components=2, random_state=seed)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "plt.style.use('seaborn-v0_8-darkgrid')\n",
    "font_dir = Path(\"Assets/Fonts\")\n",
    "font_files = list(font_dir.glob(\"*.ttf\")) + list(font_dir.glob(\"*.otf\"))\n",
    "for font_file in font_files:\n",
    "    fm.fontManager.addfont(str(font_file))\n",
    "prop = fm.FontProperties(fname=str(font_dir / \"arial.ttf\"))\n",
    "plt.rcParams[\"font.family\"] = prop.get_name()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================\n",
    "# Data and model\n",
    "# =============================\n",
    "def get_model(max_ell=11):\n",
    "    output_irreps = e3nn.o3.Irreps('+'.join([f'{ell}e' if ell % 2 == 0 else f'{ell}o' for ell in range(max_ell + 1)]))\n",
    "    model = TFNModel(\n",
    "        max_ell=max_ell,\n",
    "        num_layer=1,\n",
    "        hidden_dim=64,\n",
    "        irreps_channels=8,\n",
    "        node_input_dim=1,\n",
    "        output_irreps=output_irreps,\n",
    "    )\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def get_emb(model, fold=3, num_normal=6, num_res=7):\n",
    "    dataset = KFoldDataset(fold=fold, num_normal=num_normal, num_res=num_res)\n",
    "    loader = partial(DataLoader, batch_size=1 << 12, drop_last=False, num_workers=4)\n",
    "    myloader = loader(dataset=dataset, shuffle=False)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        all_emb = [model(batch.to(device)).detach().cpu() for batch in myloader]\n",
    "    emb = torch.cat(all_emb, dim=0).numpy()\n",
    "    normal_idx = torch.cat([batch.normal_idx for batch in myloader], dim=0).cpu().numpy()\n",
    "    angle = torch.cat([batch.angle for batch in myloader], dim=0).cpu().numpy()\n",
    "    return emb, normal_idx, angle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================\n",
    "# Plotting utilities\n",
    "# =============================\n",
    "def normalization(X):\n",
    "    \"\"\"Normalize embeddings to mean 0 and unit average norm.\"\"\"\n",
    "    X -= X.mean(axis=0, keepdims=True)\n",
    "    X /= np.linalg.norm(X, axis=1).mean(axis=0) + 1e-4\n",
    "    return X\n",
    "\n",
    "def draw_ell(emb_2d, ell, normal_idx, angle, ax=None, markers=None, xlim=None, ylim=None):\n",
    "    if markers is None:\n",
    "        markers = ['o', 's', '^', 'D', 'P', '*', 'X', 'v', '<', '>']\n",
    "    unique_idx = np.unique(normal_idx)\n",
    "    marker_map = {idx: markers[i % len(markers)] for i, idx in enumerate(unique_idx)}\n",
    "\n",
    "    norm = colors.Normalize(vmin=0, vmax=1)\n",
    "\n",
    "    created_fig = False\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 4))\n",
    "        created_fig = True\n",
    "\n",
    "    sc = None\n",
    "    for idx in unique_idx:\n",
    "        mask = (normal_idx == idx)\n",
    "        sc = ax.scatter(\n",
    "            emb_2d[mask, 0], emb_2d[mask, 1],\n",
    "            c=angle[mask],\n",
    "            cmap='viridis',\n",
    "            norm=norm,\n",
    "            marker=marker_map[idx],\n",
    "            alpha=0.6,\n",
    "            edgecolor='k',\n",
    "            linewidth=0.5\n",
    "        )\n",
    "\n",
    "    base_title = f\"l = {ell}\"\n",
    "    ax.text(\n",
    "        0.03, 0.97, base_title,\n",
    "        transform=ax.transAxes,\n",
    "        ha=\"left\", va=\"top\",\n",
    "        fontsize=28, fontweight=\"bold\", color=\"black\"\n",
    "    )\n",
    "\n",
    "    if ell in [0, 1]:\n",
    "        ax.text(\n",
    "            0.25, 0.97, \" (Full Deg.)\",\n",
    "            transform=ax.transAxes,\n",
    "            ha=\"left\", va=\"top\",\n",
    "            fontsize=24, fontweight=\"bold\", color=\"blueviolet\"\n",
    "        )\n",
    "    elif ell in [2, 4]:\n",
    "        ax.text(\n",
    "            0.25, 0.97, \" (Axial Deg.)\",\n",
    "            transform=ax.transAxes,\n",
    "            ha=\"left\", va=\"top\",\n",
    "            fontsize=24, fontweight=\"bold\", color=\"yellowgreen\"\n",
    "        )\n",
    "\n",
    "    ax.grid(True, linestyle=\"--\", alpha=0.6)\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "\n",
    "    if xlim is not None:\n",
    "        ax.set_xlim(xlim)\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(ylim)\n",
    "\n",
    "    return sc\n",
    "\n",
    "def draw_all(emb_all, max_ell, normal_idx, angle, save_path):\n",
    "    ncols = 4\n",
    "    nrows = int(np.ceil((max_ell + 1) / ncols))\n",
    "    \n",
    "    fig = plt.figure(figsize=(4*ncols + 1, 4*nrows))\n",
    "\n",
    "    gs = gridspec.GridSpec(nrows, ncols + 1, figure=fig, width_ratios=[1]*ncols + [0.05], wspace=0.03, hspace=0.03)\n",
    "    \n",
    "    emb_all_2d = []\n",
    "    for ell in range(max_ell + 1):\n",
    "        emb = emb_all[:, ell**2:(ell+1)**2]\n",
    "        emb_2d = rp.fit_transform(emb)\n",
    "        emb_2d = normalization(emb_2d)\n",
    "        emb_all_2d.append(emb_2d)\n",
    "    all_coords = np.vstack(emb_all_2d)\n",
    "    xlim = (all_coords[:, 0].min() - 0.75, all_coords[:, 0].max() + 0.75)\n",
    "    ylim = (all_coords[:, 1].min() - 0.5, all_coords[:, 1].max() + 0.75)\n",
    "\n",
    "    sc = None\n",
    "    for ell in range(max_ell + 1):\n",
    "        row = ell // ncols\n",
    "        col = ell % ncols\n",
    "        ax = fig.add_subplot(gs[row, col])\n",
    "        sc = draw_ell(emb_all_2d[ell], ell, normal_idx, angle, ax=ax, xlim=xlim, ylim=ylim)\n",
    "\n",
    "    cbar_ax = fig.add_subplot(gs[:, -1])\n",
    "    cbar = fig.colorbar(sc, cax=cbar_ax)\n",
    "    cbar.set_label(\"Rotation Rate Around the Axis\", fontsize=28)\n",
    "    cbar.ax.tick_params(labelsize=24)\n",
    "\n",
    "    unique_idx = np.unique(normal_idx)\n",
    "    markers = ['o', 's', '^', 'D', 'P', '*', 'X', 'v', '<', '>']\n",
    "    marker_map = {idx: markers[i % len(markers)] for i, idx in enumerate(unique_idx)}\n",
    "\n",
    "    legend_ax = fig.add_axes([0.55, 0.88, 0.3, 0.05])\n",
    "    legend_ax.axis(\"off\")\n",
    "\n",
    "    for i, idx in enumerate(unique_idx):\n",
    "        legend_ax.scatter(\n",
    "            i * 1.2, 0,\n",
    "            marker=marker_map[idx],\n",
    "            s=200, c=(0.127568, 0.566949, 0.550556, 1.0), alpha=0.6, edgecolors=\"k\", linewidths=1.2\n",
    "        )\n",
    "\n",
    "    legend_ax.text(len(unique_idx) * 1.2 + 0.2, 0, \"Axis Type\",\n",
    "                ha=\"left\", va=\"center\", fontsize=24, fontweight=\"bold\")\n",
    "\n",
    "    legend_ax.set_xlim(-1, len(unique_idx) * 1.3 + 3)\n",
    "    legend_ax.set_ylim(-1, 1)\n",
    "\n",
    "    save_path_pdf = save_path+'.pdf'\n",
    "    save_path_png = save_path+'.png'\n",
    "\n",
    "    fig.patch.set_alpha(0.0)\n",
    "    # plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight')\n",
    "    plt.savefig(save_path_png, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_path = 'cache/full_axis_deg.npz'\n",
    "save_path = 'full-axis-deg'\n",
    "\n",
    "if os.path.exists(cache_path):\n",
    "    cache = np.load(cache_path)\n",
    "    emb = cache[\"emb\"]\n",
    "    normal_idx = cache[\"normal_idx\"]\n",
    "    angle = cache[\"angle\"]\n",
    "    print(f'Load emb, normal_idx, angle from cache: {cache_path}')\n",
    "else:\n",
    "    model = get_model(max_ell=max_ell)\n",
    "\n",
    "    emb, normal_idx, angle = get_emb(\n",
    "        model=model,\n",
    "        fold=fold,\n",
    "        num_normal=num_normal,\n",
    "        num_res=num_res,\n",
    "    )\n",
    "\n",
    "    os.makedirs(os.path.dirname(cache_path), exist_ok=True)\n",
    "    np.savez(cache_path, emb=emb, normal_idx=normal_idx, angle=angle)\n",
    "    print(f'Save emb, normal_idx, angle to cache: {cache_path}')\n",
    "\n",
    "draw_all(emb, max_ell=max_ell, normal_idx=normal_idx, angle=angle, save_path=save_path)\n",
    "\n",
    "from IPython.display import Image, display\n",
    "display(Image(filename=save_path+'.png'))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
