{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4058ab6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scanpy as sc\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt \n",
    "import torch\n",
    "import argparse\n",
    "import os\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import EarlyStopping\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from numba import njit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e224d8ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from make_dataset import PerturbDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b9219880",
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_data_path='/XXXX-13/XXXX-14/XXXX-15/XXXX-7/peturb/data/ifn/IFN.h5ad'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9aea03e",
   "metadata": {},
   "outputs": [],
   "source": [
    "file=processed_data_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e98017c",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> make dataset use new one, and train test split remain unchanged:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "00e888f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 87436/87436 [00:03<00:00, 28042.50it/s]\n",
      "100%|██████████| 87436/87436 [00:03<00:00, 23093.54it/s]\n"
     ]
    }
   ],
   "source": [
    "train_dataset = PerturbDataset(file, n_genetic=1000, fraction_regimes_to_ignore=0.2)\n",
    "regimes_to_ignore = train_dataset.regimes_to_ignore\n",
    "test_dataset = PerturbDataset(file, n_genetic=1000, regimes_to_ignore=regimes_to_ignore, load_ignored=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7c68e166",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_size = int(0.8 * len(train_dataset))\n",
    "val_size = len(train_dataset) - train_size\n",
    "train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b931897",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> model configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "c43d95b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_nodes = test_dataset.dim\n",
    "lr=0.001\n",
    "reg_coeff=0.001\n",
    "num_modules=20\n",
    "constraint_mode='spectral_radius'\n",
    "num_gpus=1\n",
    "num_train_epochs=2\n",
    "train_batch_size=128\n",
    "num_fine_epochs=2\n",
    "n_condition=None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad7334ac",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> initialize the model, will use the MLPModuleGaussianModel built under pytorch lighting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "57e4dc38",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MLPModuleGaussianModel(\n",
    "            nb_nodes,\n",
    "            2,\n",
    "            num_modules,\n",
    "            16,\n",
    "            lr_init=lr,\n",
    "            reg_coeff=reg_coeff,\n",
    "            constraint_mode=constraint_mode,\n",
    "            num_conditions=n_condition)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "9b8444de",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLPModuleGaussianModel(\n",
       "  (module): MLPModularGaussianModule(\n",
       "    (weights_node2module): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 20x16x964]\n",
       "        (1): Parameter containing: [torch.float32 of size 20x16x16]\n",
       "        (2): Parameter containing: [torch.float32 of size 20x1x16]\n",
       "    )\n",
       "    (weights_module2node): ParameterList(  (0): Parameter containing: [torch.float32 of size 964x1x20])\n",
       "    (biases_node2module): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 20x16]\n",
       "        (1): Parameter containing: [torch.float32 of size 20x16]\n",
       "        (2): Parameter containing: [torch.float32 of size 20x1]\n",
       "    )\n",
       "    (biases_module2node): ParameterList(  (0): Parameter containing: [torch.float32 of size 964x1])\n",
       "    (gumbel_innout): GumbelInNOut()\n",
       "    (linear_uv1): Linear(in_features=964, out_features=964, bias=True)\n",
       "    (linear_uv2): Linear(in_features=964, out_features=57840, bias=True)\n",
       "    (linear_mean): Linear(in_features=20, out_features=20, bias=True)\n",
       "    (linear_var): Linear(in_features=20, out_features=20, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6b1ae1f",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> Train the model, I trained two epoches, but the kl loss seems unchanged(-0.4167) in every batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "7562ac1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(accelerator=\"cpu\",max_epochs=num_train_epochs,val_check_interval=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "33fb9668",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Set SLURM handle signals.\n",
      "\n",
      "  | Name   | Type                     | Params\n",
      "----------------------------------------------------\n",
      "0 | module | MLPModularGaussianModule | 57.1 M\n",
      "----------------------------------------------------\n",
      "57.1 M    Trainable params\n",
      "0         Non-trainable params\n",
      "57.1 M    Total params\n",
      "228.561   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation sanity check: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12ae66ab5efa487483e834cb50f12896",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.fit(model,\n",
    "        DataLoader(train_dataset, batch_size=train_batch_size, num_workers=4),\n",
    "        DataLoader(val_dataset, num_workers=8, batch_size=256),)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60f61e55",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> MLPModularGaussianModule is the basic module of dcdfg, here I changed it into a vae and define the kl_loss for w:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "0f1b843c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLPModularGaussianModule(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_vars,\n",
    "        num_layers,\n",
    "        num_modules,\n",
    "        hid_dim,\n",
    "        nonlin=\"leaky_relu\",\n",
    "        constraint_mode=\"spectral_radius\",\n",
    "        num_conditions=None):\n",
    "        \"\"\"\n",
    "        Simplification for the \"perfect known\" context and the MLP framework\n",
    "        :param int num_vars: number of variables in the system\n",
    "        :param int num_layers: number of hidden layers\n",
    "        :param int num_modules number of modules\n",
    "        :param int hid_dim: number of hidden units per layer\n",
    "        :param int num_params: number of parameters per conditional *outputted by MLP*\n",
    "        :param str nonlin: which nonlinearity to use\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.num_vars = num_vars\n",
    "        self.num_layers = num_layers\n",
    "        self.num_modules = num_modules\n",
    "        self.hid_dim = hid_dim\n",
    "        self.nonlin = nonlin\n",
    "        self.constraint_mode = constraint_mode\n",
    "\n",
    "        self.weights_node2module = nn.ParameterList()\n",
    "        self.weights_module2node = nn.ParameterList()\n",
    "        self.biases_node2module = nn.ParameterList()\n",
    "        self.biases_module2node = nn.ParameterList()\n",
    "        self.log_stds = nn.Parameter(data=torch.zeros((self.num_vars,)))\n",
    "\n",
    "        self.gumbel_innout = GumbelInNOut(self.num_vars, self.num_modules)\n",
    "        \n",
    "        self.linear_uv1=nn.Linear(self.num_vars,self.num_vars)\n",
    "        self.linear_uv2=nn.Linear(self.num_vars,self.num_vars*self.num_modules*3)\n",
    "        self.linear_mean=nn.Linear(self.num_modules,self.num_modules)\n",
    "        self.linear_var=nn.Linear(self.num_modules,self.num_modules)\n",
    "        \n",
    "        \n",
    "        self.zero_weights_ratio = 0.0\n",
    "        self.numel_weights = 0\n",
    "        self.deterministic = False\n",
    "        \n",
    "        if num_conditions:\n",
    "            self.num_conditions=num_conditions\n",
    "        else:\n",
    "            self.num_conditions=None\n",
    "\n",
    "        # Instantiate the parameters of each layer in the model of each variable\n",
    "        # Here, features -> modules is a MLP but modules -> gene is linear\n",
    "        for weights, biases, num_out_nodes, num_in_nodes in (\n",
    "            (\n",
    "                self.weights_node2module,\n",
    "                self.biases_node2module,\n",
    "                self.num_modules,\n",
    "                self.num_vars,\n",
    "            ),\n",
    "        ):\n",
    "            for i in range(self.num_layers + 1):\n",
    "                in_dim = num_in_nodes if i == 0 else self.hid_dim\n",
    "                #cheange the dimension of the encoder to the number of factors\n",
    "                out_dim = 1 if i == self.num_layers else self.hid_dim\n",
    "\n",
    "                weights.append(nn.Parameter(torch.zeros(num_out_nodes, out_dim, in_dim)))\n",
    "                biases.append(nn.Parameter(torch.zeros(num_out_nodes, out_dim)))\n",
    "                self.numel_weights += self.num_vars * out_dim * in_dim\n",
    "\n",
    "            # init params\n",
    "            with torch.no_grad():\n",
    "                for node in range(num_out_nodes):\n",
    "                    for i, w in enumerate(weights):\n",
    "                        w = w[node]\n",
    "                        nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain(self.nonlin))\n",
    "                    for i, b in enumerate(biases):\n",
    "                        b = b[node]\n",
    "                        b.zero_()\n",
    "\n",
    "            # separate for linear decoding model\n",
    "            self.weights_module2node.append(nn.Parameter(torch.zeros(self.num_vars, 1, self.num_modules)))\n",
    "            self.biases_module2node.append(nn.Parameter(torch.zeros(self.num_vars, 1)))\n",
    "            with torch.no_grad():\n",
    "                for node in range(num_out_nodes):\n",
    "                    nn.init.xavier_uniform_(self.weights_module2node[0][node],gain=nn.init.calculate_gain(self.nonlin))\n",
    "                    self.biases_module2node[0][node].zero_()\n",
    "\n",
    "        # Initialization for spectral radius constraint\n",
    "        w_adj = self.get_w_adj()\n",
    "        self.register_buffer(\"u\", torch.zeros(w_adj.shape[0]))\n",
    "        self.register_buffer(\"v\", torch.zeros(w_adj.shape[0]))\n",
    "        a, b = -3, 3\n",
    "        with torch.no_grad():\n",
    "            nn.init.trunc_normal_(self.u, a=a, b=b)\n",
    "            nn.init.trunc_normal_(self.v, a=a, b=b)\n",
    "\n",
    "        # Initialization for block spectral radius constraint\n",
    "        self.register_buffer(\"u_v\", torch.zeros(self.num_vars))\n",
    "        self.register_buffer(\"u_f\", torch.zeros(self.num_modules))\n",
    "        self.register_buffer(\"v_v\", torch.zeros(self.num_vars))\n",
    "        self.register_buffer(\"v_f\", torch.zeros(self.num_modules))\n",
    "        a, b = -3, 3\n",
    "        with torch.no_grad():\n",
    "            nn.init.trunc_normal_(self.u_v, a=a, b=b)\n",
    "            nn.init.trunc_normal_(self.u_f, a=a, b=b)\n",
    "            nn.init.trunc_normal_(self.v_v, a=a, b=b)\n",
    "            nn.init.trunc_normal_(self.v_f, a=a, b=b)\n",
    "\n",
    "        # get scaling factor and normalization factor for penalty\n",
    "        with torch.no_grad():\n",
    "            mat = w_adj\n",
    "            self.base_radius = self.spectral_radius_adj(mat, n_iter=100)\n",
    "            self.constraint_norm = self.compute_dag_constraint(mat).item()\n",
    "            if np.isinf(self.constraint_norm):\n",
    "                raise ValueError(\"Error: constraint normalization is infinite\")\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        :param x: batch_size x num_vars\n",
    "        :return: batch_size x num_vars * num_params, the parameters of each variable conditional\n",
    "        \"\"\"\n",
    "        num_zero_weights = 0\n",
    "        \n",
    "        #encode u and v\n",
    "        mask_node2module, mask_module2node,w_dis,y_soft = self.encode_uv(x)\n",
    "        \n",
    "        if self.num_conditions:\n",
    "            mask_module2node[:,-self.num_conditions:,:]=0\n",
    "        \n",
    "        mask_module2node = torch.transpose(mask_module2node, 1, 2)\n",
    "        \n",
    "        #vae part\n",
    "        mu,std=self.encode(x,mask_node2module)\n",
    "        z=self.reparameterize(mu,std)\n",
    "        x=self.decode(z,mask_module2node)\n",
    "        \n",
    "        self.zero_weights_ratio = num_zero_weights / float(self.numel_weights)\n",
    "        return x,w_dis,y_soft\n",
    "    \n",
    "    def encode_uv(self,x):\n",
    "        num_batch = x.size(0)\n",
    "\n",
    "        w=F.leaky_relu(self.linear_uv1(x))\n",
    "        bn_layer_w_1 = nn.BatchNorm1d(w.shape[1])\n",
    "        w=bn_layer_w_1(w)\n",
    "        \n",
    "        w=F.leaky_relu(self.linear_uv2(w))\n",
    "        bn_layer_w_2 = nn.BatchNorm1d(w.shape[1])\n",
    "        w=bn_layer_w_2(w)\n",
    "        \n",
    "        w=w.view(-1,self.num_vars,self.num_modules,3)\n",
    "        \n",
    "        #suppose a uniform distribution for w, prior of w\n",
    "        size_of_w = w.size()\n",
    "        w_dis_logits=torch.rand(size_of_w)\n",
    "        w_dis=torch.log(F.gumbel_softmax(w_dis_logits, tau=1, hard=False)).to(x.device)\n",
    "        #w_dis=torch.rand(size_of_w)\n",
    "        \n",
    "        if not self.deterministic:\n",
    "            \n",
    "            #output y_soft, the posterior of w after gumbel_softmax_var\n",
    "            design,y_soft = self.gumbel_softmax_var(w)\n",
    "            \n",
    "            node2module = design[:, :, :, 0]\n",
    "            module2node = design[:, :, :, 1]\n",
    "        else:\n",
    "            node2module = self.freeze_node2module.unsqueeze(0)\n",
    "            module2node = self.freeze_module2node.unsqueeze(0)\n",
    "            \n",
    "        return node2module,module2node,w_dis,y_soft\n",
    "        \n",
    "    def encode(self,x,mask_node2module):\n",
    "        num_batch = x.size(0)\n",
    "        num_zero_weights = 0\n",
    "        for weights, biases, mask in [(self.weights_node2module, self.biases_node2module, mask_node2module)]:\n",
    "            num_layers = len(weights) - 1\n",
    "            for layer in range(num_layers + 1):\n",
    "                if layer == 0:\n",
    "                    x = (torch.einsum(\"tij,bjt,bj->bti\", weights[layer], mask, x)+ biases[layer])\n",
    "                else:\n",
    "                    x = torch.einsum(\"tij,btj->bti\", weights[layer], x) + biases[layer]\n",
    "                num_zero_weights += weights[layer].numel() - weights[layer].nonzero().size(0)\n",
    "                if layer != num_layers:\n",
    "                    x = (F.leaky_relu(x) if self.nonlin == \"leaky_relu\" else torch.sigmoid(x))\n",
    "                else:\n",
    "                    x = x.squeeze()\n",
    "                    \n",
    "        mu=F.leaky_relu(self.linear_mean(x))\n",
    "        bn_layer_mu = nn.BatchNorm1d(mu.shape[1])\n",
    "        mu=bn_layer_mu(mu)\n",
    "        \n",
    "        var=F.leaky_relu(self.linear_var(x))\n",
    "        bn_layer_var = nn.BatchNorm1d(var.shape[1])\n",
    "        var=bn_layer_var(var)\n",
    "        \n",
    "        return mu,var\n",
    "    \n",
    "    def reparameterize(self, mu, var):\n",
    "        \n",
    "        epsilon = torch.randn_like(var)\n",
    "        z = mu + var*epsilon \n",
    "        \n",
    "        return z\n",
    "\n",
    "    def decode(self,x,mask_module2node):\n",
    "        num_batch = x.size(0)\n",
    "        num_zero_weights = 0\n",
    "        \n",
    "        for weights, biases, mask in [(self.weights_module2node, self.biases_module2node, mask_module2node)]:\n",
    "            num_layers = len(weights) - 1\n",
    "            for layer in range(num_layers + 1):\n",
    "                if layer == 0:\n",
    "                    x = (torch.einsum(\"tij,bjt,bj->bti\", weights[layer], mask, x)+ biases[layer])\n",
    "                else:\n",
    "                    x = torch.einsum(\"tij,btj->bti\", weights[layer], x) + biases[layer]\n",
    "                num_zero_weights += weights[layer].numel() - weights[layer].nonzero().size(0)\n",
    "                if layer != num_layers:\n",
    "                    x = (F.leaky_relu(x) if self.nonlin == \"leaky_relu\" else torch.sigmoid(x))\n",
    "                else:\n",
    "                    x = x.squeeze()\n",
    "        return x\n",
    "    \n",
    "    def gumbel_softmax_var(self,x,tau=1,hard=True):\n",
    "        gumbel_noise = -torch.empty_like(x).exponential_().log()\n",
    "        gumbels = x + gumbel_noise\n",
    "        y_soft = gumbels.softmax(-1)\n",
    "\n",
    "        if hard:\n",
    "            # Straight through.\n",
    "            index = y_soft.max(-1, keepdim=True)[1]\n",
    "            y_hard = torch.zeros_like(gumbels).scatter_(-1, index, 1.0)\n",
    "            ret = y_hard.detach() - y_soft.detach() + y_soft\n",
    "        else:\n",
    "            # Reparametrization trick.\n",
    "            ret = y_soft\n",
    "        \n",
    "        #output y_soft, which is also a distribution, and it's the posterior of w\n",
    "        \n",
    "        return ret,y_soft\n",
    "    \n",
    "    def log_likelihood(self, x):\n",
    "        \"\"\"\n",
    "        Return log-likelihood of the model for each example.\n",
    "        WARNING: This is really a joint distribution only if the DAGness constraint on the mask is satisfied.\n",
    "                 Otherwise the joint does not integrate to one.\n",
    "        :param x: (batch_size, num_vars)\n",
    "        :return: (batch_size, num_vars) log-likelihoods\n",
    "        \"\"\"\n",
    "        #1.9: return p and q after forward\n",
    "        density_params,w_dis,y_soft = self.forward(x)\n",
    "        stds = torch.sqrt(torch.exp(self.log_stds) + 1e-4)\n",
    "        \n",
    "        #1.9: add return, p and q\n",
    "        return torch.distributions.Normal(density_params, stds.unsqueeze(0)).log_prob(x),w_dis,y_soft\n",
    "\n",
    "    def losses(self, x, mask):\n",
    "        \"\"\"\n",
    "        Compute the loss. If intervention is perfect and known, remove\n",
    "        the intervened targets from the loss with a mask.\n",
    "        \"\"\"\n",
    "        \n",
    "        #if mask==False:\n",
    "        if isinstance(mask, bool):\n",
    "            \n",
    "            #1.9 add [0], since log_likelihood is only the first term\n",
    "            log_likelihood_x,w_dis,y_soft=self.log_likelihood(x)\n",
    "            log_likelihood = torch.sum(log_likelihood_x,dim=0) / x.shape[0]\n",
    "            #log_likelihood = torch.sum(self.log_likelihood(x),dim=0) / x.shape[0]\n",
    "        else:\n",
    "            #1.9: same\n",
    "            log_likelihood_x,w_dis,y_soft=self.log_likelihood(x)\n",
    "            log_likelihood = torch.sum(log_likelihood_x* mask,dim=0) /mask.size(0)\n",
    "            #log_likelihood = torch.sum(self.log_likelihood(x) * mask, dim=0) / mask.size(0)\n",
    "        # constraint related\n",
    "        adj = self.get_w_adj()\n",
    "        h = (\n",
    "            self.compute_dag_constraint(adj)\n",
    "            # / self.constraint_norm\n",
    "        )\n",
    "        a, b = self.gumbel_innout.get_proba_()\n",
    "        reg = 0.5 * (a.sum() + b.sum()) / a.numel()\n",
    "        \n",
    "        kl_loss_w=self.kl_losses(w_dis,y_soft)\n",
    "        #print(kl_loss_w)\n",
    "        losses = (-torch.mean(log_likelihood)+kl_loss_w, h, reg)\n",
    "        return losses\n",
    "    \n",
    "    #kl_losses for w\n",
    "    def kl_losses(self,w_dis,y_soft):\n",
    "        \n",
    "        #node2module,module2node,w_dis,y_soft=self.encode_uv(x)\n",
    "        #modify kl loss\n",
    "        kl_loss = nn.KLDivLoss()\n",
    "        kl_loss_w = kl_loss(w_dis,y_soft)\n",
    "        \n",
    "        return kl_loss_w\n",
    "\n",
    "    def get_w_adj(self):\n",
    "        return self.gumbel_innout.get_proba_features()\n",
    "    \n",
    "    def spectral_radius_adj(self, w_adj, n_iter=5):\n",
    "        \"\"\"\n",
    "        Compute the spectral norm of w_adj with a power iteration.\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            for _ in range(n_iter):\n",
    "                self.v = F.normalize(w_adj.T @ self.v, dim=0)\n",
    "                self.u = F.normalize(w_adj @ self.u, dim=0)\n",
    "        return self.v.T @ w_adj @ self.u / (self.v @ self.u)\n",
    "\n",
    "    def spectral_radius_block(self, A, B, n_iter=5):\n",
    "        \"\"\"\n",
    "        A is shape n_var x n_module\n",
    "        B is shape n_module x n_var\n",
    "        Compute the spectral norm of U,V with a power iteration.\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            for _ in range(n_iter):\n",
    "                self.u_f = F.normalize(B @ self.u_v, dim=0)\n",
    "                self.u_v = F.normalize(A @ self.u_f, dim=0)\n",
    "                self.v_f = F.normalize(A.T @ self.v_v, dim=0)\n",
    "                self.v_v = F.normalize(B.T @ self.v_f, dim=0)\n",
    "        numerator = self.v_f.T @ B @ self.u_v + self.v_v.T @ A @ self.u_f\n",
    "        denominator = self.v_f.T @ self.u_f + self.v_v.T @ self.u_v\n",
    "        return numerator / denominator\n",
    "\n",
    "    def spectral_radius_iteration(self, node2module, module2node, n_iter=5):\n",
    "        \"\"\"\n",
    "        Compute the spectral norm of w_adj with a power iteration.\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            for _ in range(n_iter):\n",
    "                # w_adj = node2module @ module2node.T - diag\n",
    "                # v update: v+ \\propsto w_adj v\n",
    "                diag_term = self.v * torch.sum(node2module * module2node, 1)\n",
    "                self.v = F.normalize(node2module @ module2node.T @ self.v - diag_term, dim=0)\n",
    "                # u update: y+ \\propsto w_adj.T u\n",
    "                diag_term = self.u * torch.sum(node2module * module2node, 1)\n",
    "                self.u = F.normalize(module2node @ node2module.T @ self.u - diag_term, dim=0)\n",
    "        numerator = self.v.T @ node2module @ module2node.T @ self.u\n",
    "        numerator -= self.v @ (self.u * torch.sum(node2module * module2node, 1))\n",
    "        return numerator / (self.v @ self.u)\n",
    "\n",
    "    def compute_dag_constraint(self, adj):\n",
    "        \"\"\"\n",
    "        Compute the DAG constraint on weighted adjacency matrix w_adj\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        if self.constraint_mode == \"exp\":\n",
    "            return torch.trace(torch.matrix_exp(adj / self.base_radius)) - self.num_vars\n",
    "        elif self.constraint_mode == \"spectral_radius\":\n",
    "            return self.spectral_radius_adj(adj)\n",
    "        elif self.constraint_mode == \"exptrick\":\n",
    "            return (\n",
    "                torch.trace(\n",
    "                    torch.matrix_exp(\n",
    "                        self.gumbel_innout.get_proba_modules() / self.base_radius\n",
    "                    )\n",
    "                )\n",
    "                - self.num_modules\n",
    "            )\n",
    "        elif self.constraint_mode == \"spectraltrick\":\n",
    "            return self.compute_dag_constraint_spectral(\n",
    "                *self.gumbel_innout.get_proba_()\n",
    "            )\n",
    "        else:\n",
    "            raise ValueError(\n",
    "                \"constraint_mode needs to be in ['exp', 'spectral_radius', 'matrix_power'].\"\n",
    "            )\n",
    "\n",
    "    def compute_dag_constraint_power(self, w_adj):\n",
    "        \"\"\"\n",
    "        Compute the DAG constraint DIBS style via a matrix power.\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        d = w_adj.shape[0]\n",
    "        return (\n",
    "            torch.trace(\n",
    "                torch.linalg.matrix_power(\n",
    "                    torch.eye(w_adj.shape[0], device=w_adj.device)\n",
    "                    + w_adj / self.base_radius,\n",
    "                    d,\n",
    "                )\n",
    "            )\n",
    "            - d\n",
    "        )\n",
    "\n",
    "    def compute_dag_constraint_spectral(self, adj_node2module, adj_module2node):\n",
    "        \"\"\"\n",
    "        Compute the DAG constraint NO-BEARS style via the spectral norm.\n",
    "        :param np.ndarray w_adj: the weighted adjacency matrix (each entry in [0,1])\n",
    "        \"\"\"\n",
    "        return self.spectral_radius_iteration(adj_node2module, adj_module2node)\n",
    "    \n",
    "    def check_acyclicity(self):\n",
    "        adj = self.get_w_adj()\n",
    "        to_keep = (adj > 0.5).type_as(adj)\n",
    "        return is_acyclic(to_keep.cpu().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f86b2ab",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> MLPModuleGaussianModel is use LightningModule to run the model, and will import the MLPModularGaussianModule, I only add several lines to calculate the kl loss in the training_step part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abc59705",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLPModuleGaussianModel(pl.LightningModule):\n",
    "    \"\"\"\n",
    "    Lightning module that runs augmented lagrangian\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_vars,\n",
    "        num_layers,\n",
    "        num_modules,\n",
    "        hid_dim,\n",
    "        nonlin=\"leaky_relu\",\n",
    "        lr_init=1e-3,\n",
    "        reg_coeff=0.1,\n",
    "        constraint_mode=\"exp\",\n",
    "        num_conditions=None\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.module = MLPModularGaussianModule(\n",
    "            num_vars,\n",
    "            num_layers,\n",
    "            num_modules,\n",
    "            hid_dim,\n",
    "            nonlin=nonlin,\n",
    "            constraint_mode=constraint_mode,\n",
    "            num_conditions=num_conditions)\n",
    "        # augmented lagrangian params\n",
    "        # mu: penalty\n",
    "        # gamma: multiplier\n",
    "        self.mu_init = 1e-8\n",
    "        self.gamma_init = 0.0\n",
    "        self.omega_gamma = 1e-4\n",
    "        self.omega_mu = 0.9\n",
    "        self.h_threshold = 1e-8\n",
    "        self.mu_mult_factor = 2\n",
    "        # opt params\n",
    "        self.save_hyperparameters()\n",
    "        self.hparams[\"name\"] = self.__class__.__name__\n",
    "        self.hparams[\"module_name\"] = self.module.__class__.__name__\n",
    "\n",
    "        self.lr_init = lr_init\n",
    "        self.reg_coeff = reg_coeff\n",
    "        self.constraint_mode = constraint_mode\n",
    "        self.num_conditions=num_conditions\n",
    "\n",
    "        # initialize stuff for learning loop\n",
    "        self.aug_lagrangians = []\n",
    "        self.not_nlls = []  # Augmented Lagrangrian minus (pseudo) NLL\n",
    "        self.nlls = []  # NLL on train\n",
    "        self.nlls_val = []  # NLL on validation\n",
    "        self.regs = []\n",
    "\n",
    "        # Augmented Lagrangian stuff\n",
    "        self.mu = self.mu_init\n",
    "        self.gamma = self.gamma_init\n",
    "\n",
    "        # bookkeeping for training\n",
    "        self.acyclic = 0.0\n",
    "        self.aug_lagrangians_val = []\n",
    "        self.not_nlls_val = []\n",
    "        self.constraint_value = 0.0\n",
    "        self.constraints_at_stat = []\n",
    "        self.reg_value = 0.0\n",
    "        self.internal_checkups = 0.0\n",
    "        self.stationary_points = 0.0\n",
    "\n",
    "    def forward(self, data):\n",
    "        if self.num_conditions is not None:\n",
    "            x=data\n",
    "            #not sure which to devide\n",
    "            log_likelihood = torch.sum(self.module.log_likelihood(x),dim=0)/x.shape[0]\n",
    "        else:\n",
    "            x, masks, regimes = data\n",
    "            log_likelihood = torch.sum(self.module.log_likelihood(x) * masks, dim=0) / masks.size(0)\n",
    "        return -torch.mean(log_likelihood)\n",
    "\n",
    "    def get_augmented_lagrangian(self, nll, constraint_violation, reg):\n",
    "        # compute augmented langrangian\n",
    "        return (nll+ self.reg_coeff * reg+ self.gamma * constraint_violation+ 0.5 * self.mu * constraint_violation**2)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # get data and compute loss\n",
    "        if self.num_conditions is not None:\n",
    "            x=batch\n",
    "            nll, constraint_violation, reg = self.module.losses(x, mask=False)\n",
    "        else:\n",
    "            x, masks, regimes = batch\n",
    "            \n",
    "            #output kl loss\n",
    "            #kl_loss=self.module.kl_losses(x)\n",
    "            #print(kl_loss)\n",
    "            \n",
    "            #calculate nll loss\n",
    "            nll, constraint_violation, reg = self.module.losses(x, masks)\n",
    "\n",
    "        aug_lagrangian = self.get_augmented_lagrangian(nll, constraint_violation, reg)\n",
    "\n",
    "        # logging\n",
    "        self.nlls.append(nll.item())\n",
    "        self.aug_lagrangians.append(aug_lagrangian.item())\n",
    "        self.not_nlls.append(aug_lagrangian.item() - nll.item())\n",
    "\n",
    "        self.log(\"Train/aug_lagrangian\", aug_lagrangian.detach())\n",
    "        self.log(\"Train/nll\", nll.detach())\n",
    "        self.log(\"Train/not_nll\", aug_lagrangian.detach() - nll.detach())\n",
    "        self.log(\"Aug_lag/mu\", self.mu)\n",
    "        self.log(\"Aug_lag/gamma\", self.gamma)\n",
    "\n",
    "        # return loss\n",
    "        return aug_lagrangian\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        if self.num_conditions is not None:\n",
    "            x=batch\n",
    "            nll, constraint_violation, reg = self.module.losses(x,mask=False)\n",
    "        else:\n",
    "            x, masks, regimes = batch\n",
    "            nll, constraint_violation, reg = self.module.losses(x,masks)\n",
    "        aug_lagrangian = self.get_augmented_lagrangian(nll, constraint_violation, reg)\n",
    "        return {\"aug_lagrangian\": aug_lagrangian,\"nll\": nll,\"constraint\": constraint_violation,\"reg\": reg}\n",
    "\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        agg = {}\n",
    "        for k in outputs[0]:\n",
    "            agg[k] = torch.stack([dic[k] for dic in outputs]).mean().item()\n",
    "        self.aug_lagrangians_val += [agg[\"aug_lagrangian\"]]\n",
    "        self.constraint_value = agg[\"constraint\"]\n",
    "        self.reg_value = agg[\"reg\"]\n",
    "        self.not_nlls_val += [agg[\"aug_lagrangian\"] - agg[\"nll\"]]\n",
    "        self.nlls_val += [agg[\"nll\"]]\n",
    "        self.regs += [self.reg_value]\n",
    "        # self.acyclic = self.module.check_acyclicity()\n",
    "\n",
    "        self.log(\"Val/aug_lagrangian\", agg[\"aug_lagrangian\"])\n",
    "        self.log(\"Val/nll\", agg[\"nll\"])\n",
    "        self.log(\"Val/not_nll\", agg[\"aug_lagrangian\"] - agg[\"nll\"])\n",
    "        self.log(\"Val/constraint_violation\", agg[\"constraint\"])\n",
    "        self.log(\"Val/reg_value\", agg[\"reg\"])\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.RMSprop(self.module.parameters(), lr=self.lr_init)\n",
    "\n",
    "    def update_lagrangians(self):\n",
    "        self.internal_checkups += 1\n",
    "        self.log(\"Monitor/checkup\", self.internal_checkups)\n",
    "        # compute delta for gamma to check convergence status\n",
    "        delta_gamma = -np.inf\n",
    "        if len(self.aug_lagrangians_val) >= 3:\n",
    "            t0, t_half, t1 = (self.aug_lagrangians_val[-3],self.aug_lagrangians_val[-2],self.aug_lagrangians_val[-1])\n",
    "            # if the validation loss went up and down, do not update lagrangian and penalty coefficients.\n",
    "            if min(t0, t1) < t_half < max(t0, t1):\n",
    "                delta_gamma = -np.inf\n",
    "            else:\n",
    "                delta_gamma = (t1 - t0) / 100\n",
    "\n",
    "        # if we found a stationary point, but that is not satisfying the acyclicity constraints\n",
    "        if (\n",
    "            self.constraint_value > self.h_threshold\n",
    "            and not self.acyclic\n",
    "            and self.mu < 1e15\n",
    "            or self.stationary_points < 10\n",
    "        ):\n",
    "            if abs(delta_gamma) < self.omega_gamma or delta_gamma > 0:\n",
    "                self.stationary_points += 1\n",
    "                self.log(\"Monitor/stationary\", self.stationary_points)\n",
    "                self.gamma += self.mu * self.constraint_value\n",
    "\n",
    "                # Did the constraint improve sufficiently?\n",
    "                if len(self.constraints_at_stat) > 1:\n",
    "                    if (self.constraint_value>self.constraints_at_stat[-1] * self.omega_mu):\n",
    "                        self.mu *= self.mu_mult_factor\n",
    "                self.constraints_at_stat.append(self.constraint_value)\n",
    "\n",
    "                # little hack to make sure the moving average is going down.\n",
    "                gap_in_not_nll = (self.get_augmented_lagrangian(0.0, self.constraint_value, self.reg_value)- self.not_nlls_val[-1])\n",
    "                assert gap_in_not_nll > -1e-2\n",
    "                self.aug_lagrangians_val[-1] += gap_in_not_nll\n",
    "\n",
    "                # reset optimizer\n",
    "                self.trainer.optimizers = [self.configure_optimizers()]\n",
    "\n",
    "        # if we found a stationary point, that satisfies the acyclicity constraints, raise this flag, it will activate patience and terminate training soon\n",
    "        else:\n",
    "            self.trainer.satisfied_constraints = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfa8df39",
   "metadata": {},
   "source": [
    "<span style=\"font-size: 24px;\"> gumbel part remained unchanged"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a344ddec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gumbel import gumbel_sigmoid, gumbel_softmax\n",
    "\n",
    "\n",
    "@njit\n",
    "def samesign(a, b):\n",
    "    return a * b > 0\n",
    "\n",
    "\n",
    "def bisect(func, low, high, T=20):\n",
    "    \"Find root of continuous function where f(low) and f(high) have opposite signs\"\n",
    "    flow = func(low)\n",
    "    fhigh = func(high)\n",
    "    assert not samesign(flow, fhigh)\n",
    "    for i in tqdm(range(T), desc=\"bisecting\"):\n",
    "        midpoint = (low + high) / 2.0\n",
    "        fmid = func(midpoint)\n",
    "        if samesign(flow, fmid):\n",
    "            low = midpoint\n",
    "            flow = fmid\n",
    "        else:\n",
    "            high = midpoint\n",
    "            fhigh = fmid\n",
    "    # after all those iterations, low has one sign, and high another one. midpoint is unknown\n",
    "    return high\n",
    "\n",
    "\n",
    "@njit\n",
    "def _is_acyclic(adjacency):\n",
    "    \"\"\"\n",
    "    Return true if adjacency is a acyclic\n",
    "    :param np.ndarray adjacency: adjacency matrix\n",
    "    \"\"\"\n",
    "    prod = np.eye(adjacency.shape[0], dtype=adjacency.dtype)\n",
    "    for _ in range(1, adjacency.shape[0] + 1):\n",
    "        prod = adjacency @ prod\n",
    "        if np.trace(prod) != 0:\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "\n",
    "def is_acyclic(adjacency):\n",
    "    return _is_acyclic(adjacency.astype(float))\n",
    "\n",
    "\n",
    "class GumbelAdjacency(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Probabilistic mask used for DAG learning.\n",
    "    Can sample a matrix and backpropagate using the\n",
    "    Gumbel straigth-through estimator.\n",
    "    :param int num_vars: number of variables\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, num_rows, num_cols=None):\n",
    "        super(GumbelAdjacency, self).__init__()\n",
    "        if num_cols is None:\n",
    "            # square matrix\n",
    "            self.num_vars = (num_rows, num_rows)\n",
    "        else:\n",
    "            self.num_vars = (num_rows, num_cols)\n",
    "        self.log_alpha = torch.nn.Parameter(torch.zeros(self.num_vars))\n",
    "        self.tau = 1\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def forward(self, bs):\n",
    "        adj = gumbel_sigmoid(self.log_alpha, bs, tau=self.tau, hard=True)\n",
    "        return adj\n",
    "\n",
    "    def get_proba(self):\n",
    "        \"\"\"Returns probability of getting one\"\"\"\n",
    "        return torch.sigmoid(self.log_alpha / self.tau)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        torch.nn.init.constant_(self.log_alpha, 5)\n",
    "\n",
    "\n",
    "class GumbelInNOut(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Random matrix M used for encoding egdes between modules and genes.\n",
    "    Category:\n",
    "    - 0 means no edge\n",
    "    - 1 means node2module edge\n",
    "    - 2 means module2node edge\n",
    "    Can sample a matrix and backpropagate using the\n",
    "    Gumbel straigth-through estimator.\n",
    "    :param int num_vars: number of variables\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,num_nodes, num_modules):\n",
    "        super(GumbelInNOut, self).__init__()\n",
    "        self.num_vars = (num_nodes, num_modules)\n",
    "        self.log_alpha = torch.nn.Parameter(torch.zeros(num_nodes, num_modules, 3))\n",
    "        self.register_buffer(\"freeze_node2module\",torch.zeros((num_nodes, num_modules)))\n",
    "        self.register_buffer(\"freeze_module2node\",torch.zeros((num_nodes, num_modules)))\n",
    "        self.tau = 1\n",
    "        self.drawhard = True\n",
    "        self.deterministic = False\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def forward(self, bs):\n",
    "        if not self.deterministic:\n",
    "            design = self.gumbel_softmax(self.log_alpha, bs, tau=self.tau, hard=self.drawhard)\n",
    "            node2module = design[:, :, :, 0]\n",
    "            module2node = design[:, :, :, 1]\n",
    "        else:\n",
    "            node2module = self.freeze_node2module.unsqueeze(0)\n",
    "            module2node = self.freeze_module2node.unsqueeze(0)\n",
    "        return node2module, module2node\n",
    "\n",
    "    def freeze_threshold(self, threshold):\n",
    "        \"\"\"Returns probability of being assigned into a bucket\"\"\"\n",
    "        design = torch.softmax(self.log_alpha / self.tau, -1)\n",
    "        node2module = design[:, :, 0]\n",
    "        module2node = design[:, :, 1]\n",
    "        max_in_out = torch.maximum(node2module, module2node)\n",
    "        # zero for low confidence\n",
    "        mask_keep = max_in_out >= threshold\n",
    "        # track argmax\n",
    "        self.freeze_node2module = (node2module == max_in_out) * mask_keep\n",
    "        self.freeze_module2node = (module2node == max_in_out) * mask_keep\n",
    "        self.deterministic = True\n",
    "        print(\"Freeze threshold:\" + str(self.freeze_module2node.device))\n",
    "\n",
    "    def get_proba_modules(self):\n",
    "        \"\"\"Returns probability of being assigned into a bucket\"\"\"\n",
    "        design = torch.softmax(self.log_alpha / self.tau, -1)\n",
    "        node2module = design[:, :, 0]\n",
    "        module2node = design[:, :, 1]\n",
    "        mat = module2node.T @ node2module\n",
    "        # above is correct except for diagonal values (individual values in the matrix product are corr.)\n",
    "        mask_modules = torch.ones(self.num_vars[1], self.num_vars[1]) - torch.eye(self.num_vars[1])\n",
    "        return mat * mask_modules.type_as(mat)\n",
    "\n",
    "    def get_proba_features(self, threshold=None,other=None):\n",
    "        \"\"\"Returns probability of being assigned into a bucket\"\"\"\n",
    "        design = torch.softmax(self.log_alpha / self.tau, -1)\n",
    "        node2module = design[:, :, 0]\n",
    "        module2node = design[:, :, 1]\n",
    "        if not threshold:\n",
    "            # return a differentiable tensor\n",
    "            mat = node2module @ module2node.T\n",
    "            # above is correct except for diagonal values (individual values in the matrix product are corr.)\n",
    "            mask_nodes = torch.ones(self.num_vars[0], self.num_vars[0]) - torch.eye(self.num_vars[0])\n",
    "            return mat * mask_nodes.type_as(mat)\n",
    "        else:\n",
    "            # here return a matrix without grad\n",
    "            # we're thresholding here according to the edge direction confidence\n",
    "            max_in_out = torch.maximum(design[:, :, 0], design[:, :, 1])\n",
    "            # zero for low confidence\n",
    "            mask_keep = design[:, :, 0] + design[:, :, 1] >= threshold\n",
    "            # track argmax\n",
    "            node2module = (design[:, :, 0] == max_in_out) * mask_keep\n",
    "            module2node = (design[:, :, 1] == max_in_out) * mask_keep\n",
    "            # that product below has no self cycles\n",
    "            if other!=None:\n",
    "                return (node2module.type_as(self.log_alpha),module2node.type_as(self.log_alpha))\n",
    "            return (\n",
    "                node2module.type_as(self.log_alpha)\n",
    "                @ module2node.type_as(self.log_alpha).T\n",
    "            )\n",
    "\n",
    "    def get_proba_(self):\n",
    "        design = torch.softmax(self.log_alpha / self.tau, -1)\n",
    "        node2module = design[:, :, 0]\n",
    "        module2node = design[:, :, 1]\n",
    "        return node2module, module2node\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        torch.nn.init.constant_(self.log_alpha, 1)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da6c5c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f556d2b2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
