{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import argparse\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric.loader import DataLoader\n",
    "from sklearn.random_projection import GaussianRandomProjection\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import font_manager as fm\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib import colors\n",
    "import e3nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.kfold import KFoldDataset\n",
    "from models import TFNModel\n",
    "from utils.seed import fix_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "fold, num_normal, num_res, max_ell = 3, 6, 49, 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================\n",
    "# Setup\n",
    "# =============================\n",
    "seed = 42\n",
    "fix_seed(seed)\n",
    "rp = GaussianRandomProjection(n_components=2, random_state=seed)\n",
    "norm = colors.Normalize(vmin=0, vmax=1)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "plt.style.use('seaborn-v0_8-darkgrid')\n",
    "font_dir = Path(\"Assets/Fonts\")\n",
    "font_files = list(font_dir.glob(\"*.ttf\")) + list(font_dir.glob(\"*.otf\"))\n",
    "for font_file in font_files:\n",
    "    fm.fontManager.addfont(str(font_file))\n",
    "prop = fm.FontProperties(fname=str(font_dir / \"arial.ttf\"))\n",
    "plt.rcParams[\"font.family\"] = prop.get_name()\n",
    "\n",
    "cmap = colors.ListedColormap([\"royalblue\", \"gold\"])\n",
    "bounds = [0, 0.5, 1.0]\n",
    "norm = colors.BoundaryNorm(bounds, cmap.N)\n",
    "\n",
    "ell_list = [4, 7, 10]\n",
    "fig, gs, sc = None, None, None\n",
    "all_emb, all_normal_idx, all_angle = None, None, None\n",
    "xlim, ylim = None, None\n",
    "markers = ['o', 's', '^', 'D', 'P', '*', 'X', 'v', '<', '>']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# =============================\n",
    "# Data and model\n",
    "# =============================\n",
    "def get_model(max_ell=11):\n",
    "    output_irreps = e3nn.o3.Irreps('+'.join([f'{ell}e' if ell % 2 == 0 else f'{ell}o' for ell in range(max_ell + 1)]))\n",
    "    model = TFNModel(\n",
    "        max_ell=max_ell,\n",
    "        num_layer=1,\n",
    "        hidden_dim=64,\n",
    "        irreps_channels=8,\n",
    "        node_input_dim=1,\n",
    "        output_irreps=output_irreps,\n",
    "    )\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def get_emb(max_ell, model, fold=3, num_normal=6, num_res=49):\n",
    "    dataset = KFoldDataset(fold=fold, num_normal=num_normal, num_res=num_res)\n",
    "    loader = partial(DataLoader, batch_size=1 << 12, drop_last=False, num_workers=4)\n",
    "    myloader = loader(dataset=dataset, shuffle=False)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        emb = [model(batch.to(device)).detach().cpu() for batch in myloader]\n",
    "\n",
    "    emb = torch.cat(emb, dim=0).numpy()\n",
    "    tmp = emb\n",
    "    normal_idx = torch.cat([batch.normal_idx for batch in myloader], dim=0).cpu().numpy()\n",
    "    angle = torch.cat([batch.angle for batch in myloader], dim=0).cpu().numpy()\n",
    "\n",
    "    emb_list = [emb[:, ell**2:(ell+1)**2] for ell in range(max_ell + 1)]\n",
    "    \n",
    "    emb_list =  [normalization(rp.fit_transform(emb)) for emb in emb_list]\n",
    "\n",
    "    return tmp, emb_list, normal_idx, angle\n",
    "\n",
    "def get_all_emb(max_ell, fold=3, num_normal=6, num_res=49):\n",
    "\n",
    "    assert num_res % 2 == 1\n",
    "    model = get_model(max_ell=max_ell)\n",
    "\n",
    "    get_emb_partial = partial(get_emb, max_ell = max_ell, model = model, fold=fold, num_normal=num_normal)\n",
    "\n",
    "    tmp_s, emb_s, normal_idx_s, angle_s = get_emb_partial(num_res=num_res)\n",
    "    tmp_w, emb_w, normal_idx_w, angle_w = get_emb_partial(num_res=num_res << 1)\n",
    "\n",
    "    all_emb = {'s': emb_s, 'w': emb_w}\n",
    "    all_normal_idx = {'s': normal_idx_s, 'w': normal_idx_w}\n",
    "    all_angle = {'s': angle_s, 'w': angle_w}\n",
    "\n",
    "    return all_emb, all_normal_idx, all_angle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# =============================\n",
    "# Plotting utilities\n",
    "# =============================\n",
    "def normalization(X):\n",
    "    \"\"\"Normalize embeddings to mean 0 and unit average norm.\"\"\"\n",
    "    X -= X.mean(axis=0, keepdims=True)\n",
    "    X /= np.linalg.norm(X, axis=1).mean(axis=0) + 1e-4\n",
    "    return X\n",
    "\n",
    "def draw_ell(k, ell, emb, normal_idx, angle, ax):\n",
    "    unique_idx = np.unique(normal_idx)\n",
    "    marker_map = {idx: markers[i % len(markers)] for i, idx in enumerate(unique_idx)}\n",
    "\n",
    "    sc = None\n",
    "    for idx in unique_idx:\n",
    "        mask = (normal_idx == idx)\n",
    "        sc = ax.scatter(\n",
    "            emb[mask, 0], emb[mask, 1],\n",
    "            c=angle[mask],\n",
    "            cmap=cmap,\n",
    "            norm=norm,\n",
    "            marker=marker_map[idx],\n",
    "            alpha=0.6,\n",
    "            edgecolor='k',\n",
    "            linewidth=0.5\n",
    "        )\n",
    "\n",
    "    title = {4: 'Axial Deg.', 7: 'Normal', 10: 'Half Deg.'}[ell]\n",
    "    color = {4: 'yellowgreen', 7: 'black', 10: 'teal'}[ell]\n",
    "\n",
    "    ax.text(\n",
    "        0.03, 0.97, title,\n",
    "        transform=ax.transAxes,\n",
    "        ha=\"left\", va=\"top\",\n",
    "        fontsize=28, fontweight=\"bold\", color=color\n",
    "    )\n",
    "\n",
    "    if k == 'w':\n",
    "        ax.text(\n",
    "            0.5, -0.1, \n",
    "            f\"l = {ell}\",\n",
    "            ha='center', va='center',\n",
    "            fontsize=24, fontweight='bold',\n",
    "            transform=ax.transAxes\n",
    "        )\n",
    "\n",
    "    res = {'s': 49, 'w': 98}[k]\n",
    "    if ell == ell_list[0]:\n",
    "        ax.text(\n",
    "            -0.1, 0.5,\n",
    "            f\"res = {res}\",\n",
    "            ha='center', va='center',\n",
    "            rotation=90,\n",
    "            fontsize=24, fontweight='bold',\n",
    "            transform=ax.transAxes\n",
    "        )\n",
    "\n",
    "    ax.grid(True, linestyle=\"--\", alpha=0.6)\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "\n",
    "    ax.set_xlim(xlim)\n",
    "    ax.set_ylim(ylim)\n",
    "\n",
    "    return sc\n",
    "\n",
    "def draw_ell_all(k):\n",
    "    row = 1 if k == 'w' else 0\n",
    "    for idx, ell in enumerate(ell_list):\n",
    "        ax = fig.add_subplot(gs[row, idx])\n",
    "        sc = draw_ell(k, ell, all_emb[k][ell], all_normal_idx[k], all_angle[k], ax)\n",
    "    return sc\n",
    "\n",
    "def draw_all(save_path):\n",
    "    ncols = len(ell_list)\n",
    "    nrows = 2\n",
    "\n",
    "    global fig, gs\n",
    "    fig = plt.figure(figsize=(4*ncols + 1, 4*nrows))\n",
    "    gs = gridspec.GridSpec(nrows, ncols + 1, figure=fig, width_ratios=[1]*ncols + [0.05], wspace=0.03, hspace=0.03)\n",
    "\n",
    "    all_coord = np.vstack(\n",
    "        [all_emb['s'][ell] for ell in ell_list]\n",
    "        + [all_emb['w'][ell] for ell in ell_list]\n",
    "    )\n",
    "    \n",
    "    global xlim, ylim\n",
    "    xlim = (all_coord[:, 0].min() - 0.75, all_coord[:, 0].max() + 0.75)\n",
    "    ylim = (all_coord[:, 1].min() - 0.5, all_coord[:, 1].max() + 0.75)\n",
    "\n",
    "    for idx, k in enumerate(['s', 'w']):\n",
    "        sc = draw_ell_all(k)\n",
    "\n",
    "    cbar_ax = fig.add_subplot(gs[:, -1])\n",
    "    cbar = fig.colorbar(sc, cax=cbar_ax)\n",
    "    cbar.set_label(\"Rotation Rate Around the Axis\", fontsize=28)\n",
    "    cbar.ax.tick_params(labelsize=24)\n",
    "\n",
    "    unique_idx = np.unique(all_normal_idx['s'])\n",
    "    markers = ['o', 's', '^', 'D', 'P', '*', 'X', 'v', '<', '>']\n",
    "    marker_map = {idx: markers[i % len(markers)] for i, idx in enumerate(unique_idx)}\n",
    "\n",
    "    legend_ax = fig.add_axes([0.55, 0.90, 0.3, 0.05])\n",
    "    legend_ax.axis(\"off\")\n",
    "\n",
    "    for i, idx in enumerate(unique_idx):\n",
    "        legend_ax.scatter(\n",
    "            i * 1.2, 0,\n",
    "            marker=marker_map[idx],\n",
    "            s=200, c=(0.127568, 0.566949, 0.550556, 1.0), alpha=0.6, edgecolors=\"k\", linewidths=1.2\n",
    "        )\n",
    "\n",
    "    legend_ax.text(len(unique_idx) * 1.2 + 0.2, 0, \"Axis Type\",\n",
    "                ha=\"left\", va=\"center\", fontsize=24, fontweight=\"bold\")\n",
    "\n",
    "    legend_ax.set_xlim(-1, len(unique_idx) * 1.3 + 3)\n",
    "    legend_ax.set_ylim(-1, 1)\n",
    "\n",
    "    save_path_pdf = save_path+'.pdf'\n",
    "    save_path_png = save_path+'.png'\n",
    "\n",
    "    fig.patch.set_alpha(0.0)\n",
    "    # plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight')\n",
    "    plt.savefig(save_path_png, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cache_path = 'cache/half_deg.npy'\n",
    "save_path = 'half-deg'\n",
    "\n",
    "if os.path.exists(cache_path):\n",
    "    all_emb, all_normal_idx, all_angle = np.load(cache_path, allow_pickle=True)\n",
    "    print(f'Load all_emb, all_normal_idx, all_angle from cache: {cache_path}')\n",
    "else:\n",
    "    all_emb, all_normal_idx, all_angle = get_all_emb(\n",
    "        max_ell=max_ell,\n",
    "        fold=fold,\n",
    "        num_normal=num_normal,\n",
    "        num_res=num_res,\n",
    "    )\n",
    "    os.makedirs(os.path.dirname(cache_path), exist_ok=True)\n",
    "    np.save(cache_path, [all_emb, all_normal_idx, all_angle], allow_pickle=True)\n",
    "    print(f'Save all_emb, all_normal_idx, all_angle to cache: {cache_path}')\n",
    "\n",
    "draw_all(save_path)\n",
    "\n",
    "from IPython.display import Image, display\n",
    "display(Image(filename=save_path+'.png'))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
