import pytorch_lightning as pl
from omegaconf import OmegaConf

import numpy as np
import torch

from losses import losses
from modules import EncoderBlock, LayerNormalizingBlock
from modules import IterRefEmbedding, TransformerEmbedding
from modules import HopfieldBlock_chemTrainSpace
from metrics import auc_score_train, deltaAUPRC_score_train
from distance_metrics import distance_metrics
from optimizer import define_opimizer

class ClassicSimSearch(pl.LightningModule):
    def __init__(self, config: OmegaConf):
        super(ClassicSimSearch, self).__init__()

        # Config
        self.config = config

        # Similarity Block
        self.similarity_function = distance_metrics[config.model.similarityBlock.type]

        # Output function
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):

        # Similarities:
        predictions_supportActives = self.similarity_function(queryMols, supportMolsActive,
                                                              supportSetActivesSize,
                                                              device=self.device,
                                                              scaling=self.config.model.similarityBlock.scaling,
                                                              l2Norm=self.config.model.similarityBlock.l2Norm)
        _predictions_supportInactives = self.similarity_function(queryMols, supportMolsInactive,
                                                                 supportSetInactivesSize,
                                                                 device=self.device,
                                                                 scaling=self.config.model.similarityBlock.scaling,
                                                                 l2Norm=self.config.model.similarityBlock.l2Norm)
        predictions = predictions_supportActives - _predictions_supportInactives

        return predictions

class NeuralSearch(pl.LightningModule):
    def __init__(self, config: OmegaConf):
        super(NeuralSearch, self).__init__()

        # Config
        self.config = config

        # Loss functions
        self.LossFunction = losses[config.model.training.loss]

        # Hyperparameter
        self.save_hyperparameters(config)

        # Encoder
        self.encoder = EncoderBlock(config)

        # Layernormalizing-block
        if self.config.model.layerNormBlock.usage == True:
            self.layerNormBlock = LayerNormalizingBlock(config)


        # Similarity Block
        self.similarity_function = distance_metrics[config.model.similarityBlock.type]

        # Output function
        self.sigmoid = torch.nn.Sigmoid()
        self.prediction_scaling = config.model.prediction_scaling

    def forward(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding = self.encoder(queryMols)
        supportActives_embedding = self.encoder(supportMolsActive)
        supportInactives_embedding = self.encoder(supportMolsInactive)

        # Layer normalization:
        if self.config.model.layerNormBlock.usage == True:
            (query_embedding, supportActives_embedding,
             supportInactives_embedding) = self.layerNormBlock(query_embedding, supportActives_embedding,
                                                               supportInactives_embedding)

        # Similarities:
        predictions_supportActives = self.similarity_function(query_embedding, supportActives_embedding,
                                                              supportSetActivesSize,
                                                              device=self.device,
                                                              scaling=self.config.model.similarityBlock.scaling,
                                                              l2Norm=self.config.model.similarityBlock.l2Norm)
        _predictions_supportInactives = self.similarity_function(query_embedding, supportInactives_embedding,
                                                                 supportSetInactivesSize,
                                                                 device=self.device,
                                                                 scaling=self.config.model.similarityBlock.scaling,
                                                                 l2Norm=self.config.model.similarityBlock.l2Norm)
        predictions = predictions_supportActives - _predictions_supportInactives
        predictions = self.sigmoid(self.prediction_scaling * predictions)

        return predictions

    def training_step(self, batch, batch_idx):
        queryMols = batch['queryMolecule']
        labels = batch['label']
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetInactivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize)
        predictions = torch.squeeze(predictions)

        loss = self.LossFunction(predictions, labels.reshape(-1))

        output = {'loss': loss, 'predictions':predictions, 'labels':labels, 'target_idx':target_idx}
        return output

    def validation_step(self, batch, batch_idx):

        queryMols = batch['queryMolecule']
        labels = batch['label'].squeeze().float()
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetActivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize).float()

        loss = self.LossFunction(predictions.reshape(-1), labels)

        output = {'loss': loss, 'predictions': predictions, 'labels': labels, 'target_idx':target_idx}
        return output

    def training_epoch_end(self, step_outputs):
        log_dict_training = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs], 0)
        labels = torch.cat([x['labels'] for x in step_outputs], 0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs], 0)

        pred_max = torch.max(predictions)
        pred_min = torch.min(predictions)

        auc, _ ,_ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_train': epoch_loss, 'auc_train':auc, 'dAUPRC_train': deltaAUPRC,
                      'debug_pred_max_train':pred_max,
                      'debug_pred_min_train':pred_min}
        log_dict_training.update(epoch_dict)
        self.log_dict(log_dict_training, 'training', on_epoch=True)

    def validation_epoch_end(self, step_outputs):
        log_dict_val = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs], 0)
        labels = torch.cat([x['labels'] for x in step_outputs], 0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs], 0)

        auc, _, _ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_val': epoch_loss, 'auc_val': auc, 'dAUPRC_val': deltaAUPRC}
        log_dict_val.update(epoch_dict)
        self.log_dict(log_dict_val, 'validation', on_epoch=True)

    def configure_optimizers(self):
        return define_opimizer(self.config, self.parameters())

class CAM(pl.LightningModule):
    def __init__(self, config: OmegaConf):
        super(CAM, self).__init__()

        # Config
        self.config = config

        # Loss functions
        self.LossFunction = losses[config.model.training.loss]

        # Hyperparameter
        self.save_hyperparameters(config)

        # Encoder
        self.encoder = EncoderBlock(config)

        # Layernormalizing-block
        if self.config.model.layerNormBlock.usage == True:
            self.layerNormBlock = LayerNormalizingBlock(config)

        # Transformer
        self.transformer = TransformerEmbedding(self.config)

        # Similarity Block
        self.similarity_function = distance_metrics[config.model.similarityBlock.type]

        # Output function
        self.sigmoid = torch.nn.Sigmoid()
        self.prediction_scaling = config.model.prediction_scaling

    def forward(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding = self.encoder(queryMols)
        supportActives_embedding = self.encoder(supportMolsActive)
        supportInactives_embedding = self.encoder(supportMolsInactive)

        # LayerNorm before Transformer
        (query_embedding, supportActives_embedding,
         supportInactives_embedding) = self.layerNormBlock(query_embedding, supportActives_embedding,
                                                           supportInactives_embedding)

        # Transformer part
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.transformer(query_embedding, supportActives_embedding, supportInactives_embedding,
                             supportSetActivesSize, supportSetInactivesSize)

        # Layer normalization:
        (query_embedding, supportActives_embedding,
         supportInactives_embedding) = self.layerNormBlock(query_embedding, supportActives_embedding,
                                                           supportInactives_embedding)

        # Similarities:
        predictions_supportActives = self.similarity_function(query_embedding, supportActives_embedding,
                                                              supportSetActivesSize,
                                                              device=self.device,
                                                              scaling=self.config.model.similarityBlock.scaling,
                                                              l2Norm=self.config.model.similarityBlock.l2Norm)

        _predictions_supportInactives = self.similarity_function(query_embedding, supportInactives_embedding,
                                                                 supportSetInactivesSize,
                                                                 device=self.device,
                                                                 scaling=self.config.model.similarityBlock.scaling,
                                                                 l2Norm=self.config.model.similarityBlock.l2Norm)

        predictions = predictions_supportActives - _predictions_supportInactives
        predictions = self.sigmoid(self.prediction_scaling * predictions)

        return predictions

    def training_step(self, batch, batch_idx):
        queryMols = batch['queryMolecule']
        labels = batch['label']
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetInactivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize)
        predictions = torch.squeeze(predictions)

        loss = self.LossFunction(predictions, labels.reshape(-1))

        output = {'loss': loss, 'predictions':predictions, 'labels':labels, 'target_idx':target_idx}
        return output


    def validation_step(self, batch, batch_idx):

        queryMols = batch['queryMolecule']
        labels = batch['label'].squeeze().float()
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetActivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize).float()

        loss = self.LossFunction(predictions.reshape(-1), labels)

        output = {'loss': loss, 'predictions': predictions, 'labels': labels, 'target_idx':target_idx}
        return output
        
    def training_epoch_end(self, step_outputs):
        log_dict_training = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs],axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs],axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs],axis=0)

        pred_max = torch.max(predictions)
        pred_min = torch.min(predictions)

        auc, _ ,_ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_train': epoch_loss, 'auc_train':auc, 'dAUPRC_train': deltaAUPRC,
                      'debug_pred_max_train':pred_max,
                      'debug_pred_min_train':pred_min}
        log_dict_training.update(epoch_dict)
        self.log_dict(log_dict_training, 'training', on_epoch=True)

    def validation_epoch_end(self, step_outputs):
        log_dict_val = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs], axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs], axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs], axis=0)

        auc, _, _ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_val': epoch_loss, 'auc_val': auc, 'dAUPRC_val': deltaAUPRC}
        log_dict_val.update(epoch_dict)
        self.log_dict(log_dict_val, 'validation', on_epoch=True)

    def configure_optimizers(self):
        return define_opimizer(self.config, self.parameters())

class IterRefLSTM(pl.LightningModule):
    def __init__(self, config: OmegaConf):
        super(IterRefLSTM self).__init__()

        # Config
        self.config = config

        # Loss functions
        self.LossFunction = losses[config.model.training.loss]

        # Hyperparameter
        self.save_hyperparameters(config)

        # Encoder
        self.encoder = EncoderBlock(config)

        # Layernormalizing-block
        if self.config.model.layerNormBlock.usage == True:
            self.layerNormBlock = LayerNormalizingBlock(config)

        # IterRefEmbedding-block
        self.iterRefEmbeddingBlock = IterRefEmbedding(config)


        # Similarity Block
        self.similarity_function = distance_metrics[config.model.similarityBlock.type]

        # Output function
        self.sigmoid = torch.nn.Sigmoid()
        self.prediction_scaling = config.model.prediction_scaling

    def forward(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding = self.encoder(queryMols)
        supportActives_embedding = self.encoder(supportMolsActive)
        supportInactives_embedding = self.encoder(supportMolsInactive) 

        # Layer normalization:
        if self.config.model.layerNormBlock.usage == True:
            (query_embedding, supportActives_embedding,
             supportInactives_embedding) = self.layerNormBlock(query_embedding, supportActives_embedding,
                                                               supportInactives_embedding)


        # IterRef
        (query_embedding, supportActives_embedding,
         supportInactives_embedding) = self.iterRefEmbeddingBlock(query_embedding, supportActives_embedding,
                                                                  supportInactives_embedding)


        # Similarities:
        predictions_supportActives = self.similarity_function(query_embedding, supportActives_embedding,
                                                              supportSetActivesSize,
                                                              device=self.device,
                                                              scaling=self.config.model.similarityBlock.scaling,
                                                              l2Norm=self.config.model.similarityBlock.l2Norm)
        _predictions_supportInactives = self.similarity_function(query_embedding, supportInactives_embedding,
                                                                 supportSetInactivesSize,
                                                                 device=self.device,
                                                                 scaling=self.config.model.similarityBlock.scaling,
                                                                 l2Norm=self.config.model.similarityBlock.l2Norm)

        predictions = predictions_supportActives - _predictions_supportInactives

        predictions = self.sigmoid(self.prediction_scaling * predictions)

        return predictions

    def training_step(self, batch, batch_idx):
        queryMols = batch['queryMolecule']
        labels = batch['label']
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetInactivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize)
        predictions = torch.squeeze(predictions)

        loss = self.LossFunction(predictions, labels.reshape(-1))

        output = {'loss': loss, 'predictions':predictions, 'labels':labels, 'target_idx':target_idx}
        return output


    def validation_step(self, batch, batch_idx):

        queryMols = batch['queryMolecule']
        labels = batch['label'].squeeze().float()
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetActivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize).float()

        loss = self.LossFunction(predictions.reshape(-1), labels)

        output = {'loss': loss, 'predictions': predictions, 'labels': labels, 'target_idx':target_idx}
        return output
        
    def training_epoch_end(self, step_outputs):
        log_dict_training = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs],axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs],axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs],axis=0)

        pred_max = torch.max(predictions)
        pred_min = torch.min(predictions)

        auc, _ ,_ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_train': epoch_loss, 'auc_train':auc, 'dAUPRC_train': deltaAUPRC,
                      'debug_pred_max_train':pred_max,
                      'debug_pred_min_train':pred_min}
        log_dict_training.update(epoch_dict)
        self.log_dict(log_dict_training, 'training', on_epoch=True)

    def validation_epoch_end(self, step_outputs):
        log_dict_val = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs], axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs], axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs], axis=0)

        auc, _, _ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_val': epoch_loss, 'auc_val': auc, 'dAUPRC_val': deltaAUPRC}
        log_dict_val.update(epoch_dict)
        self.log_dict(log_dict_val, 'validation', on_epoch=True)

    def configure_optimizers(self):
        return define_opimizer(self.config, self.parameters())

class MHNfs(pl.LightningModule):
    def __init__(self, config: OmegaConf):
        super(MHNfs, self).__init__()

        # Config
        self.config = config

        # Load reference set
        self.trainExemplarMemory = torch.unsqueeze(torch.from_numpy(
            np.load( # TODO: Set path to FSMOl training split input molecule matrix # ).astype('float32')
        ), 0).to('cpu')

        self.referenceSet_embedding = torch.ones(1,
                                                 512,
                                                 1024).to(config.system.ressources.device)

        self.layerNorm_refSet = torch.nn.LayerNorm(config.model.associationSpace_dim,
                                                   elementwise_affine=config.model.layerNormBlock.affine)

        # Loss functions
        self.LossFunction = losses[config.model.training.loss]

        # Hyperparameter
        self.save_hyperparameters(config)

        # Encoder
        self.encoder = EncoderBlock(config)

        # Hopfield for trained chemical space retrieval
        self.hopfield_chemTrainSpace = HopfieldBlock_chemTrainSpace(self.config)


        # Layernormalizing-block
        if self.config.model.layerNormBlock.usage == True:
            self.layerNormBlock = LayerNormalizingBlock(config)

        # Transformer
        self.transformer = TransformerEmbedding(self.config)


        # Similarity Block
        self.similarity_function = distance_metrics[config.model.similarityBlock.type]

        # Output function
        self.sigmoid = torch.nn.Sigmoid()
        self.prediction_scaling = config.model.prediction_scaling

    def forward(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding = self.encoder(queryMols)
        supportActives_embedding = self.encoder(supportMolsActive)
        supportInactives_embedding = self.encoder(supportMolsInactive)

        # Retrieve updated representations from chemical training space
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.hopfield_chemTrainSpace(query_embedding, supportActives_embedding, supportInactives_embedding,
                                         supportSetActivesSize, supportSetInactivesSize,
                                         self.referenceSet_embedding)

        # Transformer part
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.transformer(query_embedding, supportActives_embedding, supportInactives_embedding,
                             supportSetActivesSize, supportSetInactivesSize)

        # Layer normalization:
        if self.config.model.layerNormBlock.usage == True:
            (
                query_embedding, supportActives_embedding, supportInactives_embedding
            ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)

        # Similarities:
        predictions_supportActives = self.similarity_function(query_embedding, supportActives_embedding,
                                                              supportSetActivesSize,
                                                              device=self.device,
                                                              scaling=self.config.model.similarityBlock.scaling,
                                                              l2Norm=self.config.model.similarityBlock.l2Norm)

        _predictions_supportInactives = self.similarity_function(query_embedding, supportInactives_embedding,
                                                                 supportSetInactivesSize,
                                                                 device=self.device,
                                                                 scaling=self.config.model.similarityBlock.scaling,
                                                                 l2Norm=self.config.model.similarityBlock.l2Norm)

        predictions = predictions_supportActives - _predictions_supportInactives

        predictions = self.sigmoid(self.prediction_scaling * predictions)

        return predictions

    def training_step(self, batch, batch_idx):
        with torch.no_grad():
            self._update_referenceSet_embedding()

        queryMols = batch['queryMolecule']
        labels = batch['label']
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetInactivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive, supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize)
        predictions = torch.squeeze(predictions)

        loss = self.LossFunction(predictions, labels.reshape(-1))

        output = {'loss': loss, 'predictions':predictions, 'labels':labels, 'target_idx':target_idx}
        return output


    def validation_step(self, batch, batch_idx):
        queryMols = batch['queryMolecule']
        labels = batch['label'].squeeze().float()
        supportMolsActive = batch['supportSetActives']
        supportMolsInactive = batch['supportSetInactives']
        supportSetActivesSize = batch['supportSetActivesSize']
        supportSetInactivesSize = batch['supportSetActivesSize']
        target_idx = batch['taskIdx']

        predictions = self.forward(queryMols, supportMolsActive,
                                   supportMolsInactive,
                                   supportSetActivesSize, supportSetInactivesSize).float()

        loss = self.LossFunction(predictions.reshape(-1), labels)

        output = {'loss': loss, 'predictions': predictions, 'labels': labels, 'target_idx':target_idx}
        return output
        
    def training_epoch_end(self, step_outputs):
        with torch.no_grad():
            self._update_referenceSet_embedding()

        log_dict_training = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs],axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs],axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs],axis=0)

        pred_max = torch.max(predictions)
        pred_min = torch.min(predictions)

        auc, _ ,_ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_train': epoch_loss, 'auc_train':auc, 'dAUPRC_train': deltaAUPRC,
                      'debug_pred_max_train':pred_max,
                      'debug_pred_min_train':pred_min}
        log_dict_training.update(epoch_dict)
        self.log_dict(log_dict_training, 'training', on_epoch=True)

    def validation_epoch_end(self, step_outputs):
        log_dict_val = dict()

        # Predictions
        predictions = torch.cat([x['predictions'] for x in step_outputs], axis=0)
        labels = torch.cat([x['labels'] for x in step_outputs], axis=0)
        epoch_loss = torch.mean(torch.tensor([x["loss"] for x in step_outputs]))
        target_ids = torch.cat([x['target_idx'] for x in step_outputs], axis=0)

        auc, _, _ = auc_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())
        deltaAUPRC, _, _ = deltaAUPRC_score_train(predictions.cpu(), labels.cpu(), target_ids.cpu())

        epoch_dict = {'loss_val': epoch_loss, 'auc_val': auc, 'dAUPRC_val': deltaAUPRC}
        log_dict_val.update(epoch_dict)
        self.log_dict(log_dict_val, 'validation', on_epoch=True)

    def configure_optimizers(self):
        return define_opimizer(self.config, self.parameters())

    def _update_referenceSet_embedding(self):
        max_rows = self.trainExemplarMemory.shape[1]
        number_requested_rows = int(np.rint(0.05 * max_rows))

        sampled_rows = torch.randperm(max_rows)[:number_requested_rows]

        self.referenceSet_embedding = self.layerNorm_refSet(
            self.encoder(self.trainExemplarMemory[:, sampled_rows, :])
        )

    def _get_retrival_and_referenceembedding(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding_init = self.encoder(queryMols)
        supportActives_embedding_init = self.encoder(supportMolsActive)
        supportInactives_embedding_init = self.encoder(supportMolsInactive)

        # LayerNorm
        (
            query_embedding_init, supportActives_embedding_init, supportInactives_embedding_init
        ) = self.layerNormBlock(query_embedding_init, supportActives_embedding_init, supportInactives_embedding_init)

        # Retrieve updated representations from chemical training space
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.hopfield_chemTrainSpace(query_embedding_init, supportActives_embedding_init,
                                         supportInactives_embedding_init,
                                         supportSetActivesSize, supportSetInactivesSize,
                                         self.referenceSet_embedding)

        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)

        return (query_embedding_init, supportActives_embedding_init , supportInactives_embedding_init,
                query_embedding, supportActives_embedding, supportInactives_embedding, self.referenceSet_embedding)

    def _get_hopfield_association_mtx(self, queryMols, supportMolsActive, supportMolsInactive):
        # Embeddings
        query_embedding_init = self.encoder(queryMols)
        supportActives_embedding_init = self.encoder(supportMolsActive)
        supportInactives_embedding_init = self.encoder(supportMolsInactive)

        # LayerNorm
        (
            query_embedding_init, supportActives_embedding_init, supportInactives_embedding_init
        ) = self.layerNormBlock(query_embedding_init, supportActives_embedding_init, supportInactives_embedding_init)

        # Retrieve updated representations from chemical training space
        S = torch.cat((query_embedding_init, supportActives_embedding_init, supportInactives_embedding_init), 1)

        S_flattend = S.reshape(1, S.shape[0] * S.shape[1], S.shape[2])

        association_mtx = self.hopfield_chemTrainSpace.hopfield.get_association_matrix(
            (self.referenceSet_embedding, S_flattend, self.referenceSet_embedding)
        )

        return association_mtx

    def _get_crossAttentionEmbeddings(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):
        # Embeddings
        query_embedding= self.encoder(queryMols)
        supportActives_embedding = self.encoder(supportMolsActive)
        supportInactives_embedding = self.encoder(supportMolsInactive)

        # LayerNorm
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)

        # Retrieve updated representations from chemical training space
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.hopfield_chemTrainSpace(query_embedding, supportActives_embedding,
                                         supportInactives_embedding,
                                         supportSetActivesSize, supportSetInactivesSize,
                                         self.referenceSet_embedding)

        (
            query_embedding_input, supportActives_embedding_input, supportInactives_embedding_input
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)

        # Cross Attention module
        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.transformer(query_embedding_input, supportActives_embedding_input, supportInactives_embedding_input,
                             supportSetActivesSize, supportSetInactivesSize)

        (
            query_embedding, supportActives_embedding, supportInactives_embedding
        ) = self.layerNormBlock(query_embedding, supportActives_embedding, supportInactives_embedding)

        return (query_embedding_input, supportActives_embedding_input , supportInactives_embedding_input,
                query_embedding, supportActives_embedding, supportInactives_embedding)

    def _get_retrival_and_referenceembedding_justhopfieldblock(self, queryMols, supportMolsActive, supportMolsInactive,
                supportSetActivesSize=0, supportSetInactivesSize=0):

        # Retrieve updated representations from chemical training space
        (
            query_embedding, _, _
        ) = self.hopfield_chemTrainSpace(queryMols, supportMolsActive, supportMolsInactive,
                                         supportSetActivesSize, supportSetInactivesSize,
                                         self.referenceSet_embedding)
        return query_embedding, self.referenceSet_embedding
