{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4ffacd7c-a96a-4aa2-8362-e027a842ba29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch as t\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from itertools import product\n",
    "sys.path.append('/workspace/XXXX-1/Finite-groups/src')\n",
    "from model import MLP3, MLP4, InstancedModule\n",
    "from utils import *\n",
    "from model_utils import *\n",
    "from group_utils import *\n",
    "from group_data import *\n",
    "from jaxtyping import Float\n",
    "from typing import Union\n",
    "from einops import repeat\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "import plotly.graph_objects as go\n",
    "import copy\n",
    "import math\n",
    "from itertools import product\n",
    "from llc import *\n",
    "import gc\n",
    "from collections import defaultdict\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2a82bbcf-2e0f-4072-b21f-e8fe66f9e9fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = t.device(\"cuda\" if t.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0622276b-a17d-407f-b74f-ef0713e41c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '/workspace/models/2024-08-22_00-54-00_SL2_5_MLP3_512'\n",
    "models, params = load_models(path, sel='final')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "341026f6-ffa4-4d69-94ff-6992a3d715b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intersection size: 14400/14400 (1.00)\n",
      "Added 14400 elements from intersection\n",
      "Added 0 elements from group 0: smallgrp(120, 5)\n",
      "Taking random subset: 5760/14400 (0.40)\n",
      "Train set size: 5760/14400 (0.40)\n"
     ]
    }
   ],
   "source": [
    "data = GroupData(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "1a0fbc1c-82de-4b94-ab80-de427e75a9c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "42"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_dict = test_loss(models[-1].to(device), data)\n",
    "#instance = loss_dict['G0_acc'].argmax().item()\n",
    "instance = 42\n",
    "model = models[0][instance]\n",
    "instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "5ce622a4-af42-4d1f-8389-8974a4e80d17",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_dict = test_loss(model.to(device), data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "4b6eeabd-21f2-468c-b9e7-88a0abe5ec4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'G0_loss': tensor([0.0731], device='cuda:0'),\n",
       " 'G0_acc': tensor([0.9999], device='cuda:0')}"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "dd3f5951-2e6a-4149-9d86-6d3115ba7bf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_inputs = t.tensor(list(product(range(120), repeat=2)), device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "277c2fe5-62c5-44f6-b623-5d97af28d6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = einops.rearrange(data.groups[0].cayley_table, \"a b -> (a b)\").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "2a6f1fb1-57ca-4b45-a60e-408c3866a682",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(test_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "7e33367a-d752-4975-8402-d14554d43b0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.4049, device='cuda:0', grad_fn=<SelectBackward0>),\n",
       " tensor(0.4725, device='cuda:0', grad_fn=<SelectBackward0>))"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.softmax(out[948,0], dim=0)[117], t.softmax(out[948,0], dim=0)[97]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "3a3374d0-d978-49d4-933b-097f478fd79e",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = out.argmax(dim=-1).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "9b5a9510-8a50-4cdd-8b70-cee612afed37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(14399, device='cuda:0')"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(out == labels).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "cc999c39-9862-4606-a2e1-e9595e7259a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5996"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wrong = (out != labels).nonzero().item()\n",
    "wrong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "66d2ea41-29b4-4f4e-a3ac-60a8eb555276",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 49, 116], device='cuda:0'),\n",
       " tensor(57, device='cuda:0'),\n",
       " tensor(82, device='cuda:0'))"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_inputs[wrong], out[wrong], labels[wrong]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "8601af85-070b-44cf-b7b7-9ad1962fd76f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ 0*Z(5), Z(5) ], [ Z(5), Z(5)^2 ] ]"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[49])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "5fdf7cdf-ba19-4802-9c2e-b097959b668e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ Z(5)^0, Z(5)^2 ], [ Z(5)^2, Z(5) ] ]"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[116])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1af221-dd83-4b7e-b9b1-eac68fa3abde",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef5f8bf1-7324-4536-bfc2-bb3a1a7faa6a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "8da89a2f-213d-41b9-8564-171c3afbdf7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7, 108)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "948 // 120, 948 % 120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e24ee7b4-f7e6-4bc9-a085-4475de90dd35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'( 1, 2, 4, 8)( 3,17, 9,12)( 5,15, 6,14)( 7,19,13,18)(10,16,11,21)(20,23,24,22)'"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.groups[0].elements[7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "8289d904-a04c-44b3-8403-36fd2aeca5d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'( 1,22,16, 4,23,21)( 2,24,18, 8,20,19)( 3,12,10, 9,17,11)( 5,14,13, 6,15, 7)'"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.groups[0].elements[108]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "94730eb7-463c-499c-96ce-4be195e66d45",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = data.groups[0].gap_repr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "15dc3895-b4ce-4dcc-9484-fdada837e74d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "true"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.IsSL()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "e32c3ae6-9d19-432e-82f4-49e7dd802085",
   "metadata": {},
   "outputs": [],
   "source": [
    "sl = gap.SL(2, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "80e4fa58-bada-43aa-b518-8b2e382f7002",
   "metadata": {},
   "outputs": [],
   "source": [
    "iso = gap.IsomorphismGroups(g, sl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "820a8aae-5ec3-431d-9547-97f87573f8b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ Z(5), 0*Z(5) ], [ Z(5), Z(5)^3 ] ]"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "9f7bcc11-eaa4-492f-b84c-ced7657acc60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ 0*Z(5), Z(5) ], [ Z(5), Z(5)^0 ] ]"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[108])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "b1d4d5a8-3d4e-453a-856f-877b39b81529",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ 0*Z(5), Z(5)^2 ], [ Z(5)^0, Z(5) ] ]"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[117])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "f07442b3-8fd4-42a5-8291-1498bc41d069",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[ [ 0*Z(5), Z(5)^0 ], [ Z(5)^2, Z(5)^3 ] ]"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gap.Image(iso, g.Elements()[97])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "ae88abe9-8c7d-48c2-a1d1-c3e5c9cc21f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(117, device='cuda:0')"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels[948]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "02c4bf6c-c355-486b-ae94-e240b4b33771",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(97, device='cuda:0')"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out[948]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "5c4215db-2ee8-4853-9090-1488a02b216d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([  7, 108], device='cuda:0')"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_inputs[948]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "bfa9a488-b425-4ea7-a7cd-449a1f76fff1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'( 1, 2, 4, 8)( 3,17, 9,12)( 5,15, 6,14)( 7,19,13,18)(10,16,11,21)(20,23,24,22)'"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.groups[0].elements[7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "6e8feced-e3d4-4ae4-8b0e-0cb8060f05b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1,2,4,8)(3,17,9,12)(5,15,6,14)(7,19,13,18)(10,16,11,21)(20,23,24,22)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.Elements()[7]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "1e8d74bb-3df1-4a5d-ba40-96ac8c40b58b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1,24,16,3,11)(2,23,18,5,7)(4,20,21,9,10)(6,13,8,22,19)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.Elements()[7] * g.Elements()[108]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "41ed5dd8-19ac-4ff6-bd2e-f32ac1764685",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1,24,16,3,11)(2,23,18,5,7)(4,20,21,9,10)(6,13,8,22,19)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.Elements()[117]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "7c373534-1866-4c60-95a7-2e63cd6c896b",
   "metadata": {},
   "outputs": [],
   "source": [
    "gap.eval('Print(Z(5));')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "47d73712-3cee-4ef1-9f69-1767a52a6c13",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0760], device='cuda:0', grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_train_loss(model, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd5027c4-3c4b-48ef-9ee1-1e05dc82e0ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
