{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ba112ba-598e-442e-8c75-6d34eadd2c2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchvision.transforms.v2 as v2\n",
    "from pytorch_msssim import SSIM\n",
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "from models.Models import *\n",
    "from models.UNet import UNet\n",
    "from models.MotionRNN import RNN\n",
    "from data_metric_provider import set_seed, eval, DataProvider\n",
    "\n",
    "import cartopy\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from imageio import mimwrite\n",
    "import global_land_mask as glm\n",
    "from datetime import date, timedelta\n",
    "from matplotlib import pyplot as plt\n",
    "from visualization import video, fastvideo, cartovideo\n",
    "from matplotlib.backends.backend_agg import FigureCanvasAgg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8813f8dd-5a50-4540-bca8-5f7775aaf5c1",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 298,
   "id": "6ab8e318-1419-4b4d-b1e1-c329875a16ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "dp = DataProvider(cuda=0,\n",
    "    grid_file='data3/grid.npy',\n",
    "    files=[\n",
    "        'data3/sentinel-HV_2014-10-25_2023-12-31.npy',\n",
    "        'data3/sentinel-HH_2014-10-25_2023-12-31.npy',\n",
    "\n",
    "        'data3/glorys-bottomT_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-mlotst_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-so_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-thetao_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-uo_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-vo_2015-09-01_2023-10-24.npy',\n",
    "        'data3/glorys-zos_2015-09-01_2023-10-24.npy',\n",
    "\n",
    "        'data3/meteo-f_2014-01-01_2023-12-31.npy',\n",
    "        'data3/meteo-P_2014-01-01_2023-12-31.npy',\n",
    "        'data3/meteo-T_2014-01-01_2023-12-31.npy',\n",
    "        'data3/meteo-u_2014-01-01_2023-12-31.npy',\n",
    "        'data3/meteo-v_2014-01-01_2023-12-31.npy',\n",
    "    ]\n",
    ")\n",
    "cell_areas = torch.Tensor(np.load('data3/cell_areas.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "5b610fe3-9b27-4993-b569-07547a024772",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "metrics={}\n",
    "rmseday={}\n",
    "rmsemonth={}\n",
    "split_inds = [(date(2022+(i>1), (10+i)%12+1, 1) - date(2022, 10, 1)).days for i in range(11)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 247,
   "id": "48332694-5723-4ef0-8beb-24a8bf1a4b67",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "set_seed(0)\n",
    "kwargs = dict(scale=1, w_w=6, i_w=0, e_w=3, starter='pers')\n",
    "\n",
    "ch = dp.all_data.shape[1]+1\n",
    "models = {\n",
    "    'Persistence': lambda: Seq2seq(**kwargs),\n",
    "    'Linear': lambda: Seq2seq(model=E(Lin(32, wn=False, bn=False)), **kwargs),\n",
    "    'DMVFN': lambda: VFN(ch, **kwargs),\n",
    "    'IAM4VP': lambda: S2SDU2(IAM(ch, **kwargs|{'w_w':9}), (64,64)),\n",
    "    'Neural ODE': lambda: Seq2seq(model=MALI(DC(32,32)), outgate=Conv(32,1,1), **kwargs|{'scale':2}),\n",
    "    'MotionRNN': lambda: Seq2seq(model=RNN(440, 200, ch, device=dp.DEVICE, layer_norm=0), **kwargs|{'scale':2}),\n",
    "    'Vid-ODE': lambda: VidODE(ch, **kwargs),\n",
    "    'UNet': lambda: FUNet(ch, **kwargs),\n",
    "    'rUNet': lambda: Seq2seq(model=RUNet(32), **kwargs)\n",
    "}\n",
    "name = 'rUNet'\n",
    "model = models[name]().to(dp.DEVICE)\n",
    "model.load_state_dict(torch.load(f\"weights4/{name}.pt\"))\n",
    "\n",
    "LOAD_BATCH, TRAIN_BATCH = 1, 32\n",
    "OPT_P = max(TRAIN_BATCH // LOAD_BATCH, 1)\n",
    "LR = 2e-4 * LOAD_BATCH/32\n",
    "opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], LR)\n",
    "sch = torch.optim.lr_scheduler.ExponentialLR(opt, 0.99)\n",
    "train_aug, train, val, test = dp.get_loaders(LOAD_BATCH, augment=1, sic_augment=3, val_batch=2, **kwargs)\n",
    "tr_loss, grads, tr_mse, vl_mse = [0], [0], [], []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 211,
   "id": "db226ef2-a0ea-484a-98e8-69e069766077",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# train, plot definitions\n",
    "weight = torch.Tensor([0.5]*model.i_w + [1]*model.e_w).to(dp.DEVICE)\n",
    "ssim_loss = SSIM(data_range=1, size_average=True, channel=model.e_w)\n",
    "ctm = dp.target_mask.to(dp.DEVICE)#[:,::model.scale,::model.scale]\n",
    "\n",
    "def loss(pred, target, mask, weight=None, crop_loss=False):\n",
    "    \"\"\"seq2seq mse loss with weight for seq positions, shapes BTHW\"\"\"\n",
    "    pred_ = (mask*pred*(target>0))\n",
    "    target_ = (mask*target).expand(*pred.shape)\n",
    "    if crop_loss:\n",
    "        pred_ *= ctm\n",
    "        target_ *= ctm\n",
    "\n",
    "    se = (pred_ - target_)**2\n",
    "    if len(pred.shape) == 6: # for multi-scale voxel flow\n",
    "        se = (pred*(0.8**(8-torch.arange(pred.shape[0])))).sum(0)\n",
    "    m1se = torch.mean(se, dim=(0,-2,-1))\n",
    "    if weight is None:\n",
    "        weight = torch.ones_like(m1se)\n",
    "    wse = torch.sum(m1se * weight)/weight.sum()\n",
    "    return wse + 0.2*(1 - ssim_loss(\n",
    "        pred_.reshape(-1,*target.shape[1:])[:,-model.e_w:],\n",
    "        target_.reshape(-1,*target.shape[1:])[:,-model.e_w:]))\n",
    "\n",
    "def fit(epochs, crop_loss=False):\n",
    "    bar = trange(len(tr_loss), epochs + 1)\n",
    "    step = 0\n",
    "    for epoch in bar:\n",
    "        gn, tl = [], []\n",
    "\n",
    "        model.train()\n",
    "        subbar = tqdm(train_aug, leave=False)\n",
    "        for X, y, m in subbar:\n",
    "            l = loss(model(X), y, m, weight, crop_loss); tl.append(l.detach().cpu().item())\n",
    "            l.backward()\n",
    "            step += 1\n",
    "            if step % OPT_P == 0:\n",
    "                subgn = []\n",
    "                for tag, value in model.named_parameters():\n",
    "                    if value.grad is not None:\n",
    "                        subgn.append(value.grad.norm().cpu().item())\n",
    "                subgn = np.mean(subgn)/OPT_P\n",
    "                gn.append(subgn); subbar.set_postfix(gn=subgn)\n",
    "                opt.step(); opt.zero_grad()\n",
    "        if sch: sch.step()\n",
    "\n",
    "        tl = np.mean(tl)\n",
    "        vm = eval(model, val, True, False, False, False, target_mask=dp.target_mask)['mse']\n",
    "        tr = eval(model, test, True, False, False, False, target_mask=dp.target_mask)['mse']\n",
    "        if not len(vl_mse) or vm < min(vl_mse):\n",
    "            torch.save(model.state_dict(), f\"weights3/best.pt\")\n",
    "        torch.save(model.state_dict(), f\"weights3/last.pt\")\n",
    "        gn = np.mean(gn)\n",
    "\n",
    "        tr_loss.append(tl)\n",
    "        tr_mse.append(tr)\n",
    "        vl_mse.append(vm)\n",
    "        grads.append(gn)\n",
    "        bar.set_postfix(l=tl, gn=gn, vm=vm)\n",
    "\n",
    "def plot(ylim=None):\n",
    "    plt.figure(figsize=(12,6))\n",
    "\n",
    "    plt.subplot(211)\n",
    "    plt.grid(True)\n",
    "    plt.plot(tr_mse, label=f'ts mse: {tr_mse[-1]:.2e}', color='tab:green')\n",
    "    plt.plot(vl_mse, label=f'val mse: {vl_mse[-1]:.2e}', color='tab:orange')\n",
    "\n",
    "    plt.hlines(1.11e-2, 0, len(tr_loss)-1, linestyle='-.', color='tab:orange', label=f'linear val: {1.11e-2:.2e}')\n",
    "\n",
    "    if ylim is not None: plt.ylim(*ylim)\n",
    "    plt.legend(ncols=2)\n",
    "\n",
    "    plt.subplot(212)\n",
    "    plt.grid(True)\n",
    "    plt.plot(grads, label='gradients', color='tab:green')\n",
    "    plt.plot(tr_loss, label=f'loss: {tr_loss[-1]:.2e}', color='tab:red')\n",
    "    plt.legend()\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f58f51",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit(200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "id": "f6222698-4792-4edd-91ed-30e4d53e3edb",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/181 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'mse': 0.007592081608194572,\n",
       " '1-ssim': 0.008336842060089111,\n",
       " '1-ms_ssim': 0.0045871734619140625,\n",
       " 'iiee@0.05': 0.08987987041473389,\n",
       " 'iiee@0.15': 0.09982864558696747,\n",
       " 'iiee@0.3': 0.08953135460615158,\n",
       " 'iiee@0.5': 0.09797116369009018,\n",
       " 'iiee@0.75': 0.06038009375333786}"
      ]
     },
     "execution_count": 205,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# model.load_state_dict(torch.load(f\"weights3/best.pt\"))\n",
    "metric = eval(model, test, target_mask=dp.target_mask, ms_ssim=True,\n",
    "                iiee_at=[0.05,0.15,0.30,0.50,0.75], cell_areas=cell_areas)\n",
    "metrics.loc[name] = metric\n",
    "torch.save(model.state_dict(), f\"weights4/{name}.pt\")\n",
    "metric"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
