from argparse import ArgumentParser
from math import sqrt
from datetime import datetime

import os
import dill
import dgl
import dgl.function as fn
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics as tm
import wandb
from dgl.nn.pytorch import GraphConv
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from project.utils.deepinteract_constants import FEATURE_INDICES, RESIDUE_COUNT_LIMIT, NODE_COUNT_LIMIT
from project.utils.graph_utils import src_dot_dst, scaling, imp_exp_attn, out_edge_features, exp
from project.utils.esm2.esm2 import ESM2
from project.utils.saprot.saprot import SaProt
from project.utils.deepinter_modules import ResnetTriangleStack
from project.utils.triangle_custom_config import config_parse_args
from project.utils.deepssinter_utils import glorot_orthogonal, get_geo_feats_from_edges, calculate_top_k_prec, \
    extract_object, memory_mb, print_time_diff, rbf


# ------------------
# PyTorch Modules
# ------------------

# Adapted from https://arxiv.org/pdf/1708.02002v2 paper
# and https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html code
class FocalLoss(nn.Module):
    """Focal loss function"""
    def __init__(self, alpha=0.25, gamma=2, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        # Assume input are probabilities (after sigmoid) of shape (La, Lb) and target are binary labels of shape (La, Lb)
        # Here targets are assumed to be of the same shape as inputs and contain binary labels
        if target.dtype != torch.float32:
            target = target.float()
        ce_loss = F.binary_cross_entropy(input, target, reduction='none')

        # p_t is the probability of the class with label=1 (contact) if target=1 
        # else it's the probability of the class with label=0 (not a contact)
        p_t = input * target + (1 - input) * (1 - target)

        # Calculate focal loss
        loss = ce_loss * ce_loss * ((1 - p_t) ** self.gamma)

        # Calculate alpha-balanced focal loss
        if self.alpha >= 0:
            alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
            loss = alpha_t * loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise ValueError(
                'Invalid value for reduction. Supported reduction modes: "mean", "sum", or "none".'
            )
        


# -------------------------------------------------------------------------------------------------------------------------------------
# Following code adapted from 'A Generalization of Transformer Networks to Graphs' (https://github.com/graphdeeplearning/graphtransformer):
# -------------------------------------------------------------------------------------------------------------------------------------
class MultiHeadGeometricAttentionLayer(nn.Module):
    """Compute attention scores with a DGLGraph's node and edge (geometric) features."""

    def __init__(self, num_input_feats: int, num_output_feats: int,
                 num_heads: int, using_bias: bool, update_edge_feats: bool):
        super().__init__()

        # Declare shared variables
        self.num_output_feats = num_output_feats
        self.num_heads = num_heads
        self.using_bias = using_bias
        self.update_edge_feats = update_edge_feats

        # Define node features' query, key, and value tensors, and define edge features' projection tensors
        self.Q = nn.Linear(num_input_feats, self.num_output_feats * self.num_heads, bias=using_bias)
        self.K = nn.Linear(num_input_feats, self.num_output_feats * self.num_heads, bias=using_bias)
        self.V = nn.Linear(num_input_feats, self.num_output_feats * self.num_heads, bias=using_bias)
        self.edge_feats_projection = nn.Linear(num_input_feats, self.num_output_feats * self.num_heads, bias=using_bias)

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        if self.using_bias:
            glorot_orthogonal(self.Q.weight, scale=scale)
            self.Q.bias.data.fill_(0)

            glorot_orthogonal(self.K.weight, scale=scale)
            self.K.bias.data.fill_(0)

            glorot_orthogonal(self.V.weight, scale=scale)
            self.V.bias.data.fill_(0)

            glorot_orthogonal(self.edge_feats_projection.weight, scale=scale)
            self.edge_feats_projection.bias.data.fill_(0)
        else:
            glorot_orthogonal(self.Q.weight, scale=scale)
            glorot_orthogonal(self.K.weight, scale=scale)
            glorot_orthogonal(self.V.weight, scale=scale)
            glorot_orthogonal(self.edge_feats_projection.weight, scale=scale)

    def propagate_attention(self, graph: dgl.DGLGraph):
        # Compute attention scores
        graph.apply_edges(src_dot_dst('K_h', 'Q_h', 'score'))  # , edges)

        # Scale and clip attention scores
        graph.apply_edges(scaling('score', np.sqrt(self.num_output_feats), 5.0))

        # Use available edge features to modify the attention scores
        graph.apply_edges(imp_exp_attn('score', 'proj_e'))

        # Copy edge features as e_out to be passed to edge_feats_MLP
        if self.update_edge_feats:
            graph.apply_edges(out_edge_features('score'))

        # Apply softmax to attention scores, followed by clipping
        graph.apply_edges(exp('score', 5.0))

        # Send weighted values to target nodes
        e_ids = graph.edges()
        graph.send_and_recv(e_ids, fn.u_mul_e('V_h', 'score', 'V_h'), fn.sum('V_h', 'wV'))
        graph.send_and_recv(e_ids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))

    def forward(self, graph: dgl.DGLGraph, node_feats: torch.Tensor, edge_feats: torch.Tensor):
        with graph.local_scope():
            e_out = None
            node_feats_q = self.Q(node_feats)
            node_feats_k = self.K(node_feats)
            node_feats_v = self.V(node_feats)
            edge_feats_projection = self.edge_feats_projection(edge_feats)

            # Reshape tensors into [num_nodes, num_heads, feat_dim] to get projections for multi-head attention
            graph.ndata['Q_h'] = node_feats_q.view(-1, self.num_heads, self.num_output_feats)
            graph.ndata['K_h'] = node_feats_k.view(-1, self.num_heads, self.num_output_feats)
            graph.ndata['V_h'] = node_feats_v.view(-1, self.num_heads, self.num_output_feats)
            graph.edata['proj_e'] = edge_feats_projection.view(-1, self.num_heads, self.num_output_feats)

            # Disperse attention information
            self.propagate_attention(graph)

            # Compute final node and edge representations after multi-head attention
            h_out = graph.ndata['wV'] / (graph.ndata['z'] + torch.full_like(graph.ndata['z'], 1e-6))  # Add eps to all
            if self.update_edge_feats:
                e_out = graph.edata['e_out']

        # Return attention-updated node and edge representations
        return h_out, e_out


# -------------------------------------------------------------------------------------------------------------------------------------
# Following code curated for DeepInteract (https://github.com/BioinfoMachineLearning/DeepInteract):
# -------------------------------------------------------------------------------------------------------------------------------------

class InitEdgeModule(nn.Module):
    """An edge initializer module."""

    def __init__(
            self,
            node_count_limit: int,
            num_edge_feats: int,
            num_dist_feats: int,
            num_dir_feats: int,
            num_orient_feats: int,
            num_amide_feats: int,
            num_hidden_channels: int,
            activ_fn=nn.SiLU,
            feature_indices=FEATURE_INDICES
    ):
        super().__init__()

        # Record parameters given
        self.activ_fn = activ_fn
        self.feature_indices = feature_indices

        # --------------------
        # Initializer Module
        # --------------------
        # Establish node embedding strategy
        self.node_embedding = nn.Embedding(node_count_limit, num_hidden_channels)

        # Define linear layers for the edge initializer module
        self.edge_messages_linear_0 = nn.Linear(num_edge_feats, num_hidden_channels, bias=False)
        self.dist_linear_0 = nn.Linear(num_dist_feats, num_hidden_channels, bias=False)
        self.dir_linear_0 = nn.Linear(num_dir_feats, num_hidden_channels, bias=False)
        self.orient_linear_0 = nn.Linear(num_orient_feats, num_hidden_channels, bias=False)
        self.amide_linear_0 = nn.Linear(num_amide_feats, num_hidden_channels, bias=False)

        self.combined_linear_0 = nn.Linear(7 * num_hidden_channels, num_hidden_channels, bias=False)

        self.edge_messages_linear_1 = nn.Linear(num_edge_feats, num_hidden_channels, bias=False)
        self.dist_linear_1 = nn.Linear(num_dist_feats, num_hidden_channels, bias=False)
        self.dir_linear_1 = nn.Linear(num_dir_feats, num_hidden_channels, bias=False)
        self.orient_linear_1 = nn.Linear(num_orient_feats, num_hidden_channels, bias=False)
        self.amide_linear_1 = nn.Linear(num_amide_feats, num_hidden_channels, bias=False)

        combined_out_channels = num_edge_feats + num_dist_feats + num_dir_feats + num_orient_feats + num_amide_feats
        self.combined_linear_1 = nn.Linear(num_hidden_channels, combined_out_channels, bias=False)
        self.combined_linear_2 = nn.Linear(combined_out_channels, num_hidden_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        self.node_embedding.weight.data.uniform_(-sqrt(3), sqrt(3))

        glorot_orthogonal(self.edge_messages_linear_0.weight, scale=scale)
        glorot_orthogonal(self.dist_linear_0.weight, scale=scale)
        glorot_orthogonal(self.dir_linear_0.weight, scale=scale)
        glorot_orthogonal(self.orient_linear_0.weight, scale=scale)
        glorot_orthogonal(self.amide_linear_0.weight, scale=scale)

        glorot_orthogonal(self.combined_linear_0.weight, scale=scale)

        glorot_orthogonal(self.edge_messages_linear_1.weight, scale=scale)
        glorot_orthogonal(self.dist_linear_1.weight, scale=scale)
        glorot_orthogonal(self.dir_linear_1.weight, scale=scale)
        glorot_orthogonal(self.orient_linear_1.weight, scale=scale)
        glorot_orthogonal(self.amide_linear_1.weight, scale=scale)

        glorot_orthogonal(self.combined_linear_1.weight, scale=scale)
        glorot_orthogonal(self.combined_linear_2.weight, scale=scale)

    def init_edge_module_message_func(self, edges: dgl.udf.EdgeBatch):
        """Edge Initialization Module: Compute the messages for an EdgeBatch of edges.
        This function is set up as a User Defined Function in DGL.

        Parameters
        ----------
        edges : EdgeBatch
            A batch of edges for which to compute messages
        """

        # Embed node features given
        node_indices = edges.src['i_all'][0]  # Rely on temp node indices being cached prior to the call to message_func
        node_feats = self.node_embedding(node_indices)
        src_node_feats, dst_node_feats = node_feats[edges.src['i']], node_feats[edges.dst['i']]

        # Update edge features given
        edge_pos_enc_index = self.feature_indices['edge_pos_enc']
        edge_pos_enc = edges.data['f'][:, edge_pos_enc_index]
        edge_weight_index = self.feature_indices['edge_weights']  # Only use edge weights for now
        edge_weights = edges.data['f'][:, edge_weight_index]
        edge_messages_init = torch.cat((edge_pos_enc.reshape(-1, 1), edge_weights.reshape(-1, 1)), dim=1)
        edge_messages_0 = self.edge_messages_linear_0(edge_messages_init)
        dist_feats, dir_feats, orient_feats, amide_feats = get_geo_feats_from_edges(
            edges.data['f'],
            self.feature_indices
        )
        dist_feats_0 = self.activ_fn(self.dist_linear_0(dist_feats))
        dir_feats_0 = self.activ_fn(self.dir_linear_0(dir_feats))
        orient_feats_0 = self.activ_fn(self.orient_linear_0(orient_feats))
        amide_feats_0 = self.activ_fn(self.amide_linear_0(amide_feats.reshape(-1, 1)))
        combined_edge_logits = self.activ_fn(self.combined_linear_0(torch.cat([
            src_node_feats,
            dst_node_feats,
            edge_messages_0,
            dist_feats_0,
            dir_feats_0,
            orient_feats_0,
            amide_feats_0
        ], dim=1)))

        # Gate edge representations
        edge_messages_1_input = torch.cat((edge_pos_enc.reshape(-1, 1), edge_weights.reshape(-1, 1)), dim=1)
        edge_messages_1 = self.edge_messages_linear_1(edge_messages_1_input) * combined_edge_logits
        dist_feats_1 = self.activ_fn(self.dist_linear_1(dist_feats)) * combined_edge_logits
        dir_feats_1 = self.activ_fn(self.dir_linear_1(dir_feats)) * combined_edge_logits
        orient_feats_1 = self.activ_fn(self.orient_linear_1(orient_feats)) * combined_edge_logits
        amide_feats_1 = self.activ_fn(self.amide_linear_1(amide_feats.reshape(-1, 1))) * combined_edge_logits

        # Combine gated edge representations with linear transformations
        combined_edge_feats = edge_messages_1 + dist_feats_1 + dir_feats_1 + orient_feats_1 + amide_feats_1
        edge_feats = self.combined_linear_1(combined_edge_feats)
        edge_feats = self.combined_linear_2(edge_feats)

        return {
            'f': edge_feats
        }

    def forward(self, graph: dgl.DGLGraph):
        """Perform a forward pass of the edge initializer module to get initial edge representations."""
        with graph.local_scope():
            nodes = graph.nodes()
            num_nodes = len(nodes)
            graph.ndata['i'] = nodes  # Record indices of nodes for message passing
            graph.ndata['i_all'] = torch.cat((nodes, nodes.repeat(num_nodes - 1))).reshape(num_nodes, num_nodes)
            graph.apply_edges(self.init_edge_module_message_func)
            edge_feats = graph.edata['f']
        return edge_feats


class ConformationModule(nn.Module):
    """A geometry-evolving (i.e., conforming) module."""

    def __init__(
            self,
            num_dist_feats: int,
            num_dir_feats: int,
            num_orient_feats: int,
            num_amide_feats: int,
            dist_embed_size: int,
            dir_embed_size: int,
            orient_embed_size: int,
            amide_embed_size: int,
            shared_embed_size: int,
            num_hidden_channels: int,
            num_pre_res_blocks: int,
            num_post_res_blocks: int,
            activ_fn=nn.SiLU(),
            norm_to_apply='batch',
            feature_indices=FEATURE_INDICES
    ):
        super().__init__()

        # Record parameters given
        self.activ_fn = activ_fn
        self.feature_indices = feature_indices

        # --------------------
        # Conformation Module
        # --------------------
        # Define geometric modules for a conformation module
        self.dist_linear_0 = nn.Linear(num_dist_feats, dist_embed_size, bias=False)
        self.dist_linear_1 = nn.Linear(dist_embed_size, num_hidden_channels, bias=False)

        self.dir_linear_0 = nn.Linear(num_dir_feats, dir_embed_size, bias=False)
        self.dir_linear_1 = nn.Linear(dir_embed_size, shared_embed_size, bias=False)

        self.orient_linear_0 = nn.Linear(num_orient_feats, orient_embed_size, bias=False)
        self.orient_linear_1 = nn.Linear(orient_embed_size, shared_embed_size, bias=False)

        self.amide_linear_0 = nn.Linear(num_amide_feats, amide_embed_size, bias=False)
        self.amide_linear_1 = nn.Linear(amide_embed_size, shared_embed_size, bias=False)

        # Define utility modules for a conformation module's representation projections
        self.nbr_linear = nn.Linear(num_hidden_channels, num_hidden_channels)
        self.orig_msg_linear = nn.Linear(num_hidden_channels, num_hidden_channels)

        self.downward_proj = nn.Linear(num_hidden_channels, shared_embed_size, bias=False)
        self.upward_proj = nn.Linear(shared_embed_size, num_hidden_channels, bias=False)

        self.res_connect_linear = nn.Linear(num_hidden_channels, num_hidden_channels)

        # Define residual blocks
        self.pre_res_blocks = nn.ModuleList([
            ResBlock(num_hidden_channels, activ_fn, norm_to_apply) for _ in range(num_pre_res_blocks)
        ])
        self.post_res_blocks = nn.ModuleList([
            ResBlock(num_hidden_channels, activ_fn, norm_to_apply) for _ in range(num_post_res_blocks)
        ])

        # Define final linear layers
        self.final_dist_linear = nn.Linear(num_dist_feats, num_hidden_channels, bias=False)
        self.final_dir_linear = nn.Linear(num_dir_feats, num_hidden_channels, bias=False)
        self.final_orient_linear = nn.Linear(num_orient_feats, num_hidden_channels, bias=False)
        self.final_amide_linear = nn.Linear(num_amide_feats, num_hidden_channels, bias=False)
        self.final_linear = nn.Linear(num_hidden_channels, num_hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        # Initialize geometric modules for a conformation module
        scale = 2.0
        glorot_orthogonal(self.dist_linear_0.weight, scale=scale)
        glorot_orthogonal(self.dist_linear_1.weight, scale=scale)

        glorot_orthogonal(self.dir_linear_0.weight, scale=scale)
        glorot_orthogonal(self.dir_linear_1.weight, scale=scale)

        glorot_orthogonal(self.orient_linear_0.weight, scale=scale)
        glorot_orthogonal(self.orient_linear_1.weight, scale=scale)

        glorot_orthogonal(self.amide_linear_0.weight, scale=scale)
        glorot_orthogonal(self.amide_linear_1.weight, scale=scale)

        # Initialize utility modules for a conformation module's representation projections
        glorot_orthogonal(self.nbr_linear.weight, scale=scale)
        self.nbr_linear.bias.data.fill_(0)

        glorot_orthogonal(self.orig_msg_linear.weight, scale=scale)
        self.orig_msg_linear.bias.data.fill_(0)

        glorot_orthogonal(self.downward_proj.weight, scale=scale)
        glorot_orthogonal(self.upward_proj.weight, scale=scale)

        glorot_orthogonal(self.res_connect_linear.weight, scale=scale)
        self.res_connect_linear.bias.data.fill_(0)

        # Initialize final linear layers
        glorot_orthogonal(self.final_dist_linear.weight, scale=scale)
        glorot_orthogonal(self.final_dir_linear.weight, scale=scale)
        glorot_orthogonal(self.final_orient_linear.weight, scale=scale)
        glorot_orthogonal(self.final_amide_linear.weight, scale=scale)
        glorot_orthogonal(self.final_linear.weight, scale=scale)
        self.final_linear.bias.data.fill_(0)

    def conformation_module_message_func(self, edges: dgl.udf.EdgeBatch):
        """Conformation Module: Compute the updated geometric features for an EdgeBatch of edges.
        This function is set up as a User Defined Function in DGL.

        Parameters
        ----------
        edges : EdgeBatch
            A batch of edges for which to compute messages
        """

        # Secure neighboring edge features of interest and cache original edge features for residual reconnections
        src_nbr_e_ids = edges.data['src_nbr_e_ids'].permute(1, 0)
        dst_nbr_e_ids = edges.data['dst_nbr_e_ids'].permute(1, 0)
        src_nbr_edge_feats = edges.data['f'][src_nbr_e_ids]
        dst_nbr_edge_feats = edges.data['f'][dst_nbr_e_ids]
        nbr_edge_feats = torch.cat((src_nbr_edge_feats, dst_nbr_edge_feats))

        nbr_edge_feats = self.activ_fn(self.nbr_linear(nbr_edge_feats))
        orig_edge_feats = edges.data['orig_f']
        res_edge_feats = edges.data['f']  # Cache for future residual connections

        # - Gather initial geometric edge features
        dist_feats, dir_feats, orient_feats, amide_feats = get_geo_feats_from_edges(
            orig_edge_feats,
            self.feature_indices
        )

        # - Create and apply gating with distance embedding
        embedded_dist_feats = self.dist_linear_1(self.dist_linear_0(dist_feats))
        nbr_edge_feats = nbr_edge_feats * embedded_dist_feats

        # - Reduce dimensionality of neighboring edges' representations
        nbr_edge_feats = self.activ_fn(self.downward_proj(nbr_edge_feats))

        # - Curate and apply gating with direction embedding
        nbr_edge_feats = nbr_edge_feats * self.dir_linear_1(self.dir_linear_0(dir_feats))

        # - Make and apply gating with orientation embedding
        nbr_edge_feats = nbr_edge_feats * self.orient_linear_1(self.orient_linear_0(orient_feats))

        # - Make and apply gating with amide plane-amide plane angle embedding
        nbr_edge_feats = nbr_edge_feats * self.amide_linear_1(self.amide_linear_0(amide_feats.reshape(-1, 1)))

        # - Aggregate each edge's neighboring edge representations and increase resulting edge representations' size
        nbr_edge_feats = torch.sum(nbr_edge_feats, dim=0)  # Each edge is (partially) the sum of its neighboring edges
        nbr_edge_feats = self.activ_fn(self.upward_proj(nbr_edge_feats))

        # - Reintroduce the original edge features and combine them with the learned neighboring geometric edge features
        edge_feats = self.orig_msg_linear(res_edge_feats)
        edge_feats = edge_feats + nbr_edge_feats

        # Use pre-residual connection residual modules as specified
        for layer in self.pre_res_blocks:
            edge_feats = layer(edge_feats)

        # Make a residual reconnection with the original edge features given
        edge_feats = res_edge_feats + self.activ_fn(self.res_connect_linear(edge_feats))

        # Use post-residual connection residual modules as specified
        for layer in self.post_res_blocks:
            edge_feats = layer(edge_feats)

        # Gate with each type of original geometric edge feature
        gated_edge_feats_0 = self.final_dist_linear(dist_feats) * edge_feats
        gated_edge_feats_1 = self.final_dir_linear(dir_feats) * edge_feats
        gated_edge_feats_2 = self.final_orient_linear(orient_feats) * edge_feats
        gated_edge_feats_3 = self.final_amide_linear(amide_feats.reshape(-1, 1)) * edge_feats

        # Combine gated edge feature representations to get final edge representations after residual reconnecting
        combined_gated_edge_feats = gated_edge_feats_0 + gated_edge_feats_1 + gated_edge_feats_2 + gated_edge_feats_3
        edge_feats = res_edge_feats + self.activ_fn(self.final_linear(combined_gated_edge_feats))

        return {
            'f': edge_feats
        }

    def forward(self, graph: dgl.DGLGraph, orig_edge_feats: torch.Tensor):
        """Perform a forward pass of a conformation module to get intermediate edge representations."""
        with graph.local_scope():
            graph.edata['orig_f'] = orig_edge_feats
            graph.apply_edges(self.conformation_module_message_func)
            edge_features = graph.edata['f']
        return edge_features


class ResBlock(nn.Module):
    """A residual block for a conformation module."""

    def __init__(self, hidden_channels, activ_fn=nn.SiLU(), norm_to_apply='batch'):
        super().__init__()

        # Record parameters given
        self.activ_fn = activ_fn

        # Define projection layers for a conformation module residual block
        norm_layer = nn.LayerNorm(hidden_channels) if norm_to_apply == 'layer' else nn.BatchNorm1d(hidden_channels)
        self.res_block = nn.ModuleList([
            nn.Linear(hidden_channels, hidden_channels),
            norm_layer,
            activ_fn,
            nn.Linear(hidden_channels, hidden_channels),
            norm_layer,
            activ_fn,
            nn.Linear(hidden_channels, hidden_channels),
            norm_layer,
            activ_fn
        ])

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        for layer in self.res_block:
            not_norm_layer = not (hasattr(layer, 'normalized_shape') or hasattr(layer, 'running_mean'))
            if hasattr(layer, 'weight') and not_norm_layer:  # Skip init for activation functions and normalizing layers
                glorot_orthogonal(layer.weight, scale=scale)
                layer.bias.data.fill_(0)

    def forward(self, x):
        """Perform a forward pass using residual layers with intermediate activation functions applied."""
        x_res = x
        for layer in self.res_block:
            x_res = layer(x_res)
        return x + x_res


class GeometricTransformerModule(nn.Module):
    """A Geometric Transformer module (equivalent to one layer of graph convolutions)."""

    def __init__(
            self,
            shared_embed_size: int,
            num_dist_feats: int,
            num_dir_feats: int,
            num_orient_feats: int,
            num_amide_feats: int,
            dist_embed_size: int,
            dir_embed_size: int,
            orient_embed_size: int,
            amide_embed_size: int,
            num_hidden_channels: int,
            num_pre_res_blocks: int,
            num_post_res_blocks: int,
            activ_fn=nn.SiLU(),
            residual=True,
            num_attention_heads=4,
            norm_to_apply='batch',
            dropout_rate=0.1,
            knn=20,
            num_layers=4,
            feature_indices=FEATURE_INDICES,
            disable_geometric_mode=False
    ):
        super().__init__()

        """Geometry-Focused Graph Transformer Layer

        Parameters
        ----------
        shared_embed_size : int
            Size of shared embedding in a conformation module.
        dist_embed_size : int
            Size of distance embedding in a conformation module.
        dir_embed_size : int
            Size of direction embedding in a conformation module.
        orient_embed_size : int
            Size of orientation embedding in a conformation module.
        amide_embed_size : int
            Size of embedding in a conformation module for amide plane-amide plane normal vector angles.
        num_hidden_channels : int
            Hidden channel size for both nodes and edges.
        num_pre_res_blocks : int
            Number of residual blocks to apply prior to a residual reconnection.
        num_post_res_blocks : int
            Number of residual blocks to apply following a residual reconnection.
        activ_fn : Module
            Activation function to apply in MLPs.
        residual : bool
            Whether to use a residual update strategy for node features.
        num_attention_heads : int
            How many attention heads to apply to the input node features in parallel.
        norm_to_apply : str
            Which normalization scheme to apply to node and edge representations (i.e. 'batch' or 'layer').
        dropout_rate : float
            How much dropout (i.e. forget rate) to apply before activation functions.
        knn : int
            How many nearest neighbors were used when constructing node neighborhoods.
        num_layers : int
            How many layers of geometric attention to apply.
        feature_indices : dict
            A dictionary listing the start and end indices for each node and edge feature.
        disable_geometric_mode : bool
            Whether to convert the Geometric Transformer into the original Graph Transformer.
        """

        # Record parameters given
        self.activ_fn = activ_fn
        self.residual = residual
        self.num_attention_heads = num_attention_heads
        self.norm_to_apply = norm_to_apply
        self.dropout_rate = dropout_rate
        self.knn = knn
        self.num_layers = num_layers
        self.feature_indices = feature_indices
        self.disable_geometric_mode = disable_geometric_mode

        # --------------------
        # Conformation Module
        # --------------------
        # Define all modules related to a conformation module
        if not self.disable_geometric_mode:
            self.conformation_module = ConformationModule(num_dist_feats=num_dist_feats,
                                                          num_dir_feats=num_dir_feats,
                                                          num_orient_feats=num_orient_feats,
                                                          num_amide_feats=num_amide_feats,
                                                          dist_embed_size=dist_embed_size,
                                                          dir_embed_size=dir_embed_size,
                                                          orient_embed_size=orient_embed_size,
                                                          amide_embed_size=amide_embed_size,
                                                          shared_embed_size=shared_embed_size,
                                                          num_hidden_channels=num_hidden_channels,
                                                          num_pre_res_blocks=num_pre_res_blocks,
                                                          num_post_res_blocks=num_post_res_blocks,
                                                          activ_fn=activ_fn,
                                                          norm_to_apply=norm_to_apply,
                                                          feature_indices=feature_indices)

        # --------------------
        # Transformer Module
        # --------------------
        # Define all modules related to a Geometric Transformer module
        self.apply_layer_norm = 'layer' in self.norm_to_apply.lower()

        self.num_hidden_channels, self.num_output_feats = num_hidden_channels, num_hidden_channels
        if self.apply_layer_norm:
            self.layer_norm1_node_feats = nn.LayerNorm(self.num_output_feats)
            self.layer_norm1_edge_feats = nn.LayerNorm(self.num_output_feats)
        else:  # Otherwise, default to using batch normalization
            self.batch_norm1_node_feats = nn.BatchNorm1d(self.num_output_feats)
            self.batch_norm1_edge_feats = nn.BatchNorm1d(self.num_output_feats)

        self.mha_module = MultiHeadGeometricAttentionLayer(
            self.num_hidden_channels,
            self.num_output_feats // self.num_attention_heads,
            self.num_attention_heads,
            self.num_hidden_channels != self.num_output_feats,  # Only use bias if a Linear() has to change sizes
            update_edge_feats=True
        )

        self.O_node_feats = nn.Linear(self.num_output_feats, self.num_output_feats)
        self.O_edge_feats = nn.Linear(self.num_output_feats, self.num_output_feats)

        # MLP for node features
        dropout = nn.Dropout(p=self.dropout_rate) if self.dropout_rate > 0.0 else nn.Identity()
        self.node_feats_MLP = nn.ModuleList([
            nn.Linear(self.num_output_feats, self.num_output_feats * 2, bias=False),
            self.activ_fn,
            dropout,
            nn.Linear(self.num_output_feats * 2, self.num_output_feats, bias=False)
        ])

        if self.apply_layer_norm:
            self.layer_norm2_node_feats = nn.LayerNorm(self.num_output_feats)
            self.layer_norm2_edge_feats = nn.LayerNorm(self.num_output_feats)
        else:  # Otherwise, default to using batch normalization
            self.batch_norm2_node_feats = nn.BatchNorm1d(self.num_output_feats)
            self.batch_norm2_edge_feats = nn.BatchNorm1d(self.num_output_feats)

        # MLP for edge features
        self.edge_feats_MLP = nn.ModuleList([
            nn.Linear(self.num_output_feats, self.num_output_feats * 2, bias=False),
            self.activ_fn,
            dropout,
            nn.Linear(self.num_output_feats * 2, self.num_output_feats, bias=False)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        glorot_orthogonal(self.O_node_feats.weight, scale=scale)
        self.O_node_feats.bias.data.fill_(0)

        glorot_orthogonal(self.O_edge_feats.weight, scale=scale)
        self.O_edge_feats.bias.data.fill_(0)

        for layer in self.node_feats_MLP:
            if hasattr(layer, 'weight'):  # Skip initialization for activation functions
                glorot_orthogonal(layer.weight, scale=scale)

        for layer in self.edge_feats_MLP:
            if hasattr(layer, 'weight'):
                glorot_orthogonal(layer.weight, scale=scale)

    def run_gt_layer(self, graph: dgl.DGLGraph, node_feats: torch.Tensor,
                     edge_feats: torch.Tensor, orig_edge_feats: torch.Tensor):
        """Perform a forward pass of geometric attention using a multi-head attention (MHA) module."""
        node_feats_in1 = node_feats  # Cache node representations for first residual connection
        edge_feats_in1 = edge_feats  # Cache edge representations for first residual connection

        # Update the input graph's geometric edge features iteratively
        if not self.disable_geometric_mode:
            edge_feats = self.conformation_module(graph, orig_edge_feats)

        # Apply first round of normalization before applying geometric attention, for performance enhancement
        if self.apply_layer_norm:
            node_feats = self.layer_norm1_node_feats(node_feats)
            edge_feats = self.layer_norm1_edge_feats(edge_feats)
        else:  # Otherwise, default to using batch normalization
            node_feats = self.batch_norm1_node_feats(node_feats)
            edge_feats = self.batch_norm1_edge_feats(edge_feats)

        # Get multi-head attention output using provided node and edge representations
        node_attn_out, edge_attn_out = self.mha_module(graph, node_feats, edge_feats)

        node_feats = node_attn_out.view(-1, self.num_output_feats)
        edge_feats = edge_attn_out.view(-1, self.num_output_feats)

        node_feats = F.dropout(node_feats, self.dropout_rate, training=self.training)
        edge_feats = F.dropout(edge_feats, self.dropout_rate, training=self.training)

        node_feats = self.O_node_feats(node_feats)
        edge_feats = self.O_edge_feats(edge_feats)

        # Make first residual connection
        if self.residual:
            node_feats = node_feats_in1 + node_feats  # Make first node residual connection
            edge_feats = edge_feats_in1 + edge_feats  # Make first edge residual connection

        node_feats_in2 = node_feats  # Cache node representations for second residual connection
        edge_feats_in2 = edge_feats  # Cache edge representations for second residual connection

        # Apply second round of normalization after first residual connection has been made
        if self.apply_layer_norm:
            node_feats = self.layer_norm2_node_feats(node_feats)
            edge_feats = self.layer_norm2_edge_feats(edge_feats)
        else:  # Otherwise, default to using batch normalization
            node_feats = self.batch_norm2_node_feats(node_feats)
            edge_feats = self.batch_norm2_edge_feats(edge_feats)

        # Apply MLPs for node and edge features
        for layer in self.node_feats_MLP:
            node_feats = layer(node_feats)
        for layer in self.edge_feats_MLP:
            edge_feats = layer(edge_feats)

        # Make second residual connection
        if self.residual:
            node_feats = node_feats_in2 + node_feats  # Make second node residual connection
            edge_feats = edge_feats_in2 + edge_feats  # Make second edge residual connection

        # Return edge representations along with node representations (for tasks other than interface prediction)
        return node_feats, edge_feats

    def forward(self, graph: dgl.DGLGraph, orig_edge_feats: torch.Tensor):
        """Perform a forward pass of a Geometric Transformer to get intermediate node and edge representations."""
        node_feats, edge_feats = self.run_gt_layer(graph, graph.ndata['f'], graph.edata['f'], orig_edge_feats)
        return node_feats, edge_feats


class FinalGeometricTransformerModule(nn.Module):
    """A (final layer) Geometric Transformer module that combines node and edge representations using self-attention."""

    def __init__(
            self,
            shared_embed_size: int,
            num_dist_feats: int,
            num_dir_feats: int,
            num_orient_feats: int,
            num_amide_feats: int,
            dist_embed_size: int,
            dir_embed_size: int,
            orient_embed_size: int,
            amide_embed_size: int,
            num_hidden_channels: int,
            num_pre_res_blocks: int,
            num_post_res_blocks: int,
            activ_fn=nn.SiLU(),
            residual=True,
            num_attention_heads=4,
            norm_to_apply='batch',
            dropout_rate=0.1,
            knn=20,
            num_layers=4,
            feature_indices=FEATURE_INDICES,
            disable_geometric_mode=False
    ):
        super().__init__()

        """Geometry-Focused Graph Transformer (Final) Layer

        Parameters
        ----------
        shared_embed_size : int
            Size of shared embedding in a conformation module.
        dist_embed_size : int
            Size of distance embedding in a conformation module.
        dir_embed_size : int
            Size of direction embedding in a conformation module.
        orient_embed_size : int
            Size of orientation embedding in a conformation module.
        amide_embed_size : int
            Size of embedding in a conformation module for amide plane-amide plane normal vector angles.
        num_hidden_channels : int
            Hidden channel size for both nodes and edges.
        num_pre_res_blocks : int
            Number of residual blocks to apply prior to a residual reconnection.
        num_post_res_blocks : int
            Number of residual blocks to apply following a residual reconnection.
        activ_fn : Module
            Activation function to apply in MLPs.
        residual : bool
            Whether to use a residual update strategy for node features.
        num_attention_heads : int
            How many attention heads to apply to the input node features in parallel.
        norm_to_apply : str
            Which normalization scheme to apply to node and edge representations (i.e. 'batch' or 'layer').
        dropout_rate : float
            How much dropout (i.e. forget rate) to apply before activation functions.
        knn : int
            How many nearest neighbors were used when constructing node neighborhoods.
        num_layers : int
            How many layers of geometric attention to apply.
        feature_indices : dict
            A dictionary listing the start and end indices for each node and edge feature.
        disable_geometric_mode : bool
            Whether to convert the Geometric Transformer into the original Graph Transformer.
        """

        # Record parameters given
        self.activ_fn = activ_fn
        self.residual = residual
        self.num_attention_heads = num_attention_heads
        self.norm_to_apply = norm_to_apply
        self.dropout_rate = dropout_rate
        self.knn = knn
        self.num_layers = num_layers
        self.feature_indices = feature_indices
        self.disable_geometric_mode = disable_geometric_mode

        # --------------------
        # Conformation Module
        # --------------------
        # Define all modules related to a conformation module
        if self.disable_geometric_mode:
            total_edge_feats_size = 4 + num_dist_feats + num_dir_feats + num_orient_feats + num_amide_feats
            self.conformation_module = nn.Linear(total_edge_feats_size, num_hidden_channels, bias=False)
        else:
            self.conformation_module = ConformationModule(num_dist_feats=num_dist_feats,
                                                          num_dir_feats=num_dir_feats,
                                                          num_orient_feats=num_orient_feats,
                                                          num_amide_feats=num_amide_feats,
                                                          dist_embed_size=dist_embed_size,
                                                          dir_embed_size=dir_embed_size,
                                                          orient_embed_size=orient_embed_size,
                                                          amide_embed_size=amide_embed_size,
                                                          shared_embed_size=shared_embed_size,
                                                          num_hidden_channels=num_hidden_channels,
                                                          num_pre_res_blocks=num_pre_res_blocks,
                                                          num_post_res_blocks=num_post_res_blocks,
                                                          activ_fn=activ_fn,
                                                          norm_to_apply=norm_to_apply,
                                                          feature_indices=feature_indices)

        # --------------------
        # Transformer Module
        # --------------------
        # Define all modules related to a Geometric Transformer module
        self.apply_layer_norm = 'layer' in self.norm_to_apply.lower()

        self.num_hidden_channels, self.num_output_feats = num_hidden_channels, num_hidden_channels
        if self.apply_layer_norm:
            self.layer_norm1_node_feats = nn.LayerNorm(self.num_output_feats)
            self.layer_norm1_edge_feats = nn.LayerNorm(self.num_output_feats)
        else:  # Otherwise, default to using batch normalization
            self.batch_norm1_node_feats = nn.BatchNorm1d(self.num_output_feats)
            self.batch_norm1_edge_feats = nn.BatchNorm1d(self.num_output_feats)

        self.mha_module = MultiHeadGeometricAttentionLayer(
            self.num_hidden_channels,
            self.num_output_feats // self.num_attention_heads,
            self.num_attention_heads,
            self.num_hidden_channels != self.num_output_feats,  # Only use bias if a Linear() has to change sizes
            update_edge_feats=False
        )

        self.O_node_feats = nn.Linear(self.num_output_feats, self.num_output_feats)

        # MLP for node features
        dropout = nn.Dropout(p=self.dropout_rate) if self.dropout_rate > 0.0 else nn.Identity()
        self.node_feats_MLP = nn.ModuleList([
            nn.Linear(self.num_output_feats, self.num_output_feats * 2, bias=False),
            self.activ_fn,
            dropout,
            nn.Linear(self.num_output_feats * 2, self.num_output_feats, bias=False)
        ])

        if self.apply_layer_norm:
            self.layer_norm2_node_feats = nn.LayerNorm(self.num_output_feats)
        else:  # Otherwise, default to using batch normalization
            self.batch_norm2_node_feats = nn.BatchNorm1d(self.num_output_feats)

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        glorot_orthogonal(self.O_node_feats.weight, scale=scale)
        self.O_node_feats.bias.data.fill_(0)

        for layer in self.node_feats_MLP:
            if hasattr(layer, 'weight'):  # Skip initialization for activation functions
                glorot_orthogonal(layer.weight, scale=scale)

        if self.disable_geometric_mode:
            glorot_orthogonal(self.conformation_module.weight, scale=scale)

    def run_gt_layer(self, graph: dgl.DGLGraph, node_feats: torch.Tensor, orig_edge_feats: torch.Tensor):
        """Perform a forward pass of geometric attention using a multi-head attention (MHA) module."""
        node_feats_in1 = node_feats  # Cache node representations for first residual connection

        # Update the input graph's geometric edge features iteratively
        if self.disable_geometric_mode:
            edge_pos_enc_index = self.feature_indices['edge_pos_enc']
            edge_pos_enc = graph.edata['f'][:, edge_pos_enc_index]
            edge_weight_index = self.feature_indices['edge_weights']
            edge_weights = graph.edata['f'][:, edge_weight_index]
            edge_messages_init = torch.cat((edge_pos_enc.reshape(-1, 1), edge_weights.reshape(-1, 1)), dim=1)
            edge_feats_init = torch.cat((edge_messages_init, orig_edge_feats), dim=1)
            edge_feats = self.conformation_module(edge_feats_init)
        else:
            edge_feats = self.conformation_module(graph, orig_edge_feats)

        # Apply first round of normalization before applying geometric attention, for performance enhancement
        if self.apply_layer_norm:
            node_feats = self.layer_norm1_node_feats(node_feats)
            edge_feats = self.layer_norm1_edge_feats(edge_feats)
        else:  # Otherwise, default to using batch normalization
            node_feats = self.batch_norm1_node_feats(node_feats)
            edge_feats = self.batch_norm1_edge_feats(edge_feats)

        # Get multi-head attention output using provided node and edge representations
        node_attn_out, _ = self.mha_module(graph, node_feats, edge_feats)

        node_feats = node_attn_out.view(-1, self.num_output_feats)

        node_feats = F.dropout(node_feats, self.dropout_rate, training=self.training)

        node_feats = self.O_node_feats(node_feats)

        # Make first residual connection
        if self.residual:
            node_feats = node_feats_in1 + node_feats  # Make first node residual connection

        node_feats_in2 = node_feats  # Cache node representations for second residual connection

        # Apply second round of normalization after first residual connection has been made
        if self.apply_layer_norm:
            node_feats = self.layer_norm2_node_feats(node_feats)
        else:  # Otherwise, default to using batch normalization
            node_feats = self.batch_norm2_node_feats(node_feats)

        # Apply MLP for node features
        for layer in self.node_feats_MLP:
            node_feats = layer(node_feats)

        # Make second residual connection
        if self.residual:
            node_feats = node_feats_in2 + node_feats  # Make second node residual connection

        # Return node representations
        return node_feats

    def forward(self, graph: dgl.DGLGraph, orig_edge_feats: torch.Tensor):
        """Perform a forward pass of a Geometric Transformer to get final node representations."""
        node_feats = self.run_gt_layer(graph, graph.ndata['f'], orig_edge_feats)
        return node_feats


class SEBlock(torch.nn.Module):
    """A squeeze-and-excitation block for PyTorch."""

    def __init__(self, ch, ratio=16):
        super(SEBlock, self).__init__()
        self.ratio = ratio
        self.linear1 = torch.nn.Linear(ch, ch // ratio)
        self.linear2 = torch.nn.Linear(ch // ratio, ch)
        self.act = nn.ReLU()

    def forward(self, in_block):
        x = torch.reshape(in_block, (in_block.shape[0], in_block.shape[1], -1))
        x = torch.mean(x, dim=-1)
        x = self.act(self.linear1(x))
        x = self.act(self.linear2(x))
        x = torch.sigmoid(x)
        return torch.einsum('bcij,bc->bcij', in_block, x)


class ResNet(nn.Module):
    """A custom ResNet module for PyTorch."""

    # Parameter initialization
    def __init__(self,
                 num_channels,
                 num_chunks,
                 module_name,
                 activ_fn=F.elu,
                 inorm=False,
                 initial_projection=False,
                 extra_blocks=False,
                 dilation_cycle=None,
                 verbose=False):

        self.num_channel = num_channels
        self.num_chunks = num_chunks
        self.module_name = module_name
        self.activ_fn = activ_fn
        self.inorm = inorm
        self.initial_projection = initial_projection
        self.extra_blocks = extra_blocks
        self.dilation_cycle = [1, 2, 4, 8] if dilation_cycle is None else dilation_cycle
        self.verbose = verbose

        super(ResNet, self).__init__()

        if self.initial_projection:
            self.add_module(f'resnet_{self.module_name}_init_proj', nn.Conv2d(in_channels=num_channels,
                                                                              out_channels=num_channels,
                                                                              kernel_size=(1, 1)))

        for i in range(self.num_chunks):
            for dilation_rate in self.dilation_cycle:
                if self.inorm:
                    self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_1',
                                    nn.InstanceNorm2d(num_channels, eps=1e-06, affine=True))
                    self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_2',
                                    nn.InstanceNorm2d(num_channels // 2, eps=1e-06, affine=True))
                    self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_3',
                                    nn.InstanceNorm2d(num_channels // 2, eps=1e-06, affine=True))

                self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_1',
                                nn.Conv2d(num_channels, num_channels // 2, kernel_size=(1, 1)))
                self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_2',
                                nn.Conv2d(num_channels // 2,
                                          num_channels // 2,
                                          kernel_size=(3, 3),
                                          dilation=dilation_rate,
                                          padding=dilation_rate))
                self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_3',
                                nn.Conv2d(num_channels // 2, num_channels, kernel_size=(1, 1)))
                self.add_module(f'resnet_{self.module_name}_{i}_{dilation_rate}_se_block',
                                SEBlock(num_channels, ratio=16))

        if self.extra_blocks:
            for i in range(2):
                if self.inorm:
                    self.add_module(f'resnet_{self.module_name}_extra{i}_inorm_1',
                                    nn.InstanceNorm2d(num_channels, eps=1e-06, affine=True))
                    self.add_module(f'resnet_{self.module_name}_extra{i}_inorm_2',
                                    nn.InstanceNorm2d(num_channels // 2, eps=1e-06, affine=True))
                    self.add_module(f'resnet_{self.module_name}_extra{i}_inorm_3',
                                    nn.InstanceNorm2d(num_channels // 2, eps=1e-06, affine=True))

                self.add_module(f'resnet_{self.module_name}_extra{i}_conv2d_1',
                                nn.Conv2d(num_channels, num_channels // 2, kernel_size=(1, 1)))
                self.add_module(f'resnet_{self.module_name}_extra{i}_conv2d_2',
                                nn.Conv2d(num_channels // 2,
                                          num_channels // 2,
                                          kernel_size=(3, 3),
                                          dilation=(1, 1),
                                          padding=(1, 1)))
                self.add_module(f'resnet_{self.module_name}_extra{i}_conv2d_3',
                                nn.Conv2d(num_channels // 2, num_channels, kernel_size=(1, 1)))
                self.add_module(f'resnet_{self.module_name}_extra{i}_se_block',
                                SEBlock(num_channels, ratio=16))

    def forward(self, x):
        """Compute ResNet output."""
        activ_fn = self.activ_fn

        if self.initial_projection:
            x = self._modules[f'resnet_{self.module_name}_init_proj'](x)

        for i in range(self.num_chunks):
            for dilation_rate in self.dilation_cycle:
                _residual = x

                # Internal block
                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_1'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_1'](x)

                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_2'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_2'](x)

                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_inorm_3'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_conv2d_3'](x)

                x = self._modules[f'resnet_{self.module_name}_{i}_{dilation_rate}_se_block'](x)

                x = x + _residual

        if self.extra_blocks:
            for i in range(2):
                _residual = x

                # Internal block
                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_extra{i}_inorm_1'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_extra{i}_conv2d_1'](x)

                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_extra{i}_inorm_2'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_extra{i}_conv2d_2'](x)

                if self.inorm:
                    x = self._modules[f'resnet_{self.module_name}_extra{i}_inorm_3'](x)
                x = activ_fn(x)
                x = self._modules[f'resnet_{self.module_name}_extra{i}_conv2d_3'](x)

                x = self._modules[f'resnet_{self.module_name}_extra{i}_se_block'](x)

                x = x + _residual

        return x


class MultiHeadRegionalAttention(nn.Module):
    """A multi-head attention block for PyTorch that operates regionally."""

    @staticmethod
    def get_stretch_weight(s):
        w = np.zeros((s * s, 1, 1, s, s))
        for i in range(s):
            for j in range(s):
                w[s * i + j, 0, 0, i, j] = 1
        return np.asarray(w).astype(np.float32)

    def __init__(self, in_dim=3, region_size=3, d_k=16, d_v=32, n_head=4, att_drop=0.1, output_score=False):
        super(MultiHeadRegionalAttention, self).__init__()
        self.temper = int(np.sqrt(d_k))
        self.dk_per_head = d_k // n_head
        self.dv_per_head = d_v // n_head
        self.dropout_layer = nn.Dropout(att_drop)
        self.output_score = output_score
        self.q_layer = nn.Conv2d(in_dim, d_k, kernel_size=(1, 1), bias=False)
        self.k_layer = nn.Conv2d(in_dim, d_k, kernel_size=(1, 1), bias=False)
        self.v_layer = nn.Conv2d(in_dim, d_v, kernel_size=(1, 1), bias=False)
        self.softmax_layer = nn.Softmax(1)
        self.stretch_layer = nn.Conv3d(in_channels=1,
                                       out_channels=region_size * region_size,
                                       kernel_size=(1, region_size, region_size),
                                       bias=False,
                                       padding=(0, 1, 1))
        self.stretch_layer.weight = nn.Parameter(
            torch.tensor(self.get_stretch_weight(region_size)), requires_grad=False
        )

    def forward(self, x):
        """Compute attention output and attention score."""
        Q = self.stretch_layer(self.q_layer(x).unsqueeze(1))
        K = self.stretch_layer(self.k_layer(x).unsqueeze(1))
        V = self.stretch_layer(self.v_layer(x).unsqueeze(1))
        qk = torch.mul(Q, K).permute(0, 2, 1, 3, 4)
        qk1 = qk.view((-1, self.dk_per_head, qk.shape[2], qk.shape[3], qk.shape[4]))
        attention_score = self.softmax_layer(torch.div(torch.sum(qk1, 1), self.temper))
        attention_score2 = self.dropout_layer(attention_score)
        attention_score2 = torch.repeat_interleave(attention_score2.unsqueeze(0).permute(0, 2, 1, 3, 4),
                                                   repeats=self.dv_per_head, dim=2)
        attention_out = torch.sum(torch.mul(attention_score2, V), dim=1)
        return attention_out, attention_score if self.output_score else attention_out


class ResNet2DInputWithOptAttention(nn.Module):
    """A ResNet and (optionally) regionally-attentive convolution module for a pair of 2D feature tensors."""

    def __init__(self,
                 num_chunks=4,
                 init_channels=128,
                 num_channels=128,
                 num_classes=2,
                 use_attention=False,
                 n_head=4,
                 module_name=None,
                 activ_fn=F.elu,
                 dropout=0.1,
                 verbose=False):
        super(ResNet2DInputWithOptAttention, self).__init__()
        self.num_chunks = num_chunks
        self.init_channels = init_channels
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.use_attention = use_attention
        self.n_head = n_head
        self.module_name = module_name
        self.activ_fun = activ_fn
        self.dropout = dropout
        self.verbose = verbose

        self.add_module('conv2d_1', nn.Conv2d(in_channels=self.init_channels,
                                              out_channels=self.num_channels,
                                              kernel_size=(1, 1),
                                              padding=(0, 0)))
        self.add_module('inorm_1', nn.InstanceNorm2d(self.num_channels, eps=1e-06, affine=True))

        self.add_module('base_resnet', ResNet(num_channels,
                                              self.num_chunks,
                                              module_name='base_resnet',
                                              activ_fn=self.activ_fun,
                                              inorm=True,
                                              initial_projection=True,
                                              extra_blocks=False))

        self.add_module('phase2_resnet', ResNet(num_channels,
                                                num_chunks=1,
                                                module_name='bin_resnet',
                                                activ_fn=self.activ_fun,
                                                inorm=False,
                                                initial_projection=True,
                                                extra_blocks=True))
        self.add_module('phase2_conv', nn.Conv2d(in_channels=self.num_channels,
                                                 out_channels=self.num_classes,
                                                 kernel_size=(1, 1),
                                                 padding=(0, 0)))
        if self.use_attention:
            self.add_module('MHA2D_1', MultiHeadRegionalAttention(self.num_channels,
                                                                  d_v=self.num_channels,
                                                                  n_head=self.n_head,
                                                                  att_drop=self.dropout,
                                                                  output_score=True))
            self.add_module('MHA2D_2', MultiHeadRegionalAttention(self.num_channels,
                                                                  d_v=self.num_channels,
                                                                  n_head=self.n_head,
                                                                  att_drop=self.dropout,
                                                                  output_score=True))

        # Reset learnable parameters
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        # Reinitialize final output layer
        final_layer_bias = self._modules['phase2_conv'].bias.clone()
        final_layer_bias[1] = -7.0  # -7 chosen as the second term's bias s.t. positives are predicted with prob=0.001
        self._modules['phase2_conv'].bias = nn.Parameter(final_layer_bias, requires_grad=True)

    def forward(self, f2d_tile: torch.Tensor):
        """Compute final convolution output."""
        activ_fun = self.activ_fun
        out_conv2d_1 = self._modules['conv2d_1'](f2d_tile)
        out_inorm_1 = activ_fun(self._modules['inorm_1'](out_conv2d_1))

        # First ResNet
        out_base_resnet = activ_fun(self._modules['base_resnet'](out_inorm_1))
        if self.use_attention:
            out_base_resnet, attention_scores_1 = self._modules['MHA2D_1'](out_base_resnet)
            out_base_resnet = activ_fun(out_base_resnet)

        # Second ResNet
        out_bin_predictor = activ_fun(self._modules['phase2_resnet'](out_base_resnet))
        if self.use_attention:
            out_bin_predictor, attention_scores_2 = self._modules['MHA2D_2'](out_bin_predictor)
            out_bin_predictor = activ_fun(out_bin_predictor)

        # Output convolution
        out_layer = self._modules['phase2_conv'](out_bin_predictor)
        return out_layer


# ------------------
# DGL Modules
# ------------------

class DGLGeometricTransformer(nn.Module):
    """A geometry-focused graph transformer, as a DGL module.

    DGLGeometricTransformer stands for a geometry-focused Graph Transformer layer. It is the
    equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph conv layer in a GCN.
    """

    def __init__(
            self,
            node_count_limit=NODE_COUNT_LIMIT,
            shared_embed_size=64,
            dist_embed_size=8,
            dir_embed_size=8,
            orient_embed_size=8,
            amide_embed_size=8,
            num_hidden_channels=128,
            num_pre_res_blocks=2,
            num_post_res_blocks=2,
            activ_fn=nn.SiLU(),
            transformer_residual=True,
            num_attention_heads=4,
            norm_to_apply='batch',
            dropout_rate=0.1,
            knn=20,
            num_layers=4,
            feature_indices=FEATURE_INDICES,
            disable_geometric_mode=False,
            **kwargs
    ):
        """Geometry-Focused Graph Transformer Layer

        Parameters
        ----------
        node_count_limit : int
            Maximum number of nodes allowed in a given graph.
        shared_embed_size : int
            Size of shared embedding in a conformation module.
        dist_embed_size : int
            Size of distance embedding in a conformation module.
        dir_embed_size : int
            Size of direction embedding in a conformation module.
        orient_embed_size : int
            Size of orientation embedding in a conformation module.
        amide_embed_size : int
            Size of embedding in a conformation module for amide plane-amide plane normal vector angles.
        num_hidden_channels : int
            Hidden channel size for both nodes and edges.
        num_pre_res_blocks : int
            Number of residual blocks to apply prior to a residual reconnection.
        num_post_res_blocks : int
            Number of residual blocks to apply following a residual reconnection.
        activ_fn : Module
            Activation function to apply in MLPs.
        transformer_residual : bool
            Whether to use a transformer-residual update strategy for node features.
        num_attention_heads : int
            How many attention heads to apply to the input node features in parallel.
        norm_to_apply : str
            Which normalization scheme to apply to node and edge representations (i.e. 'batch' or 'layer').
        dropout_rate : float
            How much dropout (i.e. forget rate) to apply before activation functions.
        knn : int
            How many nearest neighbors were used when constructing node neighborhoods.
        num_layers : int
            How many layers of geometric attention to apply.
        feature_indices : dict
            A dictionary listing the start and end indices for each node and edge feature.
        disable_geometric_mode : bool
            Whether to convert the Geometric Transformer into the original Graph Transformer.
        """
        super().__init__()
        assert feature_indices is not None, 'Dictionary of indices for node and edge features to be used must be given'

        # Initialize model parameters
        self.activ_fn = activ_fn
        self.transformer_residual = transformer_residual
        self.num_attention_heads = num_attention_heads
        self.norm_to_apply = norm_to_apply
        self.dropout_rate = dropout_rate
        self.knn = knn
        self.num_layers = num_layers
        self.feature_indices = feature_indices
        self.disable_geometric_mode = disable_geometric_mode

        # Parse dimensionality of each type of node and edge feature
        num_dist_feats = feature_indices['edge_dist_feats_end'] - feature_indices['edge_dist_feats_start']
        num_dir_feats = feature_indices['edge_dir_feats_end'] - feature_indices['edge_dir_feats_start']
        num_orient_feats = feature_indices['edge_orient_feats_end'] - feature_indices['edge_orient_feats_start']
        num_amide_feats = 1

        # --------------------
        # Initializer Modules
        # --------------------
        # Define all modules related to edge and node initialization
        if self.disable_geometric_mode:
            total_edge_feats_size = 4 + num_dist_feats + num_dir_feats + num_orient_feats + num_amide_feats
            self.init_edge_module = nn.Linear(total_edge_feats_size, num_hidden_channels, bias=False)
        else:
            self.init_edge_module = InitEdgeModule(node_count_limit=node_count_limit,
                                                   num_edge_feats=2,  # Only use pos. encod. and edge weights initially
                                                   num_dist_feats=num_dist_feats,
                                                   num_dir_feats=num_dir_feats,
                                                   num_orient_feats=num_orient_feats,
                                                   num_amide_feats=num_amide_feats,
                                                   num_hidden_channels=num_hidden_channels,
                                                   activ_fn=activ_fn,
                                                   feature_indices=feature_indices)

        # --------------------
        # Transformer Module
        # --------------------
        # Define all modules related to a variable number of Geometric Transformer modules
        num_intermediate_layers = max(0, num_layers - 1)
        gt_block_modules = [
            GeometricTransformerModule(shared_embed_size=shared_embed_size,
                                       num_dist_feats=num_dist_feats,
                                       num_dir_feats=num_dir_feats,
                                       num_orient_feats=num_orient_feats,
                                       num_amide_feats=num_amide_feats,
                                       dist_embed_size=dist_embed_size,
                                       dir_embed_size=dir_embed_size,
                                       orient_embed_size=orient_embed_size,
                                       amide_embed_size=amide_embed_size,
                                       num_hidden_channels=num_hidden_channels,
                                       num_pre_res_blocks=num_pre_res_blocks,
                                       num_post_res_blocks=num_post_res_blocks,
                                       activ_fn=activ_fn,
                                       residual=transformer_residual,
                                       num_attention_heads=num_attention_heads,
                                       norm_to_apply=norm_to_apply,
                                       dropout_rate=dropout_rate,
                                       knn=knn,
                                       num_layers=num_layers,
                                       feature_indices=feature_indices,
                                       disable_geometric_mode=self.disable_geometric_mode)
            for _ in range(num_intermediate_layers)
        ]
        if num_layers > 0:
            gt_block_modules.extend([
                FinalGeometricTransformerModule(shared_embed_size=shared_embed_size,
                                                num_dist_feats=num_dist_feats,
                                                num_dir_feats=num_dir_feats,
                                                num_orient_feats=num_orient_feats,
                                                num_amide_feats=num_amide_feats,
                                                dist_embed_size=dist_embed_size,
                                                dir_embed_size=dir_embed_size,
                                                orient_embed_size=orient_embed_size,
                                                amide_embed_size=amide_embed_size,
                                                num_hidden_channels=num_hidden_channels,
                                                num_pre_res_blocks=num_pre_res_blocks,
                                                num_post_res_blocks=num_post_res_blocks,
                                                activ_fn=activ_fn,
                                                residual=transformer_residual,
                                                num_attention_heads=num_attention_heads,
                                                norm_to_apply=norm_to_apply,
                                                dropout_rate=dropout_rate,
                                                knn=knn,
                                                num_layers=num_layers,
                                                feature_indices=feature_indices,
                                                disable_geometric_mode=self.disable_geometric_mode)
            ])
        self.gt_block = nn.ModuleList(gt_block_modules)

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        scale = 2.0
        if self.disable_geometric_mode:
            glorot_orthogonal(self.init_edge_module.weight, scale=scale)

    def forward(
            self,
            graph: dgl.DGLGraph
    ):
        """Forward pass of the network

        Parameters
        ----------
        graph : DGLGraph
            DGL input graph
        """
        # Cache the original batch number of nodes and edges
        batch_num_nodes, batch_num_edges = graph.batch_num_nodes(), graph.batch_num_edges()

        # Cache original edge features
        orig_edge_feats = graph.edata['f']

        # Prepare initial edge representations
        if self.disable_geometric_mode:
            edge_pos_enc_index = self.feature_indices['edge_pos_enc']
            edge_pos_enc = graph.edata['f'][:, edge_pos_enc_index]
            edge_weight_index = self.feature_indices['edge_weights']
            edge_weights = graph.edata['f'][:, edge_weight_index]
            edge_messages_init = torch.cat((edge_pos_enc.reshape(-1, 1), edge_weights.reshape(-1, 1)), dim=1)
            edge_feats_init = torch.cat((edge_messages_init, graph.edata['f']), dim=1)
            graph.edata['f'] = self.init_edge_module(edge_feats_init)
        else:
            graph.edata['f'] = self.init_edge_module(graph)

        # Apply a given number of intermediate geometric attention layers to the node and edge features given
        for gt_layer in self.gt_block[:-1]:
            graph.ndata['f'], graph.edata['f'] = gt_layer(graph, orig_edge_feats)

        # Apply final layer to update node representations by merging current node and edge representations
        graph.ndata['f'] = self.gt_block[-1](graph, orig_edge_feats)

        # Retain the original batch number of nodes and edges
        graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges)

        # Return updated node and edge features inside provided DGLGraph
        return graph

    def __repr__(self):
        return f'DGLGeometricTransformer(structure=' \
               f'h_in{self.num_node_input_feats}-h_hid{self.node_output_embedding_hidden_feats}-h_out{self.num_node_input_feats}' \
               f'-e_in{self.num_edge_input_feats}-e_hid{self.edge_input_mlp_output_size}-e_out{self.num_edge_input_feats})'


# ------------------
# Lightning Modules
# ------------------

class LitGINI(pl.LightningModule):
    """A geometry-focused inter-graph node interaction (GINI) module."""

    def __init__(self, num_node_input_feats: int, num_edge_input_feats: int, gnn_activ_fn=nn.SiLU(),
                 num_classes=2, max_num_graph_nodes=NODE_COUNT_LIMIT, max_num_residues=RESIDUE_COUNT_LIMIT,
                 testing_with_casp_capri=False, training_with_db5=False, pos_prob_threshold=0.5,
                 gnn_layer_type='geotran', num_gnn_layers=2, num_gnn_hidden_channels=128,
                 num_gnn_attention_heads=4, knn=20, interact_module_type='deepinter', num_interact_layers=14,
                 num_interact_hidden_channels=128, use_interact_attention=False, num_interact_attention_heads=4,
                 disable_geometric_mode=False, num_epochs=50, pn_ratio=0.1, dropout_rate=0.2, metric_to_track='val_ce',
                 weight_decay=1e-2, batch_size=1, lr=1e-3, pad=False, viz_every_n_epochs=1, use_wandb_logger=True,
                 weight_classes=False, fine_tune=False, ckpt_path=None, use_af2=False):
        """Initialize all the parameters for a LitGINI module."""
        super().__init__()

        # Build the network
        self.num_node_input_feats = num_node_input_feats
        self.num_edge_input_feats = num_edge_input_feats
        self.gnn_activ_fn = gnn_activ_fn
        self.num_classes = num_classes
        self.max_num_graph_nodes = max_num_graph_nodes
        self.max_num_residues = max_num_residues
        self.testing_with_casp_capri = testing_with_casp_capri
        self.training_with_db5 = training_with_db5
        self.pos_prob_threshold = pos_prob_threshold

        # GNN module's keyword arguments provided via the command line
        self.gnn_layer_type = gnn_layer_type
        self.num_gnn_layers = num_gnn_layers
        self.num_gnn_hidden_channels = num_gnn_hidden_channels
        self.num_gnn_attention_heads = num_gnn_attention_heads
        self.nbrhd_size = knn

        # Interaction module's keyword arguments provided via the command line
        self.interact_module_type = interact_module_type
        self.num_interact_layers = num_interact_layers
        self.num_interact_hidden_channels = num_interact_hidden_channels
        self.use_interact_attention = use_interact_attention
        self.num_interact_attention_heads = num_interact_attention_heads
        self.disable_geometric_mode = disable_geometric_mode

        # Derive shortcut booleans for convenient future reference
        self.using_dil_resnet = self.interact_module_type.lower() == 'dil_resnet'
        self.using_deeplab = self.interact_module_type.lower() == 'deeplab'
        self.using_deepinter = self.interact_module_type.lower() == 'deepinter'

        # Model hyperparameter keyword arguments provided via the command line
        self.num_epochs = num_epochs
        self.pn_ratio = pn_ratio
        self.dropout_rate = dropout_rate
        self.metric_to_track = metric_to_track
        self.weight_decay = weight_decay
        self.batch_size = batch_size
        self.lr = lr
        self.pad = pad
        self.viz_every_n_epochs = viz_every_n_epochs  # Visualize model predictions every 'n' epochs
        self.use_wandb_logger = use_wandb_logger  # Whether to use WandB as the primary means of logging
        self.weight_classes = weight_classes  # Whether to use class weighting in our training Cross Entropy
        self.fine_tune = fine_tune  # Whether to fine-tune a trained LitGINI on a new dataset
        self.ckpt_path = ckpt_path  # A path to a trained LitGINI checkpoint given if fine-tuning

        # Set up GNN node and edge embedding layers (if requested)
        self.using_gcn = self.gnn_layer_type.lower() == 'gcn'
        self.using_node_embedding = self.num_node_input_feats != self.num_gnn_hidden_channels
        self.node_in_embedding = nn.Linear(self.num_node_input_feats, self.num_gnn_hidden_channels, bias=False) \
            if self.using_node_embedding \
            else nn.Identity()

        # Assemble the layers of the network
        if self.fine_tune:
            # Load in trained LitGINI
            lit_gini = LitGINI.load_from_checkpoint(self.ckpt_path,
                                                    use_wandb_logger=use_wandb_logger,
                                                    batch_size=self.batch_size,
                                                    lr=self.lr,
                                                    weight_decay=self.weight_decay,
                                                    dropout_rate=self.dropout_rate)
            self.gnn_module, self.interact_module = lit_gini.gnn_module, lit_gini.interact_module
            # Freeze the interaction module during fine-tuning
            for param in self.interact_module.parameters():
                param.requires_grad = False
        else:
            self.build_gnn_module(), self.build_interaction_module()

        # Declare loss functions and metrics for training, validation, and testing
        self.loss_fn = FocalLoss(alpha=0.25, gamma=1.5, reduction='sum')
        self.train_acc = tm.Accuracy(num_classes=self.num_classes, average=None)
        self.train_prec = tm.Precision(num_classes=self.num_classes, average=None)
        self.train_recall = tm.Recall(num_classes=self.num_classes, average=None)

        self.val_acc = tm.Accuracy(num_classes=self.num_classes, average=None)
        self.val_prec = tm.Precision(num_classes=self.num_classes, average=None)
        self.val_recall = tm.Recall(num_classes=self.num_classes, average=None)
        self.val_auroc = tm.AUROC(num_classes=self.num_classes, average=None)
        self.val_auprc = tm.AveragePrecision(num_classes=self.num_classes)
        self.val_f1 = tm.F1(num_classes=self.num_classes, average=None)

        self.test_acc = tm.Accuracy(num_classes=self.num_classes, average=None)
        self.test_prec = tm.Precision(num_classes=self.num_classes, average=None)
        self.test_recall = tm.Recall(num_classes=self.num_classes, average=None)
        self.test_auroc = tm.AUROC(num_classes=self.num_classes, average=None)
        self.test_auprc = tm.AveragePrecision(num_classes=self.num_classes)
        self.test_f1 = tm.F1(num_classes=self.num_classes, average=None)

        # Reset learnable parameters and log hyperparameters
        self.reset_parameters()
        self.save_hyperparameters()

        # Linear layers for dimensionality reduction
        self.linear_layer_2688_to_1280 = nn.Linear(2688, 1280)
        self.linear_layer_1384_to_64 = nn.Linear(1384, 64)
        self.linear_layer_1320_to_64 = nn.Linear(1320, 64)

        # Esm2
        self.esm2 = ESM2('utils/esm2/esm2_t33_650M_UR50D.pt', frozen=True, device='cuda')
        for name, param in self.esm2.named_parameters(): param.requires_grad = False

        # SaProt
        self.saprot = SaProt(model_dir='utils/saprot/SaProt_650M_AF2.pt', frozen=True, device='cuda')
        self.use_af2 = use_af2
        if use_af2:
            self.struct_seq_dir = '/home/anonymity/data_gu02/SaProt/preprocessed_struct_seq_deepinter_test_af2'
        else:
            self.struct_seq_dir = '/home/anonymity/data_gu02/SaProt/preprocessed_struct_seq'

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        if self.using_node_embedding:
            # Reinitialize node input embedding
            glorot_orthogonal(self.node_in_embedding.weight, scale=2.0)

    def build_gnn_module(self):
        """Define all layers for the chosen GNN module."""
        # Marshal all GNN layers, allowing the user to choose which kind of graph learning scheme they would like to use
        num_node_input_feats = self.num_gnn_hidden_channels \
            if self.using_node_embedding \
            else self.num_node_input_feats
        if self.using_gcn:
            gnn_layers = [GraphConv(in_feats=num_node_input_feats,
                                    out_feats=num_node_input_feats,
                                    weight=True,
                                    activation=None,
                                    allow_zero_in_degree=False) for _ in range(self.num_gnn_layers)]
        else:  # Default to using a Geometric Transformer for learning node representations
            if self.num_gnn_layers > 0:
                gnn_layers = [DGLGeometricTransformer(node_count_limit=self.max_num_graph_nodes,
                                                      shared_embed_size=64,
                                                      dist_embed_size=8,
                                                      dir_embed_size=8,
                                                      orient_embed_size=8,
                                                      amide_embed_size=8,
                                                      num_hidden_channels=self.num_gnn_hidden_channels,
                                                      num_pre_res_blocks=2,
                                                      num_post_res_blocks=2,
                                                      activ_fn=self.gnn_activ_fn,
                                                      transformer_residual=True,
                                                      num_attention_heads=self.num_gnn_attention_heads,
                                                      norm_to_apply='batch',
                                                      dropout_rate=self.dropout_rate,
                                                      knn=self.nbrhd_size,
                                                      num_layers=self.num_gnn_layers,
                                                      feature_indices=FEATURE_INDICES,
                                                      disable_geometric_mode=self.disable_geometric_mode)]
            else:
                gnn_layers = []
        self.gnn_module = nn.ModuleList(gnn_layers)

    def get_interact_module(self):
        """Retrieve an interaction module of a specific type (e.g. Dilated ResNet or DeepLabV3Plus)."""
        interact_module = ResnetTriangleStack(config_parse_args(), 20)
        return interact_module

    def build_interaction_module(self):
        """Define all layers for the chosen interaction module."""
        # Dilated ResNets and DeepLabV3Plus package all their forward pass logic
        self.interact_module = self.get_interact_module()

    # ---------------------
    # Training
    # ---------------------
    def gnn_forward(self, graph: dgl.DGLGraph):
        """Make a forward pass through a single GNN module."""
        # Embed input features a priori
        if self.using_node_embedding:
            # print("Shape of graph.ndata['f']:", graph.ndata['f'].shape)
            # print("Shape of nn.Linear input features:", self.num_node_input_feats)
            graph.ndata['f'] = self.node_in_embedding(graph.ndata['f']).squeeze()
        if self.using_gcn:
            # Forward propagate with each GNN layer
            for layer in self.gnn_module:
                # Cache the original batch number of nodes and edges
                batch_num_nodes, batch_num_edges = graph.batch_num_nodes(), graph.batch_num_edges()
                graph.ndata['f'] = layer(graph, graph.ndata['f'], edge_weight=graph.edata['f'][:, 1]).squeeze()
                # Retain the original batch number of nodes and edges
                graph.set_batch_num_nodes(batch_num_nodes), graph.set_batch_num_edges(batch_num_edges)
        else:  # The GeometricTransformer updates simply by returning a graph containing the updated node/edge feats
            for layer in self.gnn_module:
                graph = layer(graph)  # Geometric Transformers can handle their own depth
        # Unbatch and collect each individual graph's predicted node features
        graphs = dgl.unbatch(graph)
        node_feats = [graph.ndata['f'] for graph in graphs]
        return node_feats
    
    def interact_forward(self, rec1d: torch.Tensor, rec2d: torch.Tensor, lig1d: torch.Tensor, lig2d: torch.Tensor, com2d: torch.Tensor, mask=None): #rec1d, rec2d, lig1d, lig2d, com2d, (mask)
        """Make a forward pass through the interaction module."""
        # Dilated ResNets and DeepLabV3Plus package all their forward pass logic
        logits = self.interact_module(rec1d, rec2d, lig1d, lig2d, com2d)
        return logits

    def shared_step(self, graph1: dgl.DGLGraph, graph2: dgl.DGLGraph, sequence1, sequence2, dist1, dist2, cropped_indices, pdb_ids, return_representations=False):
        """Make a forward pass through the entire siamese network."""
        display_specific_time_consumptions = False
        display_specific_memory_consumptions = False

        # Learn structural features for each structure's nodes
        ## Sequence 1
        if display_specific_memory_consumptions: print('Shared step start: ', memory_mb())
        start_time = datetime.now()
        graph1_node_feats = self.gnn_forward(graph1)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='Geometric Transformer graph1')
        if display_specific_memory_consumptions: print('Post-GNN forward graph 1: ', memory_mb())

        ## Sequence 2
        start_time = datetime.now()
        graph2_node_feats = self.gnn_forward(graph2)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='Geometric Transformer graph2')
        if display_specific_memory_consumptions: print('Post-GNN forward graph 2: ', memory_mb())

        # Read preprocessed structural sequences
        if self.use_af2:
            preprocessed_data = dill.load(open(os.path.join(self.struct_seq_dir, f'{pdb_ids[0]}_af2_struct_seq.dill'), 'rb'))
        else:
            preprocessed_data = dill.load(open(os.path.join(self.struct_seq_dir, f'{pdb_ids[0]}_struct_seq.dill'), 'rb'))
        struct_seq1 = preprocessed_data['seqA']
        struct_seq2 = preprocessed_data['seqB']

        # Learn Saprot ESM2 features for each structure's sequence
        ## Sequence 1
        start_time = datetime.now()
        esm2_representation1, esm2_attention1 = self.esm2(sequence1[0])
        if display_specific_memory_consumptions: print('Post-ESM2 sequence 1: ', memory_mb())
        saprot_representation1, saprot_attention1 = self.saprot.forward(struct_seq1)
        if display_specific_memory_consumptions: print('Post-SaProt sequence 1: ', memory_mb())
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2+SaProt sequence 1')
        ## Sequence 2
        start_time = datetime.now()
        esm2_representation2, esm2_attention2 = self.esm2(sequence2[0])
        if display_specific_memory_consumptions: print('Post-ESM2 sequence 2: ', memory_mb())
        saprot_representation2, saprot_attention2 = self.saprot.forward(struct_seq2)
        if display_specific_memory_consumptions: print('Post-SaProt sequence 2: ', memory_mb())
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2+SaProt sequence 2')

        ## Crop ESM2 features to match the cropped indices
        cropped_indices = cropped_indices[0]
        esm2_representation1 = esm2_representation1[cropped_indices[0]:cropped_indices[2], :]
        esm2_attention1 = esm2_attention1[cropped_indices[0]:cropped_indices[2], cropped_indices[0]:cropped_indices[2], :]
        esm2_representation2 = esm2_representation2[cropped_indices[1]:cropped_indices[3], :]
        esm2_attention2 = esm2_attention2[cropped_indices[1]:cropped_indices[3], cropped_indices[1]:cropped_indices[3], :]
        saprot_representation1 = saprot_representation1[cropped_indices[0]:cropped_indices[2], :]
        saprot_attention1 = saprot_attention1[cropped_indices[0]:cropped_indices[2], cropped_indices[0]:cropped_indices[2], :]
        saprot_representation2 = saprot_representation2[cropped_indices[1]:cropped_indices[3], :]
        saprot_attention2 = saprot_attention2[cropped_indices[1]:cropped_indices[3], cropped_indices[1]:cropped_indices[3], :]

        # Concatenate geometric and ESM2 features for each structure
        graph1_combined_feats = torch.cat((graph1_node_feats[0], esm2_representation1, saprot_representation1), dim=1) # (La, 2688) -> 1280 + 1280 +128
        graph2_combined_feats = torch.cat((graph2_node_feats[0], esm2_representation2, saprot_representation2), dim=1) # (Lb, 2688) -> 1280 + 1280 +128
    
        # Learn SaProt ESM2 features for the complex (sequence 1 + sequence 2)
        start_time = datetime.now()
        _, esm2_complex_attention = self.esm2(sequence1[0] + sequence2[0])
        _, saprot_complex_attention = self.saprot.forward(struct_seq1 + struct_seq2)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2 sequence 1+2')
        
        ## Extract the top-right (La x Lb) submatrix of the complex attention matrix and crop to match the cropped indices
        cropped_esm2_complex_attention = esm2_complex_attention[:len(sequence1[0]), len(sequence1[0]):, :]
        cropped_esm2_complex_attention = cropped_esm2_complex_attention[cropped_indices[0]:cropped_indices[2], cropped_indices[1]:cropped_indices[3], :]
        cropped_saprot_complex_attention = saprot_complex_attention[:len(sequence1[0]), len(sequence1[0]):, :]
        cropped_saprot_complex_attention = cropped_saprot_complex_attention[cropped_indices[0]:cropped_indices[2], cropped_indices[1]:cropped_indices[3], :]

        # Apply rbf on distance matrices
        dist1 = [rbf(dist1[0])]
        dist2 = [rbf(dist2[0])]

        # Combine graph features with ESM2 features (and rbf distance matrices for chains)
        combined_rec_feats = torch.cat([esm2_attention1, saprot_attention1, dist1[0].permute(1,2,0)], dim=2) # (La, La, 1384) -> 660 + 660 + 64
        combined_lig_feats = torch.cat([esm2_attention2, saprot_attention2, dist2[0].permute(1,2,0)], dim=2) # (Lb, Lb, 1384) -> 660 + 660 + 64
        combined_cropped_complex_attention = torch.cat([cropped_esm2_complex_attention, cropped_saprot_complex_attention], dim=2) # (La, Lb, 1320) -> 660 + 660

        # Reduce dimensionality of combined features
        # print(f'graph1_combined_feats: {graph1_combined_feats.shape}')
        # print(f'graph2_combined_feats: {graph2_combined_feats.shape}')
        graph1_combined_feats = self.linear_layer_2688_to_1280(graph1_combined_feats)
        graph2_combined_feats = self.linear_layer_2688_to_1280(graph2_combined_feats)
        combined_rec_feats = self.linear_layer_1384_to_64(combined_rec_feats.unsqueeze(0))
        combined_lig_feats = self.linear_layer_1384_to_64(combined_lig_feats.unsqueeze(0))
        combined_cropped_complex_attention = self.linear_layer_1320_to_64(combined_cropped_complex_attention.unsqueeze(0))

        if display_specific_memory_consumptions: print('Before interact forward: ', memory_mb())
        # Make a forward pass through the interaction module (Resnet + Triangle)
        start_time = datetime.now()
        logits_list = [self.interact_forward(graph1_combined_feats, combined_rec_feats, graph2_combined_feats, combined_lig_feats, combined_cropped_complex_attention)]
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ResnetTriangleStack')
        if display_specific_memory_consumptions: print('Post-interact forward: ', memory_mb())

        # Return network prediction
        if return_representations:
            # Return network prediction and learned node and edge representations for both graphs
            g1_nf, g1_ef = graph1.ndata['f'].detach().cpu().numpy(), graph1.edata['f'].detach().cpu().numpy()
            g2_nf, g2_ef = graph2.ndata['f'].detach().cpu().numpy(), graph2.edata['f'].detach().cpu().numpy()
            return logits_list, g1_nf, g1_ef, g2_nf, g2_ef
        else:
            return logits_list
        
    def shared_step_predict(self, graph1: dgl.DGLGraph, graph2: dgl.DGLGraph, sequence1, sequence2, struct_seq1, struct_seq2, dist1, dist2):
        """Make a forward pass through the entire siamese network."""
        display_specific_time_consumptions = False
        display_specific_memory_consumptions = False

        # Learn structural features for each structure's nodes
        ## Sequence 1
        if display_specific_memory_consumptions: print('Shared step start: ', memory_mb())
        start_time = datetime.now()
        graph1_node_feats = self.gnn_forward(graph1)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='Geometric Transformer graph1')
        if display_specific_memory_consumptions: print('Post-GNN forward graph 1: ', memory_mb())

        ## Sequence 2
        start_time = datetime.now()
        graph2_node_feats = self.gnn_forward(graph2)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='Geometric Transformer graph2')
        if display_specific_memory_consumptions: print('Post-GNN forward graph 2: ', memory_mb())

        # Unbatch and structure sequences
        struct_seq1 = struct_seq1[0]
        struct_seq2 = struct_seq2[0]

        # Learn Saprot ESM2 features for each structure's sequence
        ## Sequence 1
        start_time = datetime.now()
        esm2_representation1, esm2_attention1 = self.esm2(sequence1[0])
        if display_specific_memory_consumptions: print('Post-ESM2 sequence 1: ', memory_mb())
        saprot_representation1, saprot_attention1 = self.saprot.forward(struct_seq1)
        if display_specific_memory_consumptions: print('Post-SaProt sequence 1: ', memory_mb())
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2+SaProt sequence 1')
        ## Sequence 2
        start_time = datetime.now()
        esm2_representation2, esm2_attention2 = self.esm2(sequence2[0])
        if display_specific_memory_consumptions: print('Post-ESM2 sequence 2: ', memory_mb())
        saprot_representation2, saprot_attention2 = self.saprot.forward(struct_seq2)
        if display_specific_memory_consumptions: print('Post-SaProt sequence 2: ', memory_mb())
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2+SaProt sequence 2')

        # Concatenate geometric and ESM2 features for each structure
        graph1_combined_feats = torch.cat((graph1_node_feats[0], esm2_representation1, saprot_representation1), dim=1) # (La, 2688) -> 1280 + 1280 +128
        graph2_combined_feats = torch.cat((graph2_node_feats[0], esm2_representation2, saprot_representation2), dim=1) # (Lb, 2688) -> 1280 + 1280 +128
    
        # Learn SaProt ESM2 features for the complex (sequence 1 + sequence 2)
        start_time = datetime.now()
        _, esm2_complex_attention = self.esm2(sequence1[0] + sequence2[0])
        _, saprot_complex_attention = self.saprot.forward(struct_seq1 + struct_seq2)
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ESM2 sequence 1+2')
        
        ## Extract the top-right (La x Lb) submatrix of the complex attention matrix
        cropped_esm2_complex_attention = esm2_complex_attention[:len(sequence1[0]), len(sequence1[0]):, :]
        cropped_saprot_complex_attention = saprot_complex_attention[:len(sequence1[0]), len(sequence1[0]):, :]

        # Apply rbf on distance matrices
        dist1 = [rbf(dist1[0])]
        dist2 = [rbf(dist2[0])]

        # Combine graph features with ESM2 features (and rbf distance matrices for chains)
        combined_rec_feats = torch.cat([esm2_attention1, saprot_attention1, dist1[0].permute(1,2,0)], dim=2) # (La, La, 1384) -> 660 + 660 + 64
        combined_lig_feats = torch.cat([esm2_attention2, saprot_attention2, dist2[0].permute(1,2,0)], dim=2) # (Lb, Lb, 1384) -> 660 + 660 + 64
        combined_cropped_complex_attention = torch.cat([cropped_esm2_complex_attention, cropped_saprot_complex_attention], dim=2) # (La, Lb, 1320) -> 660 + 660

        # Reduce dimensionality of combined features
        # print(f'graph1_combined_feats: {graph1_combined_feats.shape}')
        # print(f'graph2_combined_feats: {graph2_combined_feats.shape}')
        graph1_combined_feats = self.linear_layer_2688_to_1280(graph1_combined_feats)
        graph2_combined_feats = self.linear_layer_2688_to_1280(graph2_combined_feats)
        combined_rec_feats = self.linear_layer_1384_to_64(combined_rec_feats.unsqueeze(0))
        combined_lig_feats = self.linear_layer_1384_to_64(combined_lig_feats.unsqueeze(0))
        combined_cropped_complex_attention = self.linear_layer_1320_to_64(combined_cropped_complex_attention.unsqueeze(0))

        if display_specific_memory_consumptions: print('Before interact forward: ', memory_mb())
        # Make a forward pass through the interaction module (Resnet + Triangle)
        start_time = datetime.now()
        logits_list = [self.interact_forward(graph1_combined_feats, combined_rec_feats, graph2_combined_feats, combined_lig_feats, combined_cropped_complex_attention)]
        end_time = datetime.now()
        if display_specific_time_consumptions: print_time_diff(start_time, end_time, task_name='ResnetTriangleStack')
        if display_specific_memory_consumptions: print('Post-interact forward: ', memory_mb())

        # Return network prediction
        return logits_list

    def downsample_examples(self, examples: torch.tensor):
        """Randomly sample enough negative pairs to achieve requested positive-negative class ratio (via shuffling)."""
        examples = examples[torch.randperm(len(examples))]  # Randomly shuffle training examples
        pos_examples = examples[examples[:, 2] == 1]  # Find out how many interacting pairs there are
        num_neg_pairs_to_sample = int(len(pos_examples) / self.pn_ratio)  # Determine negative sample size
        neg_examples = examples[examples[:, 2] == 0][:num_neg_pairs_to_sample]  # Sample negative pairs
        downsampled_examples = torch.cat((neg_examples, pos_examples))
        return downsampled_examples

    def on_train_epoch_start(self):
        '''Lightning calls this at the beginning of every test epoch.'''
        # Print divider containing current epoch number at the start of each epoch
        print("\n-----------------------------------------------------------------------------------------------")
        print(f"Train epoch start: {self.current_epoch}")
        print("-----------------------------------------------------------------------------------------------")

    def training_step(self, batch, batch_idx):
        """Lightning calls this inside the training loop."""
        # Separate training batch from validation batch (with the latter being used for visualizations)
        train_batch, val_batch = batch['train_batch'], batch['val_batch']

        # Make a forward pass through the network for a singular training batch of protein complexes
        # Retrieve data from the batch
        graph1, graph2, seqA, seqB, distA, distB, examples_list, pdb_ids, cropped_indices = train_batch[:9]

        # Print the pdb_id of the current train complex
        # print(f'Training: {pdb_ids[0]}')

        # Forward propagate with network layers
        logits_list = self.shared_step(graph1, graph2, seqA, seqB, distA, distB, cropped_indices, pdb_ids)  # The forward method must be named something new

        # Assume there are only one element in logits_list and examples_list
        labels = examples_list[0][:, -1].squeeze().reshape(graph1.num_nodes(), graph2.num_nodes())

        # [Focal Loss] Calculate the protein interface prediction (PICP) loss
        loss = self.loss_fn(logits_list[0].squeeze(), labels)  # Calculate loss of a single complex

        # Log training step metric(s)
        self.log(f'train_ce', loss, on_step=False, on_epoch=True, sync_dist=True)

        return {
            'loss': loss,
        }

    def training_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
        """Lightning calls this at the end of every training epoch."""
        # Reset training TorchMetrics for all devices
        self.train_acc.reset()
        self.train_prec.reset()
        self.train_recall.reset()

    def validation_step(self, batch, batch_idx):
        """Lightning calls this inside the validation loop."""
        # Make a forward pass through the network for a singular validation batch of protein complexes
        graph1, graph2, seqA, seqB, distA, distB, examples_list, pdb_ids, cropped_indices = batch[:9]

        # Print the pdb_id of the current validation complex
        # print(f'Validation: {pdb_ids[0]}')

        # Forward propagate with network layers
        logits_list = self.shared_step(graph1, graph2, seqA, seqB, distA, distB, cropped_indices, pdb_ids)
        
        # Assume there are only one element in examples_list and logits_list
        # Make predictions
        preds = logits_list[0].squeeze().reshape(graph1.num_nodes() * graph2.num_nodes()).unsqueeze(1)
        empty_column = torch.zeros(preds.shape[0], 1, device=self.device)
        preds = torch.cat((empty_column, preds), dim=1)
        labels = examples_list[0][:, -1]

        # Calculate top-k metrics
        l = min(graph1.num_nodes(), graph2.num_nodes())
        sorted_pred_indices = torch.argsort(preds[:, 1], descending=True)

        top_1_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=1)
        top_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=10)
        top_25_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=25)
        top_50_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=50)
        top_l_by_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 10))
        top_l_by_5_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 5))
        top_l_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=l)

        labels = labels.squeeze().reshape(graph1.num_nodes(), graph2.num_nodes())

        # [Focal Loss] Calculate the protein interface prediction (PICP)
        loss = self.loss_fn(logits_list[0].squeeze(), labels)  # Calculate loss of a single complex

        # Log validation step metric(s)
        self.log(f'val_ce', loss, sync_dist=True)
        self.log('val_top_1_prec', top_1_prec, sync_dist=True)
        self.log('val_top_10_prec', top_10_prec, sync_dist=True)
        self.log('val_top_25_prec', top_25_prec, sync_dist=True)
        self.log('val_top_50_prec', top_50_prec, sync_dist=True)
        self.log('val_top_l_by_10_prec', top_l_by_10_prec, sync_dist=True)
        self.log('val_top_l_by_5_prec', top_l_by_5_prec, sync_dist=True)
        self.log('val_top_l_prec', top_l_prec, sync_dist=True)

        return {
            'loss': loss,
            'val_top_1_prec': top_1_prec,
            'val_top_10_prec': top_10_prec,
            'val_top_25_prec': top_25_prec,
            'val_top_50_prec': top_50_prec,
            'val_top_l_by_10_prec': top_l_by_10_prec,
            'val_top_l_by_5_prec': top_l_by_5_prec,
            'val_top_l_prec': top_l_prec
        }

    def validation_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
        """Lightning calls this at the end of every validation epoch."""
        # Tuplize scores for the current device (e.g. Rank 0)
        val_top_1_precs = [output_dict['val_top_1_prec'] for output_dict in outputs]
        val_top_10_precs = [output_dict['val_top_10_prec'] for output_dict in outputs]
        val_top_25_precs = [output_dict['val_top_25_prec'] for output_dict in outputs]
        val_top_50_precs = [output_dict['val_top_50_prec'] for output_dict in outputs]
        val_top_l_by_10_precs = [output_dict['val_top_l_by_10_prec'] for output_dict in outputs]
        val_top_l_by_5_precs = [output_dict['val_top_l_by_5_prec'] for output_dict in outputs]
        val_top_l_precs = [output_dict['val_top_l_prec'] for output_dict in outputs]

        # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
        def gather_and_concat_precs(precs):
            '''Fix issue of wrong tensor dimensions'''
            gathered_precs = []
            for prec in self.all_gather(precs):
                if prec.dim() == 0:
                    gathered_precs.append(prec.unsqueeze(0))
                else:
                    gathered_precs.append(prec)
            return torch.cat(gathered_precs)
        val_top_1_precs = gather_and_concat_precs(val_top_1_precs)
        val_top_10_precs = gather_and_concat_precs(val_top_10_precs)
        val_top_25_precs = gather_and_concat_precs(val_top_25_precs)
        val_top_50_precs = gather_and_concat_precs(val_top_50_precs)
        val_top_l_by_10_precs = gather_and_concat_precs(val_top_l_by_10_precs)
        val_top_l_by_5_precs = gather_and_concat_precs(val_top_l_by_5_precs)
        val_top_l_precs = gather_and_concat_precs(val_top_l_precs)

        # Reset validation TorchMetrics for all devices
        self.val_acc.reset()
        self.val_prec.reset()
        self.val_recall.reset()
        self.val_f1.reset()
        self.val_auroc.reset()
        self.val_auprc.reset()

        # Log metric(s) aggregated from all ranks
        self.log('mean_val_top_1_prec', torch.mean(val_top_1_precs))
        self.log('mean_val_top_10_prec', torch.mean(val_top_10_precs))
        self.log('mean_val_top_25_prec', torch.mean(val_top_25_precs))
        self.log('mean_val_top_50_prec', torch.mean(val_top_50_precs))
        self.log('mean_val_top_l_by_10_prec', torch.mean(val_top_l_by_10_precs))
        self.log('mean_val_top_l_by_5_prec', torch.mean(val_top_l_by_5_precs))
        self.log('mean_val_top_l_prec', torch.mean(val_top_l_precs))

        # Print out validation metrics
        print()
        print(f'Epoch {self.current_epoch} finish date and time: ', datetime.now().strftime("%m-%d-%Y %H:%M:%S"))
        print('* New metrics: ')
        print('- Means -')
        print('mean_val_top_1_prec: ', torch.mean(val_top_1_precs).item())
        print('mean_val_top_10_prec: ', torch.mean(val_top_10_precs).item())
        print('mean_val_top_25_prec: ', torch.mean(val_top_25_precs).item())
        print('mean_val_top_50_prec: ', torch.mean(val_top_50_precs).item())
        print('mean_val_top_l_by_10_prec: ', torch.mean(val_top_l_by_10_precs).item())
        print('mean_val_top_l_by_5_prec: ', torch.mean(val_top_l_by_5_precs).item())
        print('mean_val_top_l_prec: ', torch.mean(val_top_l_precs).item())
        print()

    def test_step(self, batch, batch_idx):
        """Lightning calls this inside the testing loop."""
        # Make a forward pass through the network for a singular test batch of protein complexes
        # Retrieve data from the batch
        graph1, graph2, seqA, seqB, distA, distB, examples_list, pdb_ids, cropped_indices = batch[:9]

        # Forward propagate with network layers
        logits_list = self.shared_step(graph1, graph2, seqA, seqB, distA, distB, cropped_indices, pdb_ids)

        # Assume there are only one element in examples_list and logits_list
        # Make predictions
        preds = logits_list[0].squeeze().reshape(graph1.num_nodes() * graph2.num_nodes()).unsqueeze(1)
        empty_column = torch.zeros(preds.shape[0], 1, device=self.device)
        preds = torch.cat((empty_column, preds), dim=1)
        labels = examples_list[0][:, -1]

        # Calculate top-k metrics
        l = min(graph1.num_nodes(), graph2.num_nodes())  # Use the smallest length of the two chains as our denominator
        sorted_pred_indices = torch.argsort(preds[:, 1], descending=True)
        top_1_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=1)
        top_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=10)
        top_25_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=25)
        top_50_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=50)
        top_l_by_10_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 10))
        top_l_by_5_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=(l // 5))
        top_l_prec = calculate_top_k_prec(sorted_pred_indices, labels, k=l)

        labels = labels.squeeze().reshape(graph1.num_nodes(), graph2.num_nodes())

        # [Focal Loss] Calculate the protein interface prediction (PICP) loss
        loss = self.loss_fn(logits_list[0].squeeze(), labels)  # Calculate loss of a single complex

        # Manually evaluate test performance by collecting all predicted and ground-truth interaction tensors
        test_preds = preds.detach()
        test_preds_rounded = test_preds.clone()
        test_preds_rounded[:, 0] = (test_preds[:, 0] >= (1 - self.pos_prob_threshold)).float()
        test_preds_rounded[:, 1] = (test_preds[:, 1] >= self.pos_prob_threshold).float()
        test_preds = test_preds[:, 1].reshape(graph1.num_nodes(), graph2.num_nodes()).cpu().numpy()
        test_preds_rounded = test_preds_rounded[:, 1].reshape(graph1.num_nodes(), graph2.num_nodes()).cpu().numpy()

        test_labels = examples_list[0][:, 2].detach()
        test_labels = test_labels.reshape(graph1.num_nodes(), graph2.num_nodes()).float().cpu().numpy()

        # Log test step metric(s)
        self.log(f'test_ce', loss, sync_dist=True)
        self.log('test_top_1_prec', top_1_prec, sync_dist=True)
        self.log('test_top_10_prec', top_10_prec, sync_dist=True)
        self.log('test_top_25_prec', top_25_prec, sync_dist=True)
        self.log('test_top_50_prec', top_50_prec, sync_dist=True)
        self.log('test_top_l_by_10_prec', top_l_by_10_prec, sync_dist=True)
        self.log('test_top_l_by_5_prec', top_l_by_5_prec, sync_dist=True)
        self.log('test_top_l_prec', top_l_prec, sync_dist=True)

        return {
            'loss': loss,
            'test_preds': test_preds,
            'test_preds_rounded': test_preds_rounded,
            'test_labels': test_labels,
            'top_1_prec': top_1_prec,
            'top_10_prec': top_10_prec,
            'top_25_prec': top_25_prec,
            'top_50_prec': top_50_prec,
            'top_l_by_10_prec': top_l_by_10_prec,
            'top_l_by_5_prec': top_l_by_5_prec,
            'top_l_prec': top_l_prec,
            'target': pdb_ids
        }

    def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT):
        """Lightning calls this at the end of every test epoch."""
        # Tuplize scores for the current device (e.g. Rank 0)
        test_top_1_precs = [output_dict['top_1_prec'] for output_dict in outputs]
        test_top_10_precs = [output_dict['top_10_prec'] for output_dict in outputs]
        test_top_25_precs = [output_dict['top_25_prec'] for output_dict in outputs]
        test_top_50_precs = [output_dict['top_50_prec'] for output_dict in outputs]
        test_top_l_by_10_precs = [output_dict['top_l_by_10_prec'] for output_dict in outputs]
        test_top_l_by_5_precs = [output_dict['top_l_by_5_prec'] for output_dict in outputs]
        test_top_l_precs = [output_dict['top_l_prec'] for output_dict in outputs]

        # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
        def gather_and_concat_precs(precs):
            '''Fix issue of wrong tensor dimensions'''
            gathered_precs = []
            for prec in self.all_gather(precs):
                if prec.dim() == 0:
                    gathered_precs.append(prec.unsqueeze(0))
                else:
                    gathered_precs.append(prec)
            return torch.cat(gathered_precs)
        test_top_1_precs = gather_and_concat_precs(test_top_1_precs)
        test_top_10_precs = gather_and_concat_precs(test_top_10_precs)
        test_top_25_precs = gather_and_concat_precs(test_top_25_precs)
        test_top_50_precs = gather_and_concat_precs(test_top_50_precs)
        test_top_l_by_10_precs = gather_and_concat_precs(test_top_l_by_10_precs)
        test_top_l_by_5_precs = gather_and_concat_precs(test_top_l_by_5_precs)
        test_top_l_precs = gather_and_concat_precs(test_top_l_precs)

        if self.use_wandb_logger:
            test_preds = [wandb.Image(output_dict['test_preds']) for output_dict in outputs]  # Convert to image
            test_preds_rounded = [wandb.Image(output_dict['test_preds_rounded']) for output_dict in outputs]  # Rounded
            test_labels = [wandb.Image(output_dict['test_labels']) for output_dict in outputs]  # Convert to image
        else:  # Assume we are instead using the TensorBoardLogger
            test_preds = [output_dict['test_preds'] for output_dict in outputs]
            test_preds_rounded = [(output_dict['test_preds_rounded']) for output_dict in outputs]
            test_labels = [output_dict['test_labels'] for output_dict in outputs]

        # Write out test top-k metric results to CSV
        metrics_data = {
            'top_1_prec': [extract_object(output_dict['top_1_prec']) for output_dict in outputs],
            'top_10_prec': [extract_object(output_dict['top_10_prec']) for output_dict in outputs],
            'top_25_prec': [extract_object(output_dict['top_25_prec']) for output_dict in outputs],
            'top_50_prec': [extract_object(output_dict['top_50_prec']) for output_dict in outputs],
            'top_l_by_10_prec': [extract_object(output_dict['top_l_by_10_prec']) for output_dict in outputs],
            'top_l_by_5_prec': [extract_object(output_dict['top_l_by_5_prec']) for output_dict in outputs],
            'top_l_prec': [extract_object(output_dict['top_l_prec']) for output_dict in outputs],
            'target': [extract_object(output_dict['target']) for output_dict in outputs],
        }
        metrics_df = pd.DataFrame(data=metrics_data)
        metrics_df_name_prefix = 'dips_plus_test'
        metrics_df_name_prefix = 'casp_capri' if self.testing_with_casp_capri else metrics_df_name_prefix
        metrics_df_name_prefix = 'db5_plus_test' if self.training_with_db5 else metrics_df_name_prefix
        metrics_df_name = metrics_df_name_prefix + '_top_metrics.csv'
        metrics_df.to_csv(metrics_df_name)

        if not self.testing_with_casp_capri:  # Testing with either DIPS-Plus or DB5-Plus
            # Filter out all but the first 55 test predictions and labels to reduce storage requirements
            test_preds, test_preds_rounded, test_labels = test_preds[:55], test_preds_rounded[:55], test_labels[:55]

        # Reset test TorchMetrics for all devices
        self.test_acc.reset()
        self.test_prec.reset()
        self.test_recall.reset()
        self.test_f1.reset()
        self.test_auroc.reset()
        self.test_auprc.reset()

        # Log median and mean test metrics
        self.log('med_test_top_1_prec', torch.median(test_top_1_precs).item())
        self.log('med_test_top_10_prec', torch.median(test_top_10_precs).item())
        self.log('med_test_top_25_prec', torch.median(test_top_25_precs).item())
        self.log('med_test_top_50_prec', torch.median(test_top_50_precs).item())
        self.log('med_test_top_l_by_10_prec', torch.median(test_top_l_by_10_precs).item())
        self.log('med_test_top_l_by_5_prec', torch.median(test_top_l_by_5_precs).item())
        self.log('med_test_top_l_prec', torch.median(test_top_l_precs).item())
        self.log('mean_test_top_1_prec', torch.mean(test_top_1_precs).item())
        self.log('mean_test_top_10_prec', torch.mean(test_top_10_precs).item())
        self.log('mean_test_top_25_prec', torch.mean(test_top_25_precs).item())
        self.log('mean_test_top_50_prec', torch.mean(test_top_50_precs).item())
        self.log('mean_test_top_l_by_10_prec', torch.mean(test_top_l_by_10_precs).item())
        self.log('mean_test_top_l_by_5_prec', torch.mean(test_top_l_by_5_precs).item())
        self.log('mean_test_top_l_prec', torch.mean(test_top_l_precs).item())

        # Log test predictions with their ground-truth interaction tensors to WandB for visual inspection
        if self.use_wandb_logger:
            self.trainer.logger.experiment.log({'test_preds': test_preds})
            self.trainer.logger.experiment.log({'test_preds_rounded': test_preds_rounded})
            self.trainer.logger.experiment.log({'test_labels': test_labels})
        else:  # Assume we are instead using the TensorBoardLogger
            for i, (t_preds, t_preds_rounded, t_labels) in enumerate(zip(test_preds, test_preds_rounded, test_labels)):
                self.logger.experiment.add_image('test_preds', t_preds, i, dataformats='HW')
                self.logger.experiment.add_image('test_preds_rounded', t_preds_rounded, i, dataformats='HW')
                self.logger.experiment.add_image('test_labels', t_labels, i, dataformats='HW')

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        """Lightning calls this inside the predict loop."""
        # Make predictions for a batch of protein complexes
        graph1, graph2, seqA, seqB, struct_seqA, struct_seqB, distA, distB = batch[:8]

        logits_list = self.shared_step_predict(graph1, graph2, seqA, seqB, struct_seqA, struct_seqB, distA, distB)
        return logits_list

    # ---------------------
    # Overide on_save_checkpoint to remove esm2 from checkpoint
    # ---------------------
    # def on_save_checkpoint(self, checkpoint):
    #     # Copy state_dict to avoid modifying the original state_dict
    #     state_dict = self.state_dict().copy()

    #     # Remove 'esm2'-related keys from state_dict
    #     keys_to_remove = [key for key in state_dict if key.startswith('esm2')]
    #     for key in keys_to_remove:
    #         del state_dict[key]

    #     # Replace the original state_dict with the modified one
    #     checkpoint['state_dict'] = state_dict

    #     return checkpoint
    
    # ---------------------
    # Training Setup
    # ---------------------
    def configure_optimizers(self):
        """Called to configure the trainer's optimizer(s)."""
        optimizer = AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": CosineAnnealingWarmRestarts(optimizer, T_0=10, eta_min=1e-8, verbose=True),
                "monitor": self.metric_to_track,
            }
        }

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # -----------------
        # Model arguments
        # -----------------
        parser.add_argument('--gnn_layer_type', type=str, default='geotran',
                            help='Which type of GNN layer to use'
                                 ' (i.e., gcn for GraphConv,'
                                 ' geotran w/ --disable_geometric_mode for DGLGraphTransformer,'
                                 ' or geotran for DGLGeometricTransformer)')
        parser.add_argument('--num_gnn_hidden_channels', type=int, default=128,
                            help='Dimensionality of GNN filters (for nodes and edges alike after embedding)')
        parser.add_argument('--num_gnn_attention_heads', type=int, default=4, 
                            help='How many multi-head GNN attention blocks to run in parallel')
        parser.add_argument('--interact_module_type', type=str, default='deepinter',
                            help='Which type of dense prediction interaction module to use'
                                 ' (i.e. deepinter for DeepInter triangle, or dil_resnet for ResNet2DInputWithOptAttention, or deeplab for DeepLabV3Plus)')
        parser.add_argument('--num_interact_hidden_channels', type=int, default=128,
                            help='Dimensionality of interaction module filters')
        parser.add_argument('--use_interact_attention', action='store_true', dest='use_interact_attention',
                            help='Whether to employ attention in, for example, a Dilated ResNet')
        parser.add_argument('--num_interact_attention_heads', type=int, default=4,
                            help='How many multi-head interact attention blocks to use in parallel')
        parser.add_argument('--disable_geometric_mode', action='store_true', dest='disable_geometric_mode',
                            help='Whether to convert the Geometric Transformer into the original Graph Transformer')
        parser.add_argument('--viz_every_n_epochs', type=int, default=1,
                            help='By how many epochs to space out model prediction visualizations during training')
        parser.add_argument('--weight_classes', action='store_true', dest='weight_classes',
                            help='Whether to use class weighting in our training Cross Entropy')
        parser.add_argument('--fine_tune', action='store_true', dest='fine_tune',
                            help='Whether to fine-tune a trained LitGINI on a new dataset')
        parser.add_argument('--left_pdb_filepath', type=str, default='test_data/4heq_l.pdb',
                            help='A filepath to the left input PDB chain')
        parser.add_argument('--right_pdb_filepath', type=str, default='test_data/4heq_r.pdb',
                            help='A filepath to the right input PDB chain')
        parser.add_argument('--left_fasta_filepath', type=str, default='test_data/4heq_l.pdb',
                            help='A filepath to the left input sequence')
        parser.add_argument('--right_fasta_filepath', type=str, default='test_data/4heq_r.pdb',
                            help='A filepath to the right input sequence')
        return parser
