{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.abspath(\"../../..\")) \n",
    "sys.path.append(os.path.abspath(\"../../\"))\n",
    "sys.path.append(os.path.abspath(\"../\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from configs.metamat_ds_config import MetamatDsConfig\n",
    "from configs.experiment_config import ExpConfig\n",
    "from data_utils.load_data import get_x_rt_data\n",
    "from data_utils.utils import create_train_val_split, scaler, unscaler\n",
    "from base_models.autoencoder import Encoder, Decoder\n",
    "from base_models.simulator_Nf import ForwardSimulator\n",
    "import importlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "metamat_config = MetamatDsConfig()\n",
    "metamat_config.print_config()\n",
    "exp_config = ExpConfig()\n",
    "\n",
    "n_layer = metamat_config.num_lay\n",
    "n_mat = metamat_config.num_mat\n",
    "\n",
    "X_train, X_test, Y_train, Y_test = get_x_rt_data(metamat_config, exp_config.invd_num_test)\n",
    "\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n",
    "Y_train = torch.tensor(Y_train, dtype=torch.float32).to(device)\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32).to(device)\n",
    "Y_test = torch.tensor(Y_test, dtype=torch.float32).to(device)\n",
    "\n",
    "x_train, x_val, y_train, y_val = create_train_val_split(X_train, Y_train, seed=42)\n",
    "\n",
    "scaler(x_train, n_layer)\n",
    "scaler(x_val, n_layer)\n",
    "scaler(X_test, n_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset_5_layer.methods.vae_based.vae as mvae\n",
    "importlib.reload(mvae)\n",
    "\n",
    "encoder = Encoder(n_layer, n_mat)\n",
    "decoder = Decoder(n_layer, n_mat)\n",
    "\n",
    "vae = mvae.VAE(encoder, decoder, device)\n",
    "latent_dim = vae.latent_dim\n",
    "\n",
    "# Train Nf from scratch\n",
    "simulator = ForwardSimulator(n_layer, layer_config=2, latent_dim=latent_dim)\n",
    "simulator = simulator.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchinfo import summary\n",
    "from torchview import draw_graph\n",
    "\n",
    "#summary(vae, (10, 30))\n",
    "draw_graph(vae, input_size=(10,30)).visual_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = vae.train_model(simulator, (x_train, y_train), (x_val, y_val), epochs=150, L2_WEIGHT=1, SIM_WEIGHT=1, KL_WEIGHT=0.00001, C_WEIGTH=0.5, learning_rate=0.001, batch_size=256, log=False)\n",
    "\n",
    "#torch.save(vae, f\"../../trained_models/vae_kl/{filename}.pt\")\n",
    "torch.save(simulator, f\"../../trained_models/vae_kl/{filename}_SIMULATOR.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
