{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cf4f3952-dd30-43c9-882e-73d663caa660",
   "metadata": {},
   "source": [
    "# Hyper-Low-rank-PINN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f81421aa-9e54-421b-ad93-4d3219003caf",
   "metadata": {},
   "source": [
    "## Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1f80927-c505-4eb1-8a9a-699a8c210d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../src/\")\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from models.hyper_lr_pinn import LR_PINN\n",
    "from loss_functions.poisson import loss_fn, get_source_function\n",
    "from tools.tools import print_trainable_parameters, sample_parameters_from_folder, plot_parameter_points\n",
    "from tools.train_loops import pre_train\n",
    "# -- matplotlib styling\n",
    "plt.style.use(\"fast\")\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('font', serif='lmodern')\n",
    "plt.rc('font', size=12)  # Adjust the font size if needed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4970ec08-a8c0-4cf5-b70e-b1e5fb43bce4",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab3fb202-caeb-4dea-b886-c5281fe96d2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "dataset_path =  \"../dataset/poisson/0\"\n",
    "output_dir = \"./hyper_lr_pinn\"\n",
    "# -- get testing parameters\n",
    "parameter_test_org, file_test_org  = sample_parameters_from_folder(dataset_path, N=30, idx_start=1, type=\"test\")\n",
    "# -- to get a shorter range\n",
    "mask =  torch.tensor(parameter_test_org) < 5.0 #7.5\n",
    "parameter_test = torch.tensor(parameter_test_org)[mask]\n",
    "file_test = [f for f, m in zip(file_test_org, mask) if m]\n",
    "print(\"[Train] \",file_test)\n",
    "\n",
    "# --\n",
    "parameter_train_org, file_train_org  = sample_parameters_from_folder(dataset_path, N=12, idx_start=0, type=\"train\") \n",
    "mask =  torch.tensor(parameter_train_org) < 5.0 #7.5\n",
    "parameter_train = torch.tensor(parameter_train_org)[mask]\n",
    "file_train = [f for f, m in zip(file_train_org,  mask) if m]\n",
    "print(\"[Train] \",file_train)\n",
    "# -- plot parameters\n",
    "plot_parameter_points(parameter_train, parameter_test, output_dir, f\"parameter_variation.png\")\n",
    "# -- load the actual data\n",
    "dtype = torch.float32\n",
    "X_train = torch.tensor(np.load(f\"{dataset_path}/{file_train[0]}\")['x'], dtype=dtype).to(device)\n",
    "Y_train = torch.tensor(np.load(f\"{dataset_path}/{file_train[0]}\")['y'], dtype=dtype).to(device)\n",
    "X_test = torch.tensor(np.load(f\"{dataset_path}/{file_test[0]}\")['x'], dtype=dtype).to(device)\n",
    "Y_test = torch.tensor(np.load(f\"{dataset_path}/{file_test[0]}\")['y'], dtype=dtype).to(device)\n",
    "solutions = [torch.tensor(np.load(f\"{dataset_path}/{f}\")['u'], dtype=dtype).to(device) for f in file_test]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef5d97e7-9e4e-4a2e-ba30-3a38e1155560",
   "metadata": {},
   "source": [
    "### Dataset Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cb1a79-5fa5-460c-8971-58d67dbb7d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 4\n",
    "plt.title(str(file_test[idx]))\n",
    "plt.contourf(solutions[idx].detach().cpu().numpy(), levels=50, cmap='rainbow')\n",
    "plt.colorbar()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed2539b6-2cac-4511-ad92-7f76f35fb9be",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0da65b40-450e-484e-941c-8605dd0271b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(3)\n",
    "# ------------ Inputs -------------------------------\n",
    "X = X_train.to(device)\n",
    "Y = Y_train.to(device)\n",
    "solution_shape = X.shape\n",
    "# --\n",
    "x_flat = X.reshape(-1, 1)\n",
    "y_flat = Y.reshape(-1, 1)\n",
    "# --\n",
    "inputs = [x_flat.to(device), y_flat.to(device)]\n",
    "parameters = parameter_train.to(device) # choose 30 items withing this range\n",
    "source_function = get_source_function(\"desai_song\")\n",
    "# ---------- Model Init -----------------------------\n",
    "hidden_dim = 48\n",
    "rank = 28\n",
    "model = LR_PINN(hidden_dim=hidden_dim, rank=16,phase='phase1').to(device)\n",
    "# ---------- Optimization Parameters ----------------\n",
    "lr = 0.5e-4\n",
    "inner_epochs = 109\n",
    "outer_epochs = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d9cdcb-d630-42aa-8275-bb1f92cb2c09",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model =  pre_train(model, inputs, source_function, parameters, solution_shape, loss_fn, inner_epochs, outer_epochs,  lr=lr, print_interval=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ddbfbac-471f-4d3b-b385-b1bd101b5a3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), '../pre-trained/phase1_weights3.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc4606f-48e1-4a9a-be77-4c7157a7f928",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "231e9a08-86f7-4252-8a40-2b1975c81632",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
