{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.abspath(\"../../\")) \n",
    "sys.path.append(os.path.abspath(\"../\")) \n",
    "sys.path.append(os.path.abspath(\"../../../\")) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from configs.metamat_ds_config import MetamatDsConfig\n",
    "from configs.experiment_config import ExpConfig\n",
    "from data_utils.load_data import get_x_rt_data\n",
    "from data_utils.utils import create_train_val_split, scaler, unscaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metamat_config = MetamatDsConfig()\n",
    "metamat_config.print_config()\n",
    "exp_config = ExpConfig()\n",
    "n_layer = metamat_config.num_lay\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "      \n",
    "X_train, X_test, y_train, y_test = get_x_rt_data(metamat_config, exp_config.invd_num_test)\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n",
    "Y_train = torch.tensor(y_train, dtype=torch.float32).to(device)\n",
    "\n",
    "x_train, x_val, y_train, y_val = create_train_val_split(X_train, Y_train, seed=42)\n",
    "\n",
    "scaler(x_train, n_layer)\n",
    "scaler(x_val, n_layer)\n",
    "\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32, device=device)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import base_models.simulator_Nf as nf\n",
    "import importlib\n",
    "#from torchinfo import summary\n",
    "#from torchview import draw_graph\n",
    "importlib.reload(nf)\n",
    "\n",
    "simulator = nf.ForwardSimulator(num_layer_material=metamat_config.num_lay, layer_config=2).to(device)\n",
    "simulator.device\n",
    "\n",
    "# 128 is a sample batch size\n",
    "#summary(simulator, (128, simulator.n, ))\n",
    "\n",
    "#model_graph = draw_graph(simulator, torch.randn(1, 30))\n",
    "#model_graph.visual_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_result, txt_file_name = simulator.train_model(x_train, y_train, x_val, y_val, num_epochs=150, batch_size=256, learning_rate=0.005, log_file=True)\n",
    "\n",
    "train_loss, test_loss, train_head_losses, test_head_losses = train_result\n",
    "torch.save(simulator.state_dict(), f'../../trained_models/simulator_noscale{txt_file_name}.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_utils.utils import plot_wave_charts\n",
    "\n",
    "start = 10\n",
    "end = start + 1\n",
    "\n",
    "# Collect predictions, concatenate into a single tensor and convert to numpy\n",
    "predicted_y = simulator(x_val[start:end])\n",
    "predicted_y = predicted_y.cpu()\n",
    "predicted_y = predicted_y.detach().numpy()\n",
    "\n",
    "plot_wave_charts(y_val[start].cpu(), predicted_y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_wave_losses_charts(losses, test_losses):\n",
    "    n_epochs = losses.shape[0]\n",
    "    n_graph = 0\n",
    "\n",
    "    # rs, rp, ts, tp\n",
    "    labels = []\n",
    "    labels.append([f\"reflectance S @ {i}°\" for i in [25, 45, 65]])\n",
    "    labels.append([f\"reflectance P @ {i}°\" for i in [25, 45, 65]])\n",
    "    labels.append([f\"transmittance S @ {i}°\" for i in [25, 45, 65]])\n",
    "    labels.append([f\"transmittance P @ {i}°\" for i in [25, 45, 65]])\n",
    "  \n",
    "    plt.figure(figsize=(20,15))\n",
    "    while n_graph < 12:\n",
    "        lbl = labels[n_graph // 3][n_graph % 3]\n",
    "       \n",
    "        plt.subplot(4, 3, n_graph + 1)\n",
    "        plt.plot(range(n_epochs), losses[:,n_graph], linewidth=1, color='blue', label=\"Train loss\")\n",
    "        plt.xlabel(lbl)\n",
    "\n",
    "        if test_losses is not None:\n",
    "            plt.plot(range(n_epochs), test_losses[:,n_graph], linewidth=1, color='orange', label=\"Test loss\")\n",
    "            plt.xlabel(lbl)\n",
    "\n",
    "        n_graph += 1\n",
    "\n",
    "        plt.legend()\n",
    "        \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_wave_losses_charts(train_head_losses, test_head_losses)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
