import json
from typing import Any, Dict, Iterable, Literal, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from .fdag import BasicFG, SPNFG
from .util import pred_score, shd_metric
from .scheduler import ConstScheduler, LinearScheduler, CosineScheduler, SigmoidScheduler, SparsityScheduler
import abcdefg.config as config
import lightning as L


class CGM(L.LightningModule):
    def __init__(self,
                 num_vars: int,
                 num_interventions: int,
                 num_modules: int,
                 num_layers: int,
                 num_dec_layers: int,
                 hid_dim: Iterable[int],
                 batch_size: int,
                 graph_model: Literal['BasicFG', 'SPNFG'],
                 sample_mask: bool = True,
                 hard_mask: bool = True,
                 tau: float = 1.0,
                 fix_factor: bool = False,
                 z_prior: Dict = {
                     'mean': 0.0,
                     'std': 1.0
                 },
                 edge_prior: float = 0.25,
                 var_type: Literal['const', 'learned'] = 'const',
                 noise_level: float = 1.0,
                 nonlin: str = 'relu',
                 lr_nn: float = 1e-4,
                 lr_fg: float = 1e-3,
                 loss_coeff: Dict[str, float] = {
                    'kl_z': 0.25,
                    'kl_fg': 0.25,
                    'l1_reg': 1.0,
                    'l1_reg_int': 1.0,
                 },
                 coeff_scheduler: Literal['const', 'linear', 'cosine', 'sigmoid', 'sparsity'] = 'sigmoid',
                 robust_loss: bool = False,
                 loss_memory: float = 0.9):
        super(CGM, self).__init__()
        self.save_hyperparameters()
        self.num_vars = num_vars
        self.num_interventions = num_interventions
        self.num_modules = num_modules
        # num_partitions is the num_modules+1
        self.num_partitions = num_modules+1
        self.num_layers = num_layers
        self.num_dec_layers = num_dec_layers
        self.hid_dim = hid_dim
        self.batch_size = batch_size
        self.sample_mask = sample_mask
        self.hard_mask = hard_mask
        self.fix_factor = fix_factor
        self.nonlin = nonlin
        self.lr_nn = lr_nn
        self.lr_fg = lr_fg
        self.robust_loss = robust_loss
        self.loss_memory = loss_memory
        self.loss_accum = None  # Accumulated loss
        self.wait_smooth = 1  # Skip the first evaluation, which starts before training

        # We need manual optimization due to more than one optimizers
        self.automatic_optimization = False
        
        self.weights_n2m = nn.ParameterList()
        self.weights_m2n = nn.ParameterList()
        self.biases_n2m = nn.ParameterList()
        self.biases_m2n = nn.ParameterList()
        self.weights_m2std = None
        self.biases_m2std = None

        self.var_type = var_type
        self.act_std = nn.Softplus()
        if var_type == 'const':
            self.register_parameter(
                'log_std', nn.Parameter(torch.log(noise_level * torch.ones((self.num_vars,))))
            )
        elif var_type == 'learned':
            self.log_std = None
        else:
            raise NotImplementedError(f"noise_level {noise_level} is not implemented.")

        if graph_model == 'BasicFG':
            self.sample_uv = BasicFG(
                self.num_vars,
                self.num_interventions,
                self.num_modules,
                sample_mask=self.sample_mask,
                hard_mask=self.hard_mask,
                **config.get_graph_params('BasicFG')
            ).to(self.device)
        elif graph_model == 'SPNFG':
            self.sample_uv = SPNFG(
                self.num_vars,
                self.num_interventions,
                self.num_modules,
                sample_mask=self.sample_mask,
                hard_mask=self.hard_mask,
                **config.get_graph_params('SPNFG')
            )
        else:
            raise NotImplementedError(f"Graph model {graph_model} is not implemented.")
        self.tau = tau
        
        self.init_nn()

        # set the prior for z
        self.register_buffer('mu_p', torch.tensor([z_prior['mean']]))
        self.register_buffer('std_p', torch.tensor([z_prior['std']]))

        # set the prior for the basic factor graph
        self.register_buffer('p_edge', torch.tensor([edge_prior, 1-edge_prior]))
        self.register_buffer('p_y', torch.tensor([1/(self.num_modules+1)]*(self.num_modules+1)))

        # Coefficients
        self.loss_coeff = loss_coeff
        self.coeff_scheduler_type = coeff_scheduler
        self.coeff_scheduler = None

        # Used for simulation
        self.true_graph = None
        self.true_graph_int = None
    
    def init_scheduler(self):
        kwargs = config.get_scheduler_params(self.coeff_scheduler_type)
        if self.coeff_scheduler_type == 'const':
            self.coeff_scheduler = ConstScheduler(
                **kwargs
            )
            self.coeff_scheduler_int = ConstScheduler(
                **kwargs
            )
        elif self.coeff_scheduler_type == 'linear':
            self.coeff_scheduler = LinearScheduler(
                **kwargs
            )
            self.coeff_scheduler_int = LinearScheduler(
                **kwargs
            )
        elif self.coeff_scheduler_type == 'cosine':
            self.coeff_scheduler = CosineScheduler(
                **kwargs
            )
            self.coeff_scheduler_int = CosineScheduler(
                **kwargs
            )
        elif self.coeff_scheduler_type == 'sigmoid':
            self.coeff_scheduler = SigmoidScheduler(
                **kwargs
            )
            self.coeff_scheduler_int = SigmoidScheduler(
                **kwargs
            )
        elif self.coeff_scheduler_type == 'sparsity':
            self.coeff_scheduler = SparsityScheduler(
                **kwargs
            )
            self.coeff_scheduler_int = SparsityScheduler(
                **kwargs
            )
        else:
            raise NotImplementedError(f"Scheduler {self.coeff_scheduler_type} is not implemented.")

    def init_nn(self):
        """ Initialize the neural network parameters. """
        for i in range(self.num_layers-1):
            in_dim = self.num_vars+self.num_interventions if i == 0 else self.hid_dim
            out_dim = self.hid_dim

            self.weights_n2m.append(
                nn.Parameter(torch.zeros(self.num_modules, out_dim, in_dim))
            )
            self.biases_n2m.append(
                nn.Parameter(torch.zeros(self.num_modules, out_dim))
            )

        for i in range(self.num_dec_layers):
            in_dim = self.num_modules if i == 0 else self.hid_dim
            out_dim = 1 if i == self.num_dec_layers-1 else self.hid_dim

            self.weights_m2n.append(
                nn.Parameter(torch.zeros(self.num_vars, out_dim, in_dim))
            )
            self.biases_m2n.append(
                nn.Parameter(torch.zeros(self.num_vars, out_dim))
            )
        
            if self.var_type == 'learned' and i == self.num_dec_layers-1:
                self.weights_m2std = nn.Parameter(torch.zeros(self.num_vars, out_dim, in_dim))
                self.biases_m2std = nn.Parameter(torch.zeros(self.num_vars, out_dim))

        # init params
        with torch.no_grad():
            for node in range(self.num_modules):
                for i, w in enumerate(self.weights_n2m):
                    nn.init.xavier_normal_(w[node], gain=nn.init.calculate_gain(self.nonlin))
                for i, b in enumerate(self.biases_n2m):
                    b[node].zero_()
            for node in range(self.num_vars):
                for i, w in enumerate(self.weights_m2n):
                    nn.init.xavier_normal_(w[node], gain=nn.init.calculate_gain(self.nonlin))
                for i, b in enumerate(self.biases_m2n):
                    b[node].zero_()
            if self.var_type == 'learned':
                for node in range(self.num_vars):
                    nn.init.xavier_normal_(self.weights_m2std[node], gain=nn.init.calculate_gain(self.nonlin))
                    self.biases_m2std[node].zero_()

        # Output layer of the encoder
        if not self.fix_factor:
            self.register_parameter('weight_mu', nn.Parameter(torch.zeros(self.num_modules, self.hid_dim)))
            self.register_parameter('bias_mu', nn.Parameter(torch.zeros(self.num_modules,)))
            self.register_parameter('weight_std', nn.Parameter(torch.zeros(self.num_modules, self.hid_dim)))
            self.register_parameter('bias_std', nn.Parameter(torch.zeros(self.num_modules,)))
            nn.init.xavier_normal_(self.weight_mu)
            nn.init.xavier_normal_(self.weight_std)
        else:
            self.register_parameter('weight_z', nn.Parameter(torch.zeros(self.num_modules, self.hid_dim)))
            self.register_parameter('bias_z', nn.Parameter(torch.zeros(self.num_modules,)))
            nn.init.xavier_normal_(self.weight_z)
        
        if self.nonlin == 'relu':
            self.act = nn.ReLU()
        elif self.nonlin == 'leaky_relu':
            self.act = nn.LeakyReLU()
        elif self.nonlin == 'sigmoid':
            self.act = nn.Sigmoid()
        elif self.nonlin == 'tanh':
            self.act = nn.Tanh()
        elif self.nonlin == 'linear':
            self.act = nn.Identity()
        else:
            raise ValueError(f"Activation function {self.nonlin} is not implemented.")

    def configure_optimizers(self):
        if self.fix_factor:
            nn_params = list(self.weights_n2m) + list(self.biases_n2m)\
                + list(self.weights_m2n) + list(self.biases_m2n)\
                + [self.weight_z, self.bias_z]
        else:
            nn_params = list(self.weights_n2m) + list(self.biases_n2m)\
                + list(self.weights_m2n) + list(self.biases_m2n)\
                + [self.weight_mu, self.weight_std, self.bias_mu, self.bias_std]
        if self.var_type == 'learned':
            nn_params += [self.weights_m2std, self.biases_m2std]

        optimizer = torch.optim.AdamW(
            nn_params, lr=self.lr_nn, weight_decay=config.get_optim_params('weight_decay')
        )
        optimizer_fg = torch.optim.AdamW(
            list(self.sample_uv.parameters()), lr=self.lr_fg
        )
        return optimizer, optimizer_fg

    def _linear_schedule(self, step: int, total_steps: int, start: float, end: float):
        """ Linearly schedule L1 or KL regularization coefficents. """
        return start + (end - start) * step / total_steps
    
    def _sigmoid_schedule(self, step: int, mid: int, scale: int, start: float, end: float):
        """ Sigmoid schedule for the loss coefficient. """
        return start + (end - start) * 1 / (1 + np.exp(-(step - mid)/scale))
    
    def set_true_graph(self, true_graph: np.ndarray, true_graph_int: np.ndarray = None):
        """Set the ground truth graph for evaluation.

        Args:
            true_graph (np.ndarray): Ground truth adjacency matrix.
            true_graph_int (np.ndarray): Ground truth adjacency matrix for unknown interventions.
        """
        self.true_graph = true_graph
        self.true_graph_int = true_graph_int

    # Functions for the forward pass
    def reparameterize(self, mu: Tensor, std: Tensor) -> Tensor:
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z

    def encode(self, x: Tensor, mask_n2m: Tensor) -> Tensor:
        """Encode the input data into latent variables.

        Args:
            x (Tensor): input data of size (n_batch, num_vars).
            mask_n2m (Tensor): node-to-module matrices of size (n_batch, num_modules, num_vars+num_interventions).

        Returns:
            Tuple[Tensor, Tensor]: mean and std of latent variables of size (n_batch, hidden_dim).
        """
        batch_size = x.shape[0]
        for weights, biases, mask in [(self.weights_n2m, self.biases_n2m, mask_n2m)]:
            for layer in range(self.num_layers - 1):
                num_module, num_hid, _ = weights[layer].shape
                # Linear layer
                if layer == 0:
                    x = torch.einsum("tij,bjt,bj->bti", weights[layer], mask, x) + biases[layer]
                else:
                    x = torch.einsum("tij,btj->bti", weights[layer], x) + biases[layer]
                if layer < self.num_layers - 1:
                    # Activation
                    x = self.act(x).reshape(batch_size, num_module, num_hid)
        
        # get mean and variance if is vae. Otherwise, directly output z
        if self.fix_factor:
            return torch.einsum("tj,btj->bt", self.weight_z, x) + self.bias_z, None

        mu = torch.einsum("tj,btj->bt", self.weight_mu, x) + self.bias_mu
        std = F.softplus(torch.einsum("tj,btj->bt", self.weight_std, x) + self.bias_std)
        return mu, std

    def decode(self, x, mask_m2n):
        """Reconstruct gene expression from latent variables.

        Args:
            x (:class:`torch.tensor`): Latent variables of size (n_batch, hidden_dim).
            mask_m2n (:class:`torch.tensor`): module-to-node matrices of size (n_batch, num_modules, num_vars+num_interventions).

        Returns:
            :class:`torch.tensor`: reconstructed gene expression
        """
        batch_size = x.shape[0]
        for weights, biases, mask in [(self.weights_m2n, self.biases_m2n, mask_m2n)]:
            for layer in range(self.num_dec_layers):
                num_node, num_hid, _ = weights[layer].shape
                # standard deviation
                if self.var_type == 'learned' and layer == self.num_dec_layers - 1:
                    if layer == 0:
                        out_std = torch.einsum("tij,bjt,bj->bti", self.weights_m2std, mask, x) \
                            + self.biases_m2std[layer]
                    else:
                        out_std = torch.einsum("tij,btj->bti", self.weights_m2std, out) \
                            + self.biases_m2std[layer]

                # Linear Layer
                if layer == 0:  # i: num_vars+num_interventions, j: hid_dim, t: num_modules
                    out = (torch.einsum("tij,bjt,bj->bti", weights[layer], mask, x) + biases[layer])
                else:
                    out = torch.einsum("tij,btj->bti", weights[layer], out) + biases[layer]

                if layer < self.num_dec_layers - 1:
                    # Activation
                    out = self.act(out).reshape(batch_size, num_node, num_hid)

        # get the standard deviation
        if self.var_type == 'const':
            out_std = torch.exp(self.log_std)
        elif self.var_type == 'learned':
            out_std = F.relu(out_std - 1e-10) + 1e-10
        else:
            raise NotImplementedError(f"var_type {self.var_type} is not implemented.")
        return out.squeeze(), out_std.squeeze()
    
    def forward(self, x: Tensor) -> Tensor:
        """ Forward pass of the model. """
        mask_n2m, mask_m2n = self.sample_uv(1)
        if mask_m2n.ndim == 3:
            mask_m2n = torch.transpose(mask_m2n, 1, 2)
        else:
            mask_m2n = torch.transpose(mask_m2n, 0, 1)

        # vae part
        if self.fix_factor:
            z, _ = self.encode(x, mask_n2m)
        else:
            mu, std = self.encode(x, mask_n2m)
            std = F.relu(std - 1e-16) + 1e-16
            self.mu_q = mu
            self.std_q = std
            z = self.reparameterize(mu, std)

        self.mask_n2m = mask_n2m
        self.mask_m2n = mask_m2n

        x_mean, x_std = self.decode(z, mask_m2n)

        return x_mean, x_std
    
    # Loss terms
    def _neg_log_likelihood(self, x):
        """Return elementwise negative log-likelihood of the model for each example.
        """
        x_mean, x_std = self.forward(x)
        if self.num_interventions > 0:
            x = x[:, :-self.num_interventions]
        if self.robust_loss:
            # Use Huber Loss for Robustness
            hb_loss = nn.HuberLoss(delta=100.0, reduction='none')
            return hb_loss(x_mean/x_std, x/x_std) + 0.5*torch.log(2*np.pi*x_std.pow(2))
        return 0.5*((x_mean - x)/x_std).pow(2) + 0.5*torch.log(2*np.pi*x_std.pow(2))

    def nll(self, x, masks):
        """ Negative log likelihood """
        if masks is not None:
            return torch.mean(torch.sum(self._neg_log_likelihood(x) * masks, dim=1), dim=0)
        return torch.mean(torch.sum(self._neg_log_likelihood(x), dim=1), dim=0)

    def mse(self, x, masks=None):
        """ mean-squared error"""
        x_mean, _ = self.forward(x)
        if self.num_interventions > 0:
            x = x[:, :-self.num_interventions]
        if masks is not None:
            return ((x_mean - x).pow(2) * masks).sum() / masks.sum()
        return (x_mean - x).pow(2).mean()

    def kl_losses(self, p, q):
        """ KL divergence between two distributions 
        :deprecated: Use kl_losses_norm instead
        """
        kl_loss = nn.KLDivLoss(reduction="none")
        kl_loss_pq = torch.mean(torch.sum(kl_loss(p, q), dim=1))
        return kl_loss_pq
    
    def kl_losses_norm(self, mu_p, mu_q, std_p, std_q):
        """ KL divergence between two Gaussian distributions """
        # kl(p, q)
        kl_1 = torch.log((std_q)/(std_p))
        if torch.any(torch.isnan(kl_1)) or torch.any(torch.isinf(kl_1)):
            kl_1 = torch.log(std_q+1e-10) - torch.log(std_p+1e-10)
        kl_2 = ((std_p.pow(2)+(mu_p-mu_q).pow(2))/(2 * std_q.pow(2)))-0.5
        if torch.any(torch.isnan(kl_2)) or torch.any(torch.isinf(kl_2)):
            kl_2 = ((std_p.pow(2)+(mu_p-mu_q).pow(2)+1e-10)/(2 * std_q.pow(2)+1e-10))-0.5
        kl_norm = torch.mean(torch.sum((kl_1+kl_2), dim=1))
        return kl_norm

    def losses(self, x, mask=None):
        """ Calculate all loss terms. """
        re_loss = self.nll(x, mask)
        uv_kl = self.sample_uv.kl_loss(self.p_edge, self.p_y)
        if self.fix_factor:
            z_kl = 0
        else:
            z_kl = self.kl_losses_norm(self.mu_q, self.mu_p, self.std_q, self.std_p)
        if self.num_interventions == 0:
            l1_reg = torch.norm(self.mask_m2n, p=1, dim=(1, 2)).mean() \
                + torch.norm(self.mask_n2m, p=1, dim=(1, 2)).mean()
            #l1_reg = torch.norm(torch.bmm(self.mask_n2m, self.mask_m2n), p=1, dim=(1, 2)).mean()
            l1_reg_int = 0
        else:
            l1_reg = torch.norm(self.mask_m2n, p=1, dim=(1, 2)).mean() \
                + torch.norm(self.mask_n2m[:, :-self.num_interventions, :], p=1, dim=(1, 2)).mean()
            l1_reg_int = torch.norm(self.mask_n2m[:, -self.num_interventions:, :], p=1, dim=(1, 2)).mean()
            #l1_reg = torch.norm(torch.bmm(self.mask_n2m[:, :-self.num_interventions, :], self.mask_m2n), p=1, dim=(1, 2)).mean()
            #l1_reg_int = torch.norm(torch.bmm(self.mask_n2m[:, -self.num_interventions:, :], self.mask_m2n), p=1, dim=(1, 2)).mean()

        return re_loss, uv_kl, z_kl, l1_reg, l1_reg_int

    def eval_graph(self, truth, truth_int=None, fix_graph=True):
        """Evaluate the graph structure.

        Args:
            truth (array like): Ground truth adjacency matrix.
            truth_int (array like): Ground truth adjacency matrix for unknown interventions.
            fix_graph (bool, optional): Whether to fix the graph structure. Defaults to True.
        """
        i2m = None
        with torch.no_grad():
            if fix_graph:
                n2m, m2n, i2m = self.sample_uv.fixed_fg(hard=True)
            else:
                n2m, m2n = self.sample_uv.sample()
                if self.num_interventions > 0:
                    i2m = n2m[:, -self.num_interventions:, :]
                    n2m = n2m[:, :-self.num_interventions, :]
            if i2m is not None:
                i2m = i2m.squeeze(0).detach().cpu().numpy()
            n2m = n2m.squeeze(0).detach().cpu().numpy()
            m2n = m2n.squeeze(0).detach().cpu().numpy()

            graph_pred = (np.matmul(n2m, m2n.T) > 0).astype(int)

        # Precision and recall
        precision, recall, f1 = pred_score(graph_pred, truth)

        # Structured Hamming Distance
        shd = shd_metric(graph_pred, truth)

        # metrics for unknown interventions
        precision_int, recall_int, f1_int, shd_int = np.nan, np.nan, np.nan, np.nan
        if truth_int is not None:
            graph_int_pred = (np.matmul(i2m, m2n.T) > 0).astype(int)
            precision_int, recall_int, f1_int = pred_score(graph_int_pred, truth_int)
            
            # convert bipartite graph to adjacency matrix
            n_int, n_feature = graph_int_pred.shape
            adj_int = np.zeros((n_int+n_feature, n_int+n_feature))
            adj_int[n_feature:, :n_feature] = graph_int_pred
            adj_int_truth = np.zeros((n_int+n_feature, n_int+n_feature))
            adj_int_truth[n_feature:, :n_feature] = truth_int
            shd_int = shd_metric(adj_int, adj_int_truth)
        return (precision, recall, f1, shd, precision_int, recall_int, f1_int, shd_int)


    # Lightning functions
    def on_fit_start(self):
        """ Initialize the scheduler. """
        self.init_scheduler()
    
    def on_train_epoch_start(self):
        """ Update the scheduler. """
        if self.coeff_scheduler_type == 'sparsity' and self.trainer.current_epoch > 0:
            self.coeff_scheduler.update_coefficient(self.logger.__getattribute__('train_l1_reg'))
        else:
            self.coeff_scheduler.update_coefficient(self.trainer.current_epoch)

    def training_step(self, batch, batch_idx):
        """Training step for the model."""
        if isinstance(batch, tuple) or isinstance(batch, list):
            x = batch[0].float()
            masks = batch[1].float()
        else:
            x = batch.float()
            masks = None
        re_loss, kl_fg, kl_z, l1_reg, l1_reg_int = self.losses(x, masks)
        # Get coefficients
        coeff = self.coeff_scheduler.get_coefficient()
        coeff_kl_fg = coeff * self.loss_coeff['kl_fg']
        coeff_kl_z = coeff * self.loss_coeff['kl_z']
        coeff_l1_reg = coeff * self.loss_coeff['l1_reg']
        coeff_int = self.coeff_scheduler_int.get_coefficient()
        coeff_l1_reg_int = coeff_int * self.loss_coeff['l1_reg_int']

        # Calculate loss
        loss = re_loss + coeff_kl_fg*kl_fg + coeff_kl_z*kl_z + coeff_l1_reg*l1_reg + coeff_l1_reg_int*l1_reg_int

        optimizer, optimizer_fg = self.optimizers()
        self.manual_backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        optimizer_fg.step()
        optimizer_fg.zero_grad()

        train_metrics = {
            'train_loss': loss,
            'train_nll': re_loss,
            'train_kl_fg': kl_fg,
            'train_kl_z': kl_z,
            'train_l1_reg': l1_reg,
            'coeff': coeff
        }
        if self.num_interventions > 0:
            train_metrics['train_l1_reg_int'] = l1_reg_int
        self.log_dict(train_metrics, on_step=False, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step for the model."""
        if isinstance(batch, tuple) or isinstance(batch, list):
            x = batch[0].float()
            masks = batch[1].float()
        else:
            x = batch.float()
            masks = None
        re_loss, kl_fg, kl_z, l1_reg, l1_reg_int = self.losses(x, masks)
        loss = re_loss + self.loss_coeff['kl_fg']*kl_fg + self.loss_coeff['kl_z']*kl_z \
            + self.loss_coeff['l1_reg']*l1_reg + self.loss_coeff['l1_reg_int']*l1_reg_int
        mse = self.mse(x, masks)

        val_metrics = {
            'val_loss': loss,
            'val_nll': re_loss,
            'val_kl_fg': kl_fg,
            'val_kl_z': kl_z,
            'val_l1_reg': l1_reg,
            'val_mse': mse
        }
        if self.num_interventions > 0:
            val_metrics['val_l1_reg_int'] = l1_reg_int
        self.log_dict(val_metrics, on_step=False, on_epoch=True, logger=True)
        return loss
    
    def on_validation_epoch_end(self):
        if self.true_graph is not None:
            precision, recall, f1, shd, precision_int, recall_int, f1_int, shd_int = self.eval_graph(self.true_graph, self.true_graph_int)
            graph_metrics = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'shd': shd,
            }
            if self.true_graph_int is not None:
                graph_metrics['precision_int'] = precision_int
                graph_metrics['recall_int'] = recall_int
                graph_metrics['f1_int'] = f1_int
                graph_metrics['shd_int'] = shd_int
            self.log_dict(graph_metrics, on_step=False, on_epoch=True, logger=True)
        # Smooth the loss
        val_loss = self.trainer.callback_metrics.get('val_loss')
        if self.wait_smooth > 0:
            self.wait_smooth -= 1
        elif self.wait_smooth == 0:
            self.loss_accum = val_loss
            self.wait_smooth -= 1
            self.log('val_loss_accum', self.loss_accum, on_step=False, on_epoch=True, logger=True)
        else:
            self.loss_accum = self.loss_memory * self.loss_accum + (1 - self.loss_memory) * val_loss
            self.log('val_loss_accum', self.loss_accum, on_step=False, on_epoch=True, logger=True)
    
    def test_step(self, batch, batch_idx):
        if isinstance(batch, tuple) or isinstance(batch, list):
            x = batch[0].float()
            masks = batch[1].float()
        else:
            x = batch.float()
            masks = None
        re_loss, kl_fg, kl_z, l1_reg, l1_reg_int = self.losses(x, masks)
        loss = re_loss + self.loss_coeff['kl_fg']*kl_fg + self.loss_coeff['kl_z']*kl_z \
            + self.loss_coeff['l1_reg']*l1_reg + self.loss_coeff['l1_reg_int']*l1_reg_int
        mse = self.mse(x, masks)

        test_metrics = {
            'test_loss': loss,
            'test_nll': re_loss,
            'test_mse': mse
        }
        self.log_dict(test_metrics, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def on_test_epoch_end(self):
        if self.true_graph is not None:
            precision, recall, f1, shd, precision_int, recall_int, f1_int, shd_int = self.eval_graph(self.true_graph, self.true_graph_int)
            graph_metrics = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'shd': shd,
            }
            if self.true_graph_int is not None:
                graph_metrics['precision_int'] = precision_int
                graph_metrics['recall_int'] = recall_int
                graph_metrics['f1_int'] = f1_int
                graph_metrics['shd_int'] = shd_int
            self.log_dict(graph_metrics, on_step=False, on_epoch=True, logger=True)
            with open(f'{self.trainer.log_dir}/metrics.json', 'w') as f:
                json.dump(graph_metrics, f)

            # evaluate sampled graphs and save metrics
            results = {
                'precision': [],
                'recall': [],
                'f1': [],
                'shd': [],
                'precision_int': [],
                'recall_int': [],
                'f1_int': [],
                'shd_int': []
            }
            for i in range(1000):
                _results = self.eval_graph(
                    self.true_graph, self.true_graph_int, fix_graph=False
                )
                results['precision'].append(_results[0])
                results['recall'].append(_results[1])
                results['f1'].append(_results[2])
                results['shd'].append(_results[3])
                results['precision_int'].append(_results[4])
                results['recall_int'].append(_results[5])
                results['f1_int'].append(_results[6])
                results['shd_int'].append(_results[7])
            np.savez(f'{self.trainer.log_dir}/eval_sampled_graph.npz', **results)

    def on_save_checkpoint(self, checkpoint):
        """ Save additional information to the checkpoint """
        # Graph
        i2m = None
        with torch.no_grad():
            n2m, m2n, i2m = self.sample_uv.fixed_fg()
            n2m = n2m.squeeze(0).detach().cpu().numpy()
            m2n = m2n.squeeze(0).detach().cpu().numpy()
            if i2m is not None:
                i2m = i2m.squeeze(0).detach().cpu().numpy()
                checkpoint['i2m'] = i2m.astype(int)
        graph_pred = (np.matmul(n2m, m2n.T) > 0).astype(int)
        checkpoint['graph'] = graph_pred
        checkpoint['n2m'] = n2m
        checkpoint['m2n'] = m2n

    def save_graph(self, path):
        """Save the model and the graph

        Args:
            path (str): path to save the model
        """
        os.makedirs(path, exist_ok=True)

        # Weighted Graph
        with torch.no_grad():
            n2m, m2n, i2m = self.sample_uv.fixed_fg(hard=False)
            n2m = n2m.squeeze(0).detach().cpu().numpy()
            m2n = m2n.squeeze(0).detach().cpu().numpy()
            if i2m is not None:
                i2m = i2m.squeeze(0).detach().cpu().numpy()
            w_graph_pred = n2m @ m2n.T
        np.save(f'{path}/weighted_graph.npy', w_graph_pred)
        np.save(f'{path}/weighted_n2m.npy', n2m)
        np.save(f'{path}/weighted_m2n.npy', m2n)
        if i2m is not None:
            np.save(f'{path}/weighted_i2m.npy', i2m)
        
        # Graph
        with torch.no_grad():
            n2m, m2n, i2m = self.sample_uv.fixed_fg()
            n2m = n2m.squeeze(0).detach().cpu().numpy()
            m2n = m2n.squeeze(0).detach().cpu().numpy()
            if i2m is not None:
                i2m = i2m.squeeze(0).detach().cpu().numpy()
            graph_pred = (np.matmul(n2m, m2n.T) > 0).astype(int)
        np.save(f'{path}/graph.npy', graph_pred)
        np.save(f'{path}/n2m.npy', n2m)
        np.save(f'{path}/m2n.npy', m2n)
        if i2m is not None:
            np.save(f'{path}/i2m.npy', i2m)

