{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-09-24T13:29:27.716051Z",
     "start_time": "2024-09-24T13:29:24.941555Z"
    }
   },
   "source": [
    "import torch\n",
    "from hydra import compose, initialize\n",
    "from hydra.utils import instantiate\n",
    "from omegaconf import open_dict\n",
    "import numpy as np\n",
    "from model.pl_modules.ddm import DDM"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:29:27.972387Z",
     "start_time": "2024-09-24T13:29:27.737348Z"
    }
   },
   "cell_type": "code",
   "source": [
    "with initialize(version_base=None, config_path=\"../conf\"):\n",
    "    \n",
    "    cfg = compose(config_name=\"default\")\n",
    "\n",
    "with open_dict(cfg):\n",
    "    cfg.trainer.fast_dev_run = True\n",
    "    cfg.logger.mode = 'disabled'\n",
    "    cfg.trainer.accelerator = 'cuda'\n",
    "    cfg.trainer.devices = \"auto\"\n",
    "    cfg.denoising_network.n_features = len(cfg.dataset.feature_names)"
   ],
   "id": "2b19def4befd6a3",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:29:28.363745Z",
     "start_time": "2024-09-24T13:29:28.086318Z"
    }
   },
   "cell_type": "code",
   "source": [
    "diffusion = instantiate(cfg.model.diffusion)\n",
    "denoising_network = instantiate(cfg.denoising_network)"
   ],
   "id": "3535978f17a5259a",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:29:29.174330Z",
     "start_time": "2024-09-24T13:29:28.840390Z"
    }
   },
   "cell_type": "code",
   "source": [
    "device = cfg.trainer.accelerator\n",
    "batch_size = 1\n",
    "n_features = 36\n",
    "max_lag = 2\n",
    "\n",
    "model: DDM = DDM.load_from_checkpoint(\n",
    "    '../../storage/acfjajzq/checkpoints/epoch=050.ckpt',\n",
    "    diffusion=diffusion, denoising_network=denoising_network,\n",
    "    loss_fn_sparsification=None,\n",
    "    loss_fn_fourier=None,\n",
    "    loss_fn_conv=None,\n",
    "    loss_fn_dtw=None,\n",
    ").to(device)\n",
    "model = model.eval()"
   ],
   "id": "f814d3904c61483b",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:29:29.798824Z",
     "start_time": "2024-09-24T13:29:29.786870Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_valmatrix_graph_strongest(coefficients: np.ndarray, keep_top: float, n_features, max_lag) -> (\n",
    "        np.ndarray, np.ndarray):\n",
    "    # coefficients.shape = [n_features, n_features*max_lag, seq_len-max_lag]\n",
    "    c_pos = np.quantile(coefficients, q=.95, axis=-1)\n",
    "    c_neg = np.quantile(coefficients, q=.05, axis=-1)\n",
    "    c = np.asarray([p if p > abs(n) else n for p, n in zip(c_pos.flatten(), c_neg.flatten())])\n",
    "    c = c.reshape(n_features, n_features, max_lag).transpose((1, 0, 2))\n",
    "    c = np.flip(c, axis=-1)\n",
    "    z = np.zeros((n_features, n_features, 1))\n",
    "    val_matrix = np.concatenate([z, c], axis=-1)\n",
    "    threshold_pos = np.quantile(val_matrix, q=1 - keep_top)\n",
    "    threshold_neg = np.quantile(val_matrix, q=keep_top)\n",
    "    val_matrix = np.where(val_matrix > threshold_pos, 1., (np.where(val_matrix < threshold_neg, -1., 0.)))\n",
    "    # val_matrix.shape = [n_features, n_features, max_lag + 1]\n",
    "    return val_matrix"
   ],
   "id": "e41163429707e4da",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:30:43.954525Z",
     "start_time": "2024-09-24T13:30:43.947746Z"
    }
   },
   "cell_type": "code",
   "source": "x_t = torch.randn((batch_size, cfg.dataset.seq_len, cfg.denoising_network.n_features)).to(device)",
   "id": "d580f0d462812703",
   "outputs": [],
   "execution_count": 19
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-09-24T13:31:55.508076Z",
     "start_time": "2024-09-24T13:31:53.997241Z"
    }
   },
   "cell_type": "code",
   "source": [
    "%%time\n",
    "for i in reversed(range(cfg.model.diffusion.diffusion_timesteps)):\n",
    "    t = torch.full((batch_size,), i, device=device, dtype=torch.long)\n",
    "    pred, coefficients = model.forward_denoising_network(x_t, t)\n",
    "    x_t = model.diffusion.backward_diffusion(x_t, pred, t, i)\n",
    "# val_matrix = get_valmatrix_graph_strongest(coefficients.detach().cpu().numpy(), .3, n_features, max_lag)"
   ],
   "id": "3158bde1962e8c30",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.44 s, sys: 70 ms, total: 1.51 s\n",
      "Wall time: 1.5 s\n"
     ]
    }
   ],
   "execution_count": 34
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "19b1efc374b52e3b"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
