{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcf5dbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.abspath(\"../../\")) \n",
    "sys.path.append(os.path.abspath(\"../../..\")) \n",
    "sys.path.append(os.path.abspath(\"../\")) \n",
    "\n",
    "from re import X\n",
    "import torch\n",
    "import time\n",
    "import csv\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.abspath(\"../..\")) \n",
    "sys.path.append(os.path.abspath(\"../../..\")) \n",
    "sys.path.append(os.path.abspath(\"../../../..\")) \n",
    "from configs.metamat_ds_config import MetamatDsConfig\n",
    "from configs.experiment_config import ExpConfig\n",
    "from configs.gidnet_config import GidNetConfig\n",
    "from data_utils.load_data import get_x_rt_data\n",
    "from data_utils.utils import unscaler\n",
    "from methods.gidnet.gidnet import GidNet\n",
    "from base_models.autoencoder import Encoder,Decoder\n",
    "from base_models.simulator_Nf import ForwardSimulator\n",
    "from data_utils.real_simulator_tmm import srmse_evaluate\n",
    "import commons.semantic_loss as semloss\n",
    "import numpy as np\n",
    "from semantic_loss_pytorch import SemanticLoss\n",
    "\n",
    "\n",
    "\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "TEST_SIZE = 100\n",
    "#### load DS 5 layers #####\n",
    "device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "metamat_config = MetamatDsConfig()\n",
    "metamat_config.print_config()\n",
    "exp_config = ExpConfig()\n",
    "gidnet_config = GidNetConfig()\n",
    "\n",
    "N_TEST = gidnet_config.n_test\n",
    "\n",
    "n_layer = 5\n",
    "n_mat = 5\n",
    "\n",
    "X_train, X_test, Y_train, Y_test = get_x_rt_data(metamat_config, exp_config.invd_num_test)\n",
    "\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32)\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n",
    "Y_train = torch.tensor(Y_train, dtype=torch.float32).to(device)\n",
    "Y_test = torch.tensor(Y_test, dtype=torch.float32).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7474fdae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from methods.adjoint import neural_adjoint as adj\n",
    "import methods.vae_based.vae_move as vae_move\n",
    "from methods.gidnet.gidnet import GidNet\n",
    "\n",
    "\n",
    "vae = torch.load(\"../../trained_models/vae_kl/VAE_LR=0.001_BS=256_E=150.pt\", weights_only=False).to(device)\n",
    "vae_simulator = torch.load(\"../../trained_models/vae_kl/VAE_LR=0.001_BS=256_E=150_SIMULATOR.pt\", weights_only=False).to(device)\n",
    "\n",
    "metamat_config = MetamatDsConfig()\n",
    "exp_config = ExpConfig()\n",
    "gidnet_config = GidNetConfig()\n",
    "n_movements_ = gidnet_config.n_movements\n",
    "lr_ = gidnet_config.lr\n",
    "\n",
    "encoder = Encoder(n_layer, 5, device).to(device)\n",
    "decoder = Decoder(n_layer, 5, device).to(device)\n",
    "decoder.load_state_dict(torch.load('../../trained_models/autoencoder/decoderLR=0.001_BS=1024_L2W-RW=3-1.pt'))\n",
    "encoder.load_state_dict(torch.load('../../trained_models/autoencoder/encoderLR=0.001_BS=1024_L2W-RW=3-1.pt'))\n",
    "\n",
    "\n",
    "mat_simulator = ForwardSimulator(metamat_config.num_lay, 2) \n",
    "mat_simulator.load_state_dict(torch.load(\"../../trained_models/simulator_noscale/Nf_175200train_43800val_5matlay_150e_1024b_0.005lr.pt\"))\n",
    "mat_simulator = mat_simulator.to(device)\n",
    "\n",
    "simulator = ForwardSimulator(metamat_config.num_lay, 2) \n",
    "simulator.load_state_dict(torch.load(\"../../trained_models/simulator/Nf_175200train_43800val_5matlay_150e_1024b_0.005lr.pt\"))\n",
    "\n",
    "n_movements_ = gidnet_config.n_movements\n",
    "lr_ = gidnet_config.lr\n",
    "gidnet = GidNet(\n",
    "    n_layer=n_layer,\n",
    "    n_mat=5,\n",
    "    n_seeds=30,\n",
    "    n_movements=32,\n",
    "    lambda1=gidnet_config.lambda1,\n",
    "    onehot_weight=1,\n",
    "    encoder=encoder,\n",
    "    decoder=decoder,\n",
    "    simulator=mat_simulator,\n",
    "    device=device,\n",
    "    metamat_config=metamat_config,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "13173ed1",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, manager, vtree = semloss.construct_vars(5, 5)\n",
    "\n",
    "losses = [\n",
    "    (\"hyperbolic_2\", \"hyp2.sdd\", \"hyp2.vtree\", lambda: semloss.force_hyperbolic_material(x, n_lay=5, n_mat=5, pattern_len=2)),\n",
    "    (\"hyperbolic_3\", \"hyp3.sdd\", \"hyp3.vtree\", lambda: semloss.force_hyperbolic_material(x, n_lay=5, n_mat=5, pattern_len=3)),\n",
    "    (\"palindrome_2\", \"pal2.sdd\", \"pal2.vtree\", lambda: semloss.force_palindrome_material(x, n_lay=5, n_mat=5, up_to=2)),\n",
    "    (\"no_adjacent\",  \"noadj.sdd\", \"noadj.vtree\", lambda: semloss.prevent_same_following_materials(x, n_lay=5, n_mat=5)),\n",
    "    (\"use_all\",     \"useall.sdd\", \"useall.vtree\", lambda: semloss.force_use_all_materials(x, n_lay=5, n_mat=5))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a761a0bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create .sdd and .vtree losses data\n",
    "for sloss_data in losses:\n",
    "    formula = sloss_data[3]()\n",
    "    manager.save(f\"sdds/{sloss_data[1]}\".encode(), formula)\n",
    "    vtree.save(f\"vtrees/{sloss_data[2]}\".encode())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "c7e07473",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[VAE] - hyperbolic_2, No loss time: 3.38 s, Sem loss time: 13.67 s\n",
      "[Adjoint] - hyperbolic_2, No loss time: 3.44 s, Sem loss time: 13.53 s\n",
      "[Gidnet] - hyperbolic_2, No loss time: 5.63 s, Sem loss time: 12.72 s\n",
      "[VAE] - hyperbolic_3, No loss time: 3.27 s, Sem loss time: 13.91 s\n",
      "[Adjoint] - hyperbolic_3, No loss time: 3.14 s, Sem loss time: 13.93 s\n",
      "[Gidnet] - hyperbolic_3, No loss time: 5.27 s, Sem loss time: 14.37 s\n",
      "[VAE] - palindrome_2, No loss time: 3.29 s, Sem loss time: 14.29 s\n",
      "[Adjoint] - palindrome_2, No loss time: 3.37 s, Sem loss time: 14.98 s\n",
      "[Gidnet] - palindrome_2, No loss time: 5.64 s, Sem loss time: 16.27 s\n",
      "[VAE] - no_adjacent, No loss time: 3.98 s, Sem loss time: 14.47 s\n",
      "[Adjoint] - no_adjacent, No loss time: 3.68 s, Sem loss time: 14.28 s\n",
      "[Gidnet] - no_adjacent, No loss time: 6.03 s, Sem loss time: 14.06 s\n",
      "[VAE] - use_all, No loss time: 3.05 s, Sem loss time: 13.93 s\n",
      "[Adjoint] - use_all, No loss time: 3.42 s, Sem loss time: 15.97 s\n",
      "[Gidnet] - use_all, No loss time: 6.02 s, Sem loss time: 16.65 s\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "point_x, point_y = X_test[0:1], Y_test[0:1]\n",
    "\n",
    "for sloss_data in losses:\n",
    "    sloss = SemanticLoss(f\"sdds/{sloss_data[1]}\", f\"vtrees/{sloss_data[2]}\")\n",
    "\n",
    "    # ----------------- VAE TESTS ------------------------------\n",
    "\n",
    "    NUM_POINTS = 1\n",
    "    initial_point = torch.randn((NUM_POINTS, vae.latent_dim)).to(vae.device)\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _ = vae_move.search_point_multi(vae, vae_simulator, initial_point, point_y, point_x, learning_rate=0.01, epochs=200, metamat_config=metamat_config)\n",
    "    no_loss_end_time = time.time()\n",
    "\n",
    "    sloss_start_time = time.time()\n",
    "    _ = vae_move.search_point_multi(vae, vae_simulator, initial_point, point_y, point_x, learning_rate=0.01, epochs=200, sloss=sloss, metamat_config=metamat_config)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "\n",
    "    print(f\"[VAE] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n",
    "\n",
    "\n",
    "    # ------------------ Adjoint TEST -----------------------------\n",
    "\n",
    "    NUM_POINTS = 1\n",
    "    point_dim = 30\n",
    "    initial_point = torch.randn((NUM_POINTS, point_dim)).to(device)\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _, _ = adj.neural_adjoint_search(mat_simulator, initial_point,point_x, point_y, lr=0.05, epochs=200, metamat_config=metamat_config)\n",
    "    no_loss_end_time = time.time()\n",
    "\n",
    "    sloss_start_time = time.time()\n",
    "    _, _ = adj.neural_adjoint_search(mat_simulator, initial_point,point_x, point_y, lr=0.05, epochs=200, sloss=sloss, metamat_config=metamat_config)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "    print(f\"[Adjoint] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n",
    "\n",
    "    # ------------------ Gidnet TEST ------------------------------\n",
    "\n",
    "    no_loss_start_time = time.time()\n",
    "    _ = gidnet.train(X_train, Y_train, point_y, 200, 0.01, point_x)\n",
    "    no_loss_end_time = time.time()\n",
    "    \n",
    "    sloss_start_time = time.time()\n",
    "    _ = gidnet.train(X_train, Y_train, point_y, 200, 0.01, point_x, semloss=sloss)\n",
    "    sloss_end_time = time.time()\n",
    "\n",
    "    print(f\"[Gidnet] - {sloss_data[0]}, No loss time: {(no_loss_end_time - no_loss_start_time):.2f} s, Sem loss time: {(sloss_end_time - sloss_start_time):.2f} s\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
