{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "k_list = [2, 3, 4, 6]\n",
    "max_ell = 11\n",
    "if_3d = True # True is 3D, False is 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import math\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.train_utils import run_experiment\n",
    "from models import TFNModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TFNReadout(torch.nn.Module):\n",
    "    def __init__(self, max_ell, output_irreps):\n",
    "        super().__init__()\n",
    "        self.tfn = TFNModel(\n",
    "            max_ell=max_ell, \n",
    "            num_layer=1,\n",
    "            hidden_dim=64,\n",
    "            irreps_channels=8,\n",
    "            node_input_dim=1, \n",
    "            output_irreps=output_irreps, \n",
    "        )\n",
    "        self.readout = torch.nn.Linear(output_irreps.dim, 2)\n",
    "\n",
    "    def forward(self, bacth):\n",
    "        self.emb = self.tfn(bacth)\n",
    "        return self.readout(self.tfn(bacth))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_rotsym_envs(fold=3):\n",
    "    dataset = []\n",
    "\n",
    "    # Environment 0\n",
    "    node_feat = torch.Tensor([ 0 ] + [ 0 ] * fold).unsqueeze(-1)\n",
    "    edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )\n",
    "    x = torch.Tensor([1,0,0])\n",
    "    node_pos = [\n",
    "        torch.Tensor([0,0,0]),  # origin\n",
    "        x,   # first spoke \n",
    "    ]\n",
    "    for count in range(1, fold):\n",
    "        R = e3nn.o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)\n",
    "        node_pos.append(x @ R.T)\n",
    "    node_pos = torch.stack(node_pos)\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(node_feat=node_feat, edge_index=edge_index, node_pos=node_pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    q = 2*math.pi/(fold + random.randint(1, fold - 1))\n",
    "    assert q < 2*math.pi/fold\n",
    "    if if_3d:\n",
    "        Q = e3nn.o3.matrix_y(torch.Tensor([q])).squeeze(0)\n",
    "    else:\n",
    "        Q = e3nn.o3.matrix_z(torch.Tensor([q])).squeeze(0)\n",
    "    node_pos = node_pos @ Q.T\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(node_feat=node_feat, edge_index=edge_index, node_pos=node_pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "\n",
    "    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "    val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "    test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "    \n",
    "    return dataloader, val_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_ell(fold = 3, max_ell = max_ell): \n",
    "    dataloader, val_loader, test_loader = create_rotsym_envs(fold)\n",
    "    res = []\n",
    "    for ell in range(max_ell + 1):\n",
    "        output_irreps = e3nn.o3.Irreps(f'{ell}e')\n",
    "        model = TFNReadout(max_ell=ell, output_irreps=output_irreps)\n",
    "\n",
    "        print(f'*' * 100)\n",
    "        print(f'{fold}-fold | output_irreps = {output_irreps}')\n",
    "\n",
    "        best_val_acc, test_acc, train_time = run_experiment(\n",
    "            model, \n",
    "            dataloader,\n",
    "            val_loader, \n",
    "            test_loader,\n",
    "            n_epochs=200,\n",
    "            n_times=10,\n",
    "            device=device,\n",
    "            verbose=False\n",
    "        )\n",
    "\n",
    "        res.append(f'{np.mean(test_acc):.1f} ± {np.std(test_acc):.1f}')\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}\n",
    "for k in k_list:\n",
    "    res[f'{k}-fold'] = check_ell(k)\n",
    "    json_name = \"rot_3d.json\" if if_3d else \"rot_2d.json\"\n",
    "    with open(json_name, \"w\") as json_file:\n",
    "        json.dump(res, json_file, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mace",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
