{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759e150c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f9d897",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"6\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f84380e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from copy import deepcopy\n",
    "\n",
    "from neural_fields.nf_utils import compress_weights\n",
    "from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader\n",
    "from neural_fields.nf_train import train_nf, eval_diagnose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d03981",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS = 20\n",
    "AUX_EPOCHS = 100\n",
    "PINN_TRAINING = True\n",
    "FIELD = True\n",
    "FLUX_FIELDS = False\n",
    "CHEAT_INTEGRAL = False\n",
    "KY_MODES = [0, 1, 2, [3, 4, 5]]\n",
    "SPECTRAL = False\n",
    "FIELD_SUBSAMPLE = np.linspace(1.0, 0.2, EPOCHS)\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fc5fa37",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = CycloneNFDataset(\n",
    "    trajectory=\"iteration_13.h5\",\n",
    "    # trajectory=\"iteration_115.h5\",\n",
    "    # trajectory=\"iteration_131.h5\",\n",
    "    # trajectory=\"iteration_134.h5\",\n",
    "    timesteps=200,\n",
    "    normalize=\"zscore\",\n",
    "    normalize_coords=False,\n",
    "    spatial_fft=SPECTRAL,\n",
    "    # separate_ky_modes=KY_MODES,\n",
    "    flux_fields=FLUX_FIELDS,\n",
    "    realpotens=True,\n",
    ")\n",
    "loader = CycloneNFDataLoader(data, 2048, preload=True, shuffle=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b0be16",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "from neural_fields.models.siren import SIREN\n",
    "from neural_fields.models.wire import WIRE\n",
    "from neural_fields.models.mlp import MLPNF\n",
    "\n",
    "\n",
    "# model = SIREN(\n",
    "#     data.ndim,\n",
    "#     data.nchannels,\n",
    "#     n_layers=5,\n",
    "#     dim=64,\n",
    "#     skips=True,\n",
    "#     first_w0=1.0,\n",
    "#     hidden_w0=3.0,\n",
    "#     readout_w0=3.0,\n",
    "#     embed_type=\"discrete\",\n",
    "#     clip_out=False,\n",
    "# )\n",
    "\n",
    "# model = WIRE(  # TODO does not work with mode separation (complex channel mismatch)\n",
    "#     data.ndim,\n",
    "#     data.nchannels // 2,\n",
    "#     n_layers=1,\n",
    "#     dim=96,\n",
    "#     s0=1.0,\n",
    "#     first_w0=0.5,\n",
    "#     hidden_w0=1.0,\n",
    "#     readout_w0=1.0,\n",
    "#     complex_out=False,\n",
    "#     real_out=True,\n",
    "#     skips=False,\n",
    "#     learnable_w0_s0=True,\n",
    "#     embed_type=\"discrete\",\n",
    "# )\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=True,\n",
    "    skips=True,\n",
    "    embed_type=\"discrete\",\n",
    ")\n",
    "\n",
    "# model = torch.compile(model)\n",
    "\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "model_size = sum(p.nbytes for p in model.parameters())\n",
    "compression = data.full_df.nbytes / model_size\n",
    "print(f\"Params: {n_params / 1e3:.2f}k, compression: {compression:.2f}x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e49bc31",
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.AdamW(model.parameters(), 5e-3, weight_decay=1e-7)\n",
    "sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, EPOCHS, 1e-12)\n",
    "\n",
    "aux_optim, aux_sched = None, None\n",
    "if PINN_TRAINING:\n",
    "    aux_optim = torch.optim.AdamW(model.parameters(), 5e-4, weight_decay=1e-9)\n",
    "    aux_sched = torch.optim.lr_scheduler.CosineAnnealingLR(aux_optim, AUX_EPOCHS, 1e-12)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cc7c008",
   "metadata": {},
   "source": [
    "## Normal training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a6e81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, _ = train_nf(\n",
    "    model,\n",
    "    optim=optim,\n",
    "    sched=sched,\n",
    "    n_epochs=1,\n",
    "    data=data,\n",
    "    loader=loader,\n",
    "    device=device,\n",
    "    use_flux_fields=FLUX_FIELDS,\n",
    "    use_spectral=SPECTRAL,\n",
    "    cheat_integral=CHEAT_INTEGRAL,\n",
    "    field_loss=True,\n",
    "    physical_loss=False,\n",
    ")\n",
    "# model_pre = deepcopy(model)\n",
    "# # finetune\n",
    "# if PINN_TRAINING:\n",
    "#     model, _, _ = train_nf(\n",
    "#         model,\n",
    "#         optim=optim,\n",
    "#         sched=sched,\n",
    "#         n_epochs=AUX_EPOCHS,\n",
    "#         data=data,\n",
    "#         loader=loader,\n",
    "#         device=device,\n",
    "#         field_subsamples=FIELD_SUBSAMPLE,\n",
    "#         use_flux_fields=FLUX_FIELDS,\n",
    "#         use_spectral=SPECTRAL,\n",
    "#         cheat_integral=CHEAT_INTEGRAL,\n",
    "#         aux_optim=aux_optim,\n",
    "#         aux_sched=aux_sched,\n",
    "#         field_loss=False,\n",
    "#         physical_loss=True,\n",
    "#         integral_loss_weight={\"flux\": 0.1, \"phi\": 0.1},\n",
    "#         physical_loss_weight={\n",
    "#             \"kyspec\": 1.0,\n",
    "#             \"qspec\": 1.0,\n",
    "#             \"kyspec monotonicity\": 1.0,\n",
    "#             \"qspec monotonicity\": 1.0,\n",
    "#             \"mass\": 0.0,\n",
    "#         },\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54db717c",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = eval_diagnose(\n",
    "    model=model_pre,\n",
    "    data=data,\n",
    "    T=None,\n",
    "    device=device,\n",
    "    use_spectral=SPECTRAL,\n",
    "    cheat_integral=CHEAT_INTEGRAL,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "310c631d",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = eval_diagnose(\n",
    "    model=model,\n",
    "    data=data,\n",
    "    T=None,\n",
    "    device=device,\n",
    "    use_spectral=SPECTRAL,\n",
    "    cheat_integral=CHEAT_INTEGRAL,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2f790fe",
   "metadata": {},
   "source": [
    "## Initialization experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f584c714",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = CycloneNFDataset(\n",
    "    trajectory=\"iteration_13.h5\",\n",
    "    timesteps=160,\n",
    "    normalize=\"zscore\",\n",
    "    normalize_coords=False,\n",
    ")\n",
    "loader = CycloneNFDataLoader(data, 2048, preload=True, shuffle=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c5f1d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# grid = data.grid\n",
    "\n",
    "# step = 1 / (2 * grid.shape[-2])\n",
    "\n",
    "# shifted = grid + step\n",
    "\n",
    "# grid = torch.stack([grid, shifted], dim=-2)\n",
    "# grid = grid.flatten(start_dim=-3, end_dim=-2)\n",
    "\n",
    "# data.grid = grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2854ad19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "from neural_fields.models.mlp import MLPNF\n",
    "from neural_fields.models.siren import SIREN\n",
    "from neural_fields.models.wire import WIRE\n",
    "\n",
    "MODELS = {}\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"linear\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren\"] = model\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"sincos_discrete\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren sincos\"] = model\n",
    "\n",
    "model = SIREN(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    skips=True,\n",
    "    first_w0=1.0,\n",
    "    hidden_w0=3.0,\n",
    "    readout_w0=3.0,\n",
    "    embed_type=\"discrete\",\n",
    "    clip_out=False,\n",
    ")\n",
    "MODELS[\"siren nn.Embedding\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"linear\",\n",
    ")\n",
    "MODELS[\"mlp\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"sincos_discrete\",\n",
    ")\n",
    "MODELS[\"mlp sincos\"] = model\n",
    "\n",
    "model = MLPNF(\n",
    "    data.ndim,\n",
    "    data.nchannels,\n",
    "    n_layers=5,\n",
    "    dim=64,\n",
    "    act_fn=nn.SiLU,\n",
    "    use_checkpoint=False,\n",
    "    skips=True,\n",
    "    embed_type=\"discrete\",\n",
    ")\n",
    "MODELS[\"mlp nn.Embedding\"] = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops import rearrange\n",
    "\n",
    "from nf_utils import plot_diag, sample_field, plotND, phi_fft\n",
    "from gk_losses import spectra_losses, get_integrals\n",
    "\n",
    "\n",
    "def to_complex(x: torch.Tensor) -> torch.Tensor:\n",
    "    assert x.shape[0] == 2, x.shape\n",
    "    x = rearrange(x, \"c ... -> ... c\").contiguous()\n",
    "    return torch.view_as_complex(x).squeeze()\n",
    "\n",
    "\n",
    "def to_real(x: torch.Tensor) -> torch.Tensor:\n",
    "    return rearrange(torch.view_as_real(x), \"... c -> c ...\").squeeze()\n",
    "\n",
    "\n",
    "def df_fft(df: torch.Tensor, norm: str = \"forward\"):\n",
    "    if df.shape[0] == 4:\n",
    "        df = df[[0, 1]] + df[[2, 3]]\n",
    "    df = to_complex(df)\n",
    "    df = torch.fft.fftn(df, dim=(-5, -4, -3, -2, -1), norm=norm)\n",
    "    df = torch.fft.fftshift(df, dim=(-2,))\n",
    "    return to_real(df)\n",
    "\n",
    "\n",
    "def df_ifft(df: torch.Tensor, norm: str = \"forward\"):\n",
    "    if df.shape[0] == 4:\n",
    "        df = df[[0, 1]] + df[[2, 3]]\n",
    "    df = to_complex(df)\n",
    "    df = torch.fft.ifftshift(df, dim=(-2,))\n",
    "    df = torch.fft.ifftn(df, dim=(-5, -4, -3, -2, -1), norm=norm)\n",
    "    return to_real(df)\n",
    "\n",
    "\n",
    "gt_diagz, pred_diagz = [], []\n",
    "pred_dfz = {}\n",
    "pred_phiz = {}\n",
    "for m, nf in MODELS.items():\n",
    "    data.to(device)\n",
    "    nf.to(device)\n",
    "    with torch.no_grad():\n",
    "        pred_df = sample_field(nf, data, device, timestep=None).to(device)\n",
    "        pred_df = df_fft(pred_df)\n",
    "        pred_dfz[m] = pred_df\n",
    "        pred_df = df_ifft(pred_df)\n",
    "\n",
    "    gt_df = data.full_df.to(device)\n",
    "    pred_phi, (pred_pflux, pred_eflux, _) = get_integrals(\n",
    "        pred_df,\n",
    "        data,\n",
    "        flux_fields=True,\n",
    "        spectral_df=False,\n",
    "        phi_integral=False,\n",
    "    )\n",
    "    pred_phiz[m] = to_real(phi_fft(pred_phi))\n",
    "    gt_phi, (_, gt_eflux, _) = get_integrals(\n",
    "        gt_df.to(device), data, flux_fields=True, spectral_df=False\n",
    "    )\n",
    "    spec_losses, (gt_diag, pred_diag) = spectra_losses(\n",
    "        pred_df.cpu(),\n",
    "        pred_phi.cpu(),\n",
    "        pred_eflux.cpu(),\n",
    "        gt_df.cpu(),\n",
    "        gt_phi.cpu(),\n",
    "        gt_eflux.cpu(),\n",
    "        data.ds,\n",
    "    )\n",
    "    gt_diagz.append(gt_diag)\n",
    "    pred_diagz.append(pred_diag)\n",
    "    plotND(df_fft(pred_df), n=5, title=m)\n",
    "    plotND(torch.log(torch.abs(to_real(phi_fft(pred_phi))) ** 2), n=3, title=m)\n",
    "fig_diag = plot_diag(gt_diagz, pred_diagz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0657318",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45c47e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(len(pred_dfz), 3, figsize=(15, 3 * len(pred_dfz)))\n",
    "i = 0\n",
    "for m, pd in pred_dfz.items():\n",
    "    # df\n",
    "    x = pd.sum(0).sum((0, 1, 2, 3))\n",
    "    ax[i, 0].set_title(f\"{m} df(ky)\")\n",
    "    ax[i, 0].plot(x.cpu().numpy())\n",
    "    # phi\n",
    "    phi = pred_phiz[m]\n",
    "    x = phi.sum(0).sum((0, 1))\n",
    "    ax[i, 1].set_title(f\"{m} phi(ky)\")\n",
    "    ax[i, 1].plot(x.cpu().numpy())\n",
    "    x = phi.sum(0).sum((0, 1)) ** 2\n",
    "    ax[i, 2].set_title(f\"{m} phi(ky)^2\")\n",
    "    ax[i, 2].plot(x.cpu().numpy())\n",
    "    for j in [1, 2]:\n",
    "        ax[i, j].set_xscale(\"log\")\n",
    "        ax[i, j].set_yscale(\"log\")\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0ab6968",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mhd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
