{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbac5d78-30ae-4d0d-8099-6b9893e817b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import json\n",
    "from itertools import product\n",
    "sys.path.append('/workspace/XXXX-1/Finite-groups/src')\n",
    "from model import MLP3\n",
    "from utils import *\n",
    "from group_data import *\n",
    "from train import Parameters\n",
    "from model_viz import model_table, plot_table, plot_indicator_table\n",
    "from plotly import graph_objs as go\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4fecde5e-af24-408b-a829-b39f1852be40",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 400/400 [05:06<00:00,  1.30it/s]\n"
     ]
    }
   ],
   "source": [
    "#path = '/home/XXXX-1/Finite-groups/src/models/2024_07_24_17_31_02_Z_100__Z_2_50_'\n",
    "#path = '/home/XXXX-1/Finite-groups/src/models/2024_07_24_17_51_00_Z_100__clamped_100_'\n",
    "#path = '/home/XXXX-1/Finite-groups/src/models/2024_07_24_18_12_27_Z_48_2__twZ_48_'     #delta_frac 0;0.3\n",
    "#path = '/home/XXXX-1/Finite-groups/src/models/2024_07_24_18_23_16_Z_48_2__twZ_48_'      #delta_frac 0;0.1\n",
    "#path = '/home/XXXX-1/Finite-groups/src/models/2024_07_24_19_14_31_Z_100__Z_50_2_'\n",
    "path = '/workspace/models/2024_08_05_18_31_53_S_5__times_A_5__Z_2__'\n",
    "models, params = load_models(path, sel=slice(400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c3f236c4-7164-4c65-aea2-84d183aa77ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance=5\n",
    "plot_indicator_table(models[-1][instance].to(device), params, instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "44219481-0c3e-43ca-b593-81bfbb11fb06",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models[0].stack([m[instance] for m in models]).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "02f3351d-8be5-4b56-a9be-b859b133eeb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "table = model_table(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1adc33b0-994b-49f6-9759-a5a1bf8b6df2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 400/400 [18:26<00:00,  2.77s/it]\n"
     ]
    }
   ],
   "source": [
    "fig = plot_table(table, params, save=True, instance=instance)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6b56de2c-e386-4289-b3d6-7cabbe64f8c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/workspace/XXXX-1/Finite-groups/src/notebooks'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bef375a-1354-4470-8f6c-c269837980ab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "39ebb04b-e623-41e9-966d-aac283ed6a19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 400/400 [20:10<00:00,  3.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/XXXX-1/Finite-groups/src/notebooks/plots/model_table_20240807_230320.html\n"
     ]
    }
   ],
   "source": [
    "instance = 6\n",
    "model = models[0].stack([m[instance] for m in models]).to(device)\n",
    "table = model_table(model)\n",
    "fig = plot_table(table, params, save=True, instance=instance)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f0608961-15c5-4727-983a-6b73e7c8b240",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_indicator_table(models[-1][instance].to(device), params, instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff95a91f-b068-4593-ae65-32365c07667e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 60/400 [02:45<15:31,  2.74s/it]"
     ]
    }
   ],
   "source": [
    "instance = 44\n",
    "model = models[0].stack([m[instance] for m in models]).to(device)\n",
    "table = model_table(model)\n",
    "fig = plot_table(table, params, save=True, instance=instance)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a5972d1-892a-462e-9343-3cceb374c28b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05e556ac-35ce-4d91-9453-ffd2801ec2a2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45370a2a-1f8e-4b68-abf6-a1fcc3cb144f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74a9a21f-4460-4d20-97e8-8a0854d23ada",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7409fc3a-1177-42c2-965c-ed73f730aa9b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff688a46-5e79-43f7-9710-4fbe546653f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "a178601f-6eb7-40a6-9794-fa0e4c4de3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = Parameters()\n",
    "params.group_string = 'Z(5)'\n",
    "params.instances = 20\n",
    "model = MLP3(5, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "f92d3ac5-2b14-48f4-b111-a8cc444e3c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [model[i] for i in range(params.instances)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "6ca32c00-7052-457e-a9d8-9a4e9c0523fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_table(models, params)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "df86adde-342c-4663-b853-759d9c0d96b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ['a', 'b', 'c', 'd', 'e']\n",
    "y = [1, 2, 3, 4, 5]\n",
    "fig = go.Figure(go.Heatmap(z=table, x=x, y=y, ))\n",
    "fig.update_layout(\n",
    "    # yaxis = dict(\n",
    "    #     scaleanchor = 'x',\n",
    "    #     scaleratio = 1,\n",
    "    # ),\n",
    "    # xaxis = dict(\n",
    "    #     constrain = 'domain'\n",
    "    # ),\n",
    "    width = 500,  # You can adjust this value\n",
    "    height = 500  # Make sure width and height are the same for a square plot\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db50259b-eb32-42b6-878f-2c1618eef8d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
