import torch
import torch.nn as nn

from src.models.transformer_model import GraphTransformer


class TwoTrackScoring(nn.Module):
    def __init__(self, n_layers, input_dims, hidden_mlp_dims, hidden_dims):
        super().__init__()
        encoder_output_dims = {
            'X': 0,
            'E': 0,
            'y': hidden_mlp_dims['y'],
        }
        self.product_encoder = GraphTransformer(
            n_layers=n_layers,
            input_dims=input_dims,
            hidden_mlp_dims=hidden_mlp_dims,
            hidden_dims=hidden_dims,
            output_dims=encoder_output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU(),
            addition=False,
        )
        self.reactants_encoder = GraphTransformer(
            n_layers=n_layers,
            input_dims=input_dims,
            hidden_mlp_dims=hidden_mlp_dims,
            hidden_dims=hidden_dims,
            output_dims=encoder_output_dims,
            act_fn_in=nn.ReLU(),
            act_fn_out=nn.ReLU(),
            addition=False,
        )
        self.scoring_block = nn.Sequential(
            nn.Linear(2 * encoder_output_dims['y'], encoder_output_dims['y']),
            nn.ReLU(),
            nn.Linear(encoder_output_dims['y'], 1)
        )

    def forward(self, p_X, p_E, p_y, p_node_mask, r_X, r_E, r_y, r_node_mask):
        product_emb = self.product_encoder(p_X, p_E, p_y, p_node_mask).y  # (bs, encoder_output_dims['y'])
        reactants_emb = self.reactants_encoder(r_X, r_E, r_y, r_node_mask).y  # (bs, encoder_output_dims['y'])
        joint_emb = torch.cat([product_emb, reactants_emb], dim=-1)  # (bs, 2 * encoder_output_dims['y])
        return self.scoring_block(joint_emb)
