{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fcf5dbe4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 444/444 [00:03<00:00, 119.57it/s]\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.abspath(\"../../\")) \n",
    "sys.path.append(os.path.abspath(\"../../..\")) \n",
    "sys.path.append(os.path.abspath(\"../\")) \n",
    "\n",
    "import importlib\n",
    "import torch\n",
    "import src.utils.data as data\n",
    "import commons.utils as utils\n",
    "import commons.semantic_loss as sem_loss\n",
    "from semantic_loss_pytorch import SemanticLoss\n",
    "import importlib\n",
    "import numpy as np\n",
    "\n",
    "importlib.reload(sem_loss)\n",
    "importlib.reload(data)\n",
    "importlib.reload(utils)\n",
    "\n",
    "\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "TEST_SIZE = 100\n",
    "train_data, val_data, test_data = data.get_x_y_data(invd_steps=TEST_SIZE, device=device, val_split=None)\n",
    "\n",
    "x_train, y_train = train_data\n",
    "x_test, y_test = test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7474fdae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.adjoint import neural_adjoint as adj\n",
    "import src.models.vae.vae_move as vae_move\n",
    "from models.gidnet.configs.metamat_ds_config import MetamatDsConfig\n",
    "from models.gidnet.configs.experiment_config import ExpConfig\n",
    "from models.gidnet.configs.gidnet_config import GidNetConfig\n",
    "import src.models.gidnet.gidnet as gd\n",
    "from src.models.base.autoencoder import Encoder,Decoder\n",
    "from src.models.base.simulator_Nf2 import ForwardSimulator\n",
    "importlib.reload(gd)\n",
    "\n",
    "vae_simulator = torch.load(\"../trained_models/vae_simulator.pt\", weights_only=False).to(device)\n",
    "simulator = torch.load(\"../trained_models/simulator.pt\", weights_only=False).to(device)\n",
    "vae = torch.load(\"../trained_models/vae.pt\", weights_only=False).to(device)\n",
    "\n",
    "metamat_config = MetamatDsConfig()\n",
    "exp_config = ExpConfig()\n",
    "gidnet_config = GidNetConfig()\n",
    "n_layer = 10\n",
    "n_mat = 7\n",
    "n_movements_ = gidnet_config.n_movements\n",
    "lr_ = gidnet_config.lr\n",
    "\n",
    "encoder = Encoder(n_layer, device).to(device)\n",
    "decoder = Decoder(n_layer, device).to(device)\n",
    "encoder.load_state_dict(torch.load('../trained_models/encoder_10layers.pt', weights_only=False))\n",
    "decoder.load_state_dict(torch.load('../trained_models/decoder_10layers.pt', weights_only=False))\n",
    "\n",
    "gidnet = gd.GidNet(\n",
    "    n_layer=n_layer,\n",
    "    n_mat=n_mat, #Changed this\n",
    "    n_seeds=30,\n",
    "    n_movements=32,\n",
    "    lambda1=gidnet_config.lambda1,\n",
    "    onehot_weight=1,\n",
    "    encoder=encoder,\n",
    "    decoder=decoder,\n",
    "    simulator=simulator,\n",
    "    device=device,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "13173ed1",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, manager, vtree = sem_loss.construct_vars(10, 7)\n",
    "\n",
    "losses = [\n",
    "    (\"hyperbolic_2\", \"hyp2.sdd\", \"hyp2.vtree\", lambda: sem_loss.force_hyperbolic_material(x, n_lay=10, n_mat=7, pattern_len=2)),\n",
    "    (\"hyperbolic_3\", \"hyp3.sdd\", \"hyp3.vtree\", lambda: sem_loss.force_hyperbolic_material(x, n_lay=10, n_mat=7, pattern_len=3)),\n",
    "    (\"hyperbolic_4\", \"hyp4.sdd\", \"hyp4.vtree\", lambda: sem_loss.force_hyperbolic_material(x, n_lay=10, n_mat=7, pattern_len=4)),\n",
    "    (\"palindrome_2\", \"pal2.sdd\", \"pal2.vtree\", lambda: sem_loss.force_palindrome_material(x, n_lay=10, n_mat=7, up_to=2)),\n",
    "    (\"palindrome_3\", \"pal3.sdd\", \"pal3.vtree\", lambda: sem_loss.force_palindrome_material(x, n_lay=10, n_mat=7, up_to=3)),\n",
    "    (\"palindrome_4\", \"pal4.sdd\", \"pal4.vtree\", lambda: sem_loss.force_palindrome_material(x, n_lay=10, n_mat=7, up_to=4)),\n",
    "    (\"no_adjacent\",  \"noadj.sdd\", \"noadj.vtree\", lambda: sem_loss.prevent_same_following_materials(x, n_lay=10, n_mat=7)),\n",
    "    (\"use_all\",     \"useall.sdd\", \"useall.vtree\", lambda: sem_loss.force_use_all_materials(x, n_lay=10, n_mat=7))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a761a0bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create .sdd and .vtree losses data\n",
    "for sloss_data in losses:\n",
    "    formula = sloss_data[3]()\n",
    "    manager.save(f\"sdds/{sloss_data[1]}\".encode(), formula)\n",
    "    vtree.save(f\"vtrees/{sloss_data[2]}\".encode())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c7e07473",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[VAE] - hyperbolic_2, No loss time: 1.14 s, Sem loss time: 35.35 s\n",
      "[Adjoint] - hyperbolic_2, No loss time: 1.24 s, Sem loss time: 34.34 s\n",
      "[Gidnet] - hyperbolic_2, No loss time: 3.61 s, Sem loss time: 35.83 s\n",
      "[VAE] - hyperbolic_3, No loss time: 1.12 s, Sem loss time: 65.68 s\n",
      "[Adjoint] - hyperbolic_3, No loss time: 1.21 s, Sem loss time: 65.09 s\n",
      "[Gidnet] - hyperbolic_3, No loss time: 3.57 s, Sem loss time: 68.33 s\n",
      "[VAE] - hyperbolic_4, No loss time: 1.13 s, Sem loss time: 110.18 s\n",
      "[Adjoint] - hyperbolic_4, No loss time: 1.24 s, Sem loss time: 109.33 s\n",
      "[Gidnet] - hyperbolic_4, No loss time: 3.59 s, Sem loss time: 113.00 s\n",
      "[VAE] - palindrome_2, No loss time: 1.13 s, Sem loss time: 27.06 s\n",
      "[Adjoint] - palindrome_2, No loss time: 1.24 s, Sem loss time: 26.43 s\n",
      "[Gidnet] - palindrome_2, No loss time: 3.56 s, Sem loss time: 27.08 s\n",
      "[VAE] - palindrome_3, No loss time: 1.13 s, Sem loss time: 56.53 s\n",
      "[Adjoint] - palindrome_3, No loss time: 1.23 s, Sem loss time: 55.95 s\n",
      "[Gidnet] - palindrome_3, No loss time: 3.57 s, Sem loss time: 58.45 s\n",
      "[VAE] - palindrome_4, No loss time: 1.13 s, Sem loss time: 187.61 s\n",
      "[Adjoint] - palindrome_4, No loss time: 1.24 s, Sem loss time: 186.59 s\n",
      "[Gidnet] - palindrome_4, No loss time: 3.57 s, Sem loss time: 204.47 s\n",
      "[VAE] - no_adjacent, No loss time: 1.15 s, Sem loss time: 45.42 s\n",
      "[Adjoint] - no_adjacent, No loss time: 1.23 s, Sem loss time: 44.04 s\n",
      "[Gidnet] - no_adjacent, No loss time: 3.63 s, Sem loss time: 46.40 s\n",
      "[VAE] - use_all, No loss time: 1.17 s, Sem loss time: 245.37 s\n",
      "[Adjoint] - use_all, No loss time: 1.38 s, Sem loss time: 261.85 s\n",
      "[Gidnet] - use_all, No loss time: 4.01 s, Sem loss time: 263.23 s\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import importlib\n",
    "importlib.reload(vae_move)\n",
    "importlib.reload(adj)\n",
    "\n",
    "point_x, point_y = x_test[0:1], y_test[0:1]\n",
    "\n",
    "for sloss_data in losses:\n",
    "    sloss = SemanticLoss(f\"sdds/{sloss_data[1]}\", f\"vtrees/{sloss_data[2]}\")\n",
    "\n",
    "    # ----------------- VAE TESTS ------------------------------\n",
    "\n",
    "    NUM_POINTS = 1\n",
    "    initial_point = torch.randn((NUM_POINTS, vae.latent_dim)).to(vae.device)\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _ = vae_move.search_point_multi(vae, vae_simulator, initial_point, point_y, point_x, learning_rate=0.01, epochs=200)\n",
    "    no_loss_end_time = time.time()\n",
    "\n",
    "    sloss_start_time = time.time()\n",
    "    _ = vae_move.search_point_multi(vae, vae_simulator, initial_point, point_y, point_x, learning_rate=0.01, epochs=200, sloss=sloss)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "\n",
    "    print(f\"[VAE] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n",
    "\n",
    "\n",
    "    # ------------------ Adjoint TEST -----------------------------\n",
    "\n",
    "    NUM_POINTS = 1\n",
    "    point_dim = 80\n",
    "    initial_point = torch.randn((NUM_POINTS, point_dim)).to(device)\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _, _ = adj.neural_adjoint_search(simulator, initial_point, point_y, lr=0.05, epochs=200)\n",
    "    no_loss_end_time = time.time()\n",
    "\n",
    "    sloss_start_time = time.time()\n",
    "    _, _ = adj.neural_adjoint_search(simulator, initial_point, point_y, lr=0.05, epochs=200, sloss=sloss)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "    print(f\"[Adjoint] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n",
    "\n",
    "    # ------------------ Gidnet TEST ------------------------------\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _ = gidnet.train(x_train, y_train, point_y, 200, 0.01, point_x)\n",
    "    no_loss_end_time = time.time()\n",
    "    \n",
    "    sloss_start_time = time.time()\n",
    "    _ = gidnet.train(x_train, y_train, point_y, 200, 0.01, point_x, semloss=sloss)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "    print(f\"[Gidnet] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n",
    "\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d00deb34",
   "metadata": {},
   "outputs": [],
   "source": [
    "#[VAE]     - hyperbolic_2, No loss time: 4.88 s, Sem loss time: 41.60 s\n",
    "#[Adjoint] - hyperbolic_2, No loss time: 4.20 s, Sem loss time: 42.19 s\n",
    "#[Gidnet]  - hyperbolic_2, No loss time: 3.54 s, Sem loss time: 37.34 s\n",
    "\n",
    "#[VAE]     - hyperbolic_3, No loss time: 4.83 s, Sem loss time: 73.85 s\n",
    "#[Adjoint] - hyperbolic_3, No loss time: 4.20 s, Sem loss time: 72.08 s\n",
    "#[Gidnet]  - hyperbolic_3, No loss time: 3.55 s, Sem loss time: 69.46 s\n",
    "\n",
    "#[VAE]     - hyperbolic_4, No loss time: 4.83 s, Sem loss time: 120.17 s\n",
    "#[Adjoint] - hyperbolic_4, No loss time: 4.22 s, Sem loss time: 135.19 s\n",
    "#[Gidnet]  - hyperbolic_4, No loss time: 4.02 s, Sem loss time: 146.38 s\n",
    "\n",
    "#[VAE]     - palindrome_2, No loss time: 4.99 s, Sem loss time: 40.16 s\n",
    "#[Adjoint] - palindrome_2, No loss time: 4.45 s, Sem loss time: 38.18 s\n",
    "#[Gidnet]  - palindrome_2, No loss time: 4.01 s, Sem loss time: 35.91 s\n",
    "\n",
    "#[VAE]     - palindrome_3, No loss time: 4.98 s, Sem loss time: 80.52 s\n",
    "#[Adjoint] - palindrome_3, No loss time: 4.46 s, Sem loss time: 78.62 s\n",
    "#[Gidnet]  - palindrome_3, No loss time: 4.01 s, Sem loss time: 75.68 s\n",
    "\n",
    "#[VAE]     - palindrome_4, No loss time: 4.80 s, Sem loss time: 199.12 s\n",
    "#[Adjoint] - palindrome_4, No loss time: 4.19 s, Sem loss time: 193.86 s\n",
    "#[Gidnet]  - palindrome_4, No loss time: 3.52 s, Sem loss time: 192.24 s\n",
    "\n",
    "#[VAE]     - no_adjacent, No loss time: 5.21 s, Sem loss time: 62.93 s\n",
    "#[Adjoint] - no_adjacent, No loss time: 4.46 s, Sem loss time: 61.35 s\n",
    "#[Gidnet]  - no_adjacent, No loss time: 4.01 s, Sem loss time: 58.54 s\n",
    "\n",
    "#[VAE]     - use_all, No loss time: 4.99 s, Sem loss time: 268.53 s\n",
    "#[Adjoint] - use_all, No loss time: 4.24 s, Sem loss time: 235.48 s\n",
    "#[Gidnet]  - use_all, No loss time: 3.59 s, Sem loss time: 229.93 s"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
