import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
from .backbone import *
import numpy as np

class GNNSafe(nn.Module):
    '''
    The model class of energy-based models for out-of-distribution detection
    The parameter args.use_reg and args.use_prop control the model versions:
        Energy: args.use_reg = False, args.use_prop = False
        Energy FT: args.use_reg = True, args.use_prop = False
        GNNSafe: args.use_reg = False, args.use_prop = True
        GNNSafe++ args.use_reg = True, args.use_prop = True
    '''
    def __init__(self, d, c, cfg):
        super(GNNSafe, self).__init__()
        if cfg["backbone"] == 'gcn':
            self.encoder = GCN_baseline(in_channels=d,
                        hidden_channels=cfg["hidden_channels"],
                        out_channels=c,
                        num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"],
                        use_bn=cfg["use_bn"])
        elif cfg["backbone"] == 'mlp':
            self.encoder = MLP(in_channels=d, hidden_channels=cfg["hidden_channels"],
                        out_channels=c, num_layers=cfg["num_layers"],
                        dropout=cfg["dropout"])
        elif cfg["backbone"] == 'gat':
            self.encoder = GAT_baseline(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"], dropout=cfg["dropout"], use_bn=cfg["use_bn"])
        elif cfg["backbone"] == 'mixhop':
            self.encoder = MixHop(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"], dropout=cfg["dropout"])
        elif cfg["backbone"] == 'gcnjk':
            self.encoder = GCNJK(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"], dropout=cfg["dropout"])
        elif cfg["backbone"] == 'gatjk':
            self.encoder = GATJK(d, cfg["hidden_channels"], c, num_layers=cfg["num_layers"], dropout=cfg["dropout"])
        else:
            raise NotImplementedError

    def reset_parameters(self):
        self.encoder.reset_parameters()

    def forward(self, dataset, device):
        '''return predicted logits'''
        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        return self.encoder(x, edge_index)

    def propagation(self, e, edge_index, prop_layers=1, alpha=0.5):
        '''energy belief propagation, return the energy after propagation'''
        e = e.unsqueeze(1)
        N = e.shape[0]
        row, col = edge_index
        d = degree(col, N).float()
        d_norm = 1. / d[col]
        value = torch.ones_like(row) * d_norm
        value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
        adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
        for _ in range(prop_layers):
            e = e * alpha + matmul(adj, e) * (1 - alpha)
        return e.squeeze(1)

    def detect(self, dataset, node_idx, device, cfg):
        '''return negative energy, a vector for all input nodes'''
        x, edge_index = dataset.x.to(device), dataset.edge_index.to(device)
        logits = self.encoder(x, edge_index)
        if cfg["dataset"] in ('proteins', 'ppi'): # for multi-label binary classification
            logits = torch.stack([logits, torch.zeros_like(logits)], dim=2)
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1).sum(dim=1)
        else: # for single-label multi-class classification
            neg_energy = cfg["T"] * torch.logsumexp(logits / cfg["T"], dim=-1)
        if cfg["use_prop"]: # use energy belief propagation
            neg_energy = self.propagation(neg_energy, edge_index, cfg["K"], cfg["alpha"])
        return neg_energy[node_idx]

    def loss_compute(self, dataset_ind, dataset_ood, criterion, device, cfg):
        '''return loss for training'''
        x_in, edge_index_in = dataset_ind.x.to(device), dataset_ind.edge_index.to(device)
        x_out, edge_index_out = dataset_ood.x.to(device), dataset_ood.edge_index.to(device)

        # get predicted logits from gnn classifier
        logits_in = self.encoder(x_in, edge_index_in)
        logits_out = self.encoder(x_out, edge_index_out)

        train_in_idx, train_ood_idx = dataset_ind.splits['train'], dataset_ood.node_idx

        # compute supervised training loss
        if cfg["dataset"] in ('proteins', 'ppi'):
            sup_loss = criterion(logits_in[train_in_idx], dataset_ind.y[train_in_idx].to(device).to(torch.float))
        else:
            pred_in = F.log_softmax(logits_in[train_in_idx], dim=1)
            sup_loss = criterion(pred_in, dataset_ind.y[train_in_idx].squeeze(1).to(device))

        if cfg["use_reg"]: # if use energy regularization
            if cfg["dataset"] in ('proteins', 'ppi'): # for multi-label binary classification
                logits_in = torch.stack([logits_in, torch.zeros_like(logits_in)], dim=2)
                logits_out = torch.stack([logits_out, torch.zeros_like(logits_out)], dim=2)
                energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1).sum(dim=1)
                energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1).sum(dim=1)
            else: # for single-label multi-class classification
                energy_in = - cfg["T"] * torch.logsumexp(logits_in / cfg["T"], dim=-1)
                energy_out = - cfg["T"] * torch.logsumexp(logits_out / cfg["T"], dim=-1)

            if cfg["use_prop"]: # use energy belief propagation
                energy_in = self.propagation(energy_in, edge_index_in, cfg["K"], cfg["alpha"])[train_in_idx]
                energy_out = self.propagation(energy_out, edge_index_out, cfg["K"], cfg["alpha"])[train_ood_idx]
            else:
                energy_in = energy_in[train_in_idx]
                energy_out = energy_out[train_in_idx]

            # truncate to have the same length
            if energy_in.shape[0] != energy_out.shape[0]:
                min_n = min(energy_in.shape[0], energy_out.shape[0])
                energy_in = energy_in[:min_n]
                energy_out = energy_out[:min_n]

            # compute regularization loss
            reg_loss = torch.mean(F.relu(energy_in - cfg["m_in"]) ** 2 + F.relu(cfg["m_out"] - energy_out) ** 2)

            loss = sup_loss + cfg["lamda"] * reg_loss
        else:
            loss = sup_loss

        return loss
