# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
import torch.nn.functional as F
import pandas as pd
from unicore import metrics
from unicore.losses import UnicoreLoss, register_loss
from unicore.losses.cross_entropy import CrossEntropyLoss
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import numpy as np
import warnings
from sklearn.metrics import top_k_accuracy_score
from rdkit.ML.Scoring.Scoring import CalcBEDROC


def calculate_bedroc(y_true, y_score, alpha):
    """
    Calculate BEDROC score.

    Parameters:
    - y_true: true binary labels (0 or 1)
    - y_score: predicted scores or probabilities
    - alpha: parameter controlling the degree of early retrieval emphasis

    Returns:
    - BEDROC score
    """
    
        # concate res_single and labels
    scores = np.expand_dims(y_score, axis=1)
    y_true = np.expand_dims(y_true, axis=1)
    scores = np.concatenate((scores, y_true), axis=1)
    # inverse sort scores based on first column
    scores = scores[scores[:,0].argsort()[::-1]]
    bedroc = CalcBEDROC(scores, 1, 80.5)
    return bedroc

@register_loss("decoder_loss")
class DecoderLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            features_only=True,
            classification_head_name=None,
            fix_encoder=fix_encoder
        )
        loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        targets = sample["net_input"]["selfie_tokens"]
        sample_size = targets.size(0)
        
        lprobs = net_output[:,:,:targets.shape[-1]]
        if not self.training:
            logging_output = {
                "loss": loss.data,
                "prob": lprobs.data,
                "target": targets.data,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = net_output
        targets = sample["net_input"]["selfie_tokens"]
        lprobs = lprobs[:,:,:targets.shape[-1]]
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        nll_loss = F.nll_loss(
            lprobs,
            targets,
            reduction="sum" if reduce else "none",
        ) / lprobs.shape[-1]

        loss =  nll_loss 
        #print(loss.data, nll_loss.data, kld_loss.data)
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        # if split == "valid":
        #     print("hi1")
        loss = sum(log.get("loss", 0).float() for log in logging_outputs)
        # if split == "valid":
        #     print("hi2")
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # if split == "valid":
        #     print("hi3")
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss_all", loss / sample_size / math.log(2), sample_size, round=3
        )

        if "valid" in split or "test" in split:
            
            prob_list = []
            pred_list = []
            target_list = []
            for log in logging_outputs:
                prob = log.get("prob")
                prob = torch.transpose(prob, 1, 2)
                prob = prob.reshape((-1, prob.shape[-1]))
                prob_list.append(prob)
                pred = log.get("prob").argmax(dim=1)
                pred = pred.flatten()
                pred_list.append(pred)
                target = log.get("target")
                target = target.flatten()
                target_list.append(target)

            preds = torch.cat(pred_list, dim=0)
            targets = torch.cat(target_list, dim=0)
            #print(preds.shape, targets.shape)
            acc = (preds == targets).float().mean(dim=-1)
            #print(acc.shape)
            metrics.log_scalar(
                f"{split}_acc", acc , sample_size, round=3
            )
            
        
    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train


@register_loss("decoder_vae_loss")
class DecoderVAELoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, candidate_reps, candidate_embs, candidate_smiles, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            candidate_reps=candidate_reps,
            candidate_embs=candidate_embs,
            candidate_smiles=candidate_smiles,
            features_only=True,
            classification_head_name=None,
            fix_encoder=fix_encoder
        )
        loss, nll_loss, kld_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        targets = sample["net_input"]["selfie_tokens"]
        sample_size = targets.size(0)
        
        lprobs = net_output[0][:,:,:targets.shape[-1]]
        if not self.training:
            logging_output = {
                "loss": loss.data,
                "kld_loss": kld_loss.data,
                "nll_loss": nll_loss.data,
                "prob": lprobs.data,
                "target": targets.data,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "kld_loss": kld_loss.data,
                "nll_loss": nll_loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        out, z, mu, log_var = net_output
        lprobs = out
        targets = sample["net_input"]["selfie_tokens"]
        lprobs = lprobs[:,:,:targets.shape[-1]]
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        nll_loss = F.nll_loss(
            lprobs,
            targets,
            reduction="sum" if reduce else "none",
        ) / lprobs.shape[-1]
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / lprobs.shape[1]
        p=0.2
        loss = p * kld_loss + (1-p) * nll_loss 
        #print(loss.data, nll_loss.data, kld_loss.data)
        return loss, nll_loss, kld_loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        # if split == "valid":
        #     print("hi1")
        loss_all = sum(log.get("loss", 0).float() for log in logging_outputs)
        loss_kld = sum(log.get("kld_loss", 0).float() for log in logging_outputs)
        loss_nll = sum(log.get("nll_loss", 0).float() for log in logging_outputs)
        # if split == "valid":
        #     print("hi2")
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # if split == "valid":
        #     print("hi3")
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss_all", loss_all / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "loss_kld", loss_kld / sample_size, sample_size, round=3
        )
        metrics.log_scalar(
            "loss_nll", loss_nll / sample_size, sample_size, round=3
        )
        if "valid" in split or "test" in split:
            
            prob_list = []
            pred_list = []
            target_list = []
            for log in logging_outputs:
                prob = log.get("prob")
                prob = torch.transpose(prob, 1, 2)
                prob = prob.reshape((-1, prob.shape[-1]))
                prob_list.append(prob)
                pred = log.get("prob").argmax(dim=1)
                pred = pred.flatten()
                pred_list.append(pred)
                target = log.get("target")
                target = target.flatten()
                target_list.append(target)

            probs = torch.cat(prob_list, dim=0)
            preds = torch.cat(pred_list, dim=0)
            targets = torch.cat(target_list, dim=0)
            #print(preds.shape, targets.shape)
            acc = (preds == targets).float().mean(dim=-1)
            #print(acc.shape)
            metrics.log_scalar(
                f"{split}_acc", acc , sample_size, round=3
            )
            '''
            # smi_list = [
            #     item for log in logging_outputs for item in log.get("smi_name")
            # ]
            probs = torch.exp(probs)
            #prob_flat = prob_flat.reshape((-1, prob_flat.shape[-1]))
            print(probs.shape)

            #targets = targets.squeeze(dim=-1)
            auc = roc_auc_score(targets.cpu(), probs.cpu(), multi_class="ovo", labels=torch.arange(probs.shape[-1]))
            #df = df.groupby("smi").mean()
            #agg_auc = roc_auc_score(df["targets"], df["probs"])
            agg_auc = auc
            
            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)
            '''
        
    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train


@register_loss("finetune_cross_entropy")
class FinetuneCrossEntropyLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            features_only=True,
            classification_head_name=self.args.classification_head_name,
            fix_encoder=fix_encoder
        )
        logit_output = net_output[0]
        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
        sample_size = sample["target"]["finetune_target"].size(0)
        if not self.training:
            probs = F.softmax(logit_output.float(), dim=-1).view(
                -1, logit_output.size(-1)
            )
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": sample["target"]["finetune_target"].view(-1).data,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = F.log_softmax(net_output.float(), dim=-1)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        targets = sample["target"]["finetune_target"].view(-1)
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss = F.nll_loss(
            lprobs,
            targets,
            reduction="sum" if reduce else "none",
        )
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            acc_sum = sum(
                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0)
            metrics.log_scalar(
                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
            )
            if probs.size(-1) == 2:
                # binary classification task, add auc score
                targets = torch.cat(
                    [log.get("target", 0) for log in logging_outputs], dim=0
                )
                smi_list = [
                    item for log in logging_outputs for item in log.get("smi_name")
                ]
                df = pd.DataFrame(
                    {
                        "probs": probs[:, 1].cpu(),
                        "targets": targets.cpu(),
                        "smi": smi_list,
                    }
                )
                auc = roc_auc_score(df["targets"], df["probs"])
                df = df.groupby("smi").mean()
                agg_auc = roc_auc_score(df["targets"], df["probs"])
                metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
                metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)

    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("ce")
class CEntropyLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            smi_list = sample["smi_name"],
            pocket_list = sample["pocket_name"],
            features_only=True,
            fix_encoder=fix_encoder
        )
        logit_output = net_output
        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
        sample_size = sample["target"]["finetune_target"].size(0)
        if not self.training:
            probs = torch.sigmoid(logit_output.float()).view(-1, logit_output.size(-1))
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": sample["target"]["finetune_target"].view(-1).data,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):

        targets = sample["target"]["finetune_target"].view(-1)
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss = F.binary_cross_entropy_with_logits(
            net_output.float(),
            targets,
            reduction="sum" if reduce else "none",
        )
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            acc_sum = sum(
                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0)
            metrics.log_scalar(
                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
            )
            if probs.size(-1) == 2:
                # binary classification task, add auc score
                targets = torch.cat(
                    [log.get("target", 0) for log in logging_outputs], dim=0
                )
                smi_list = [
                    item for log in logging_outputs for item in log.get("smi_name")
                ]
                df = pd.DataFrame(
                    {
                        "probs": probs[:, 1].cpu(),
                        "targets": targets.cpu(),
                        "smi": smi_list,
                    }
                )
                auc = roc_auc_score(df["targets"], df["probs"])
                df = df.groupby("smi").mean()
                agg_auc = roc_auc_score(df["targets"], df["probs"])
                metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
                metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)

    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train


@register_loss("in_batch_softmax")
class IBSLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            smi_list = sample["smi_name"],
            pocket_list = sample["pocket_name"],
            features_only=True,
            fix_encoder=fix_encoder,
            is_train = self.training
        )
        
        logit_output = net_output[0]
        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
        sample_size = logit_output.size(0)
        targets = torch.arange(sample_size, dtype=torch.long).cuda()
        if not self.training:
            logit_output = logit_output[:,:sample_size]
            probs = F.softmax(logit_output.float(), dim=-1).view(
                -1, logit_output.size(-1)
            )
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": targets,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": targets.size(0),
                "scale": net_output[1].data
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": targets.size(0),
                "scale": net_output[1].data
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs_pocket = F.log_softmax(net_output.float(), dim=-1)
        lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
        sample_size = lprobs_pocket.size(0)
        targets= torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
        #targets = sample["target"]["finetune_target"].view(-1)
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss_pocket = F.nll_loss(
            lprobs_pocket,
            targets,
            reduction="sum" if reduce else "none",
        )
        
        lprobs_mol = F.log_softmax(torch.transpose(net_output.float(), 0, 1), dim=-1)
        lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
        lprobs_mol = lprobs_mol[:sample_size]

        #targets = sample["target"]["finetune_target"].view(-1)
        
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss_mol = F.nll_loss(
            lprobs_mol,
            targets,
            reduction="sum" if reduce else "none",
        )
        
        loss = 0.5 * loss_pocket + 0.5 * loss_mol
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        metrics.log_scalar("scale", logging_outputs[0].get("scale"), round=3)
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            acc_sum = sum(
                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            
            prob_list = []
            if len(logging_outputs) == 1:
                prob_list.append(logging_outputs[0].get("prob"))
            else:
                for i in range(len(logging_outputs)-1):
                    prob_list.append(logging_outputs[i].get("prob"))
            probs = torch.cat(prob_list, dim=0)
            
            metrics.log_scalar(
                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
            )

            metrics.log_scalar(
                "valid_neg_loss", -loss_sum / sample_size / math.log(2), sample_size, round=3
            )
            
            
            targets = torch.cat(
                [log.get("target", 0) for log in logging_outputs], dim=0
            )
            print(targets.shape, probs.shape)

            targets = targets[:len(probs)]
            auc_list = []
            for i in range(len(probs)):
                prob = probs[i]
                target = targets[i]
                label = torch.zeros_like(prob)
                label[target]=1.0
                cur_auc = roc_auc_score(label.cpu(), prob.cpu())
                bedroc = calculate_bedroc(label.cpu(), prob.cpu(), 80.5)
                #print(bedroc)
                auc_list.append(bedroc)
            bedroc = np.mean(auc_list)

            auc = roc_auc_score(targets.cpu(), probs.cpu(), average='macro', multi_class='ovr')
            
            top_k_acc = top_k_accuracy_score(targets.cpu(), probs.cpu(), k=3, normalize=True)
            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
            metrics.log_scalar(f"{split}_bedroc", bedroc, sample_size, round=3)
            metrics.log_scalar(f"{split}_top3_acc", top_k_acc, sample_size, round=3)
    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("in_batch_softmax_hns")
class IBSHLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        #print(self.training)
        net_output = model(
            **sample["net_input"],
            smi_list = sample["smi_name"],
            pocket_list = sample["pocket_name"],
            features_only=True,
            is_train = self.training,
        )
        
        logit_output = net_output[0]
        logit_output_hns = net_output[1]
        loss, loss_hns = self.compute_loss(model, logit_output, logit_output_hns, sample, reduce=reduce)
        #print(loss, loss_hns)
        sample_size = logit_output.size(0)
        targets = torch.arange(sample_size, dtype=torch.long).cuda()
        if not self.training:
            logit_output = logit_output[:,:sample_size]
            probs = F.softmax(logit_output.float(), dim=-1).view(
                -1, logit_output.size(-1)
            )
            logging_output = {
                "loss": loss.data,
                "loss_hns": loss_hns.data,
                "prob": probs.data,
                "target": targets,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": targets.size(0),
                "scale": net_output[2].data
            }
        else:
            logging_output = {
                "loss": loss.data,
                "loss_hns": loss_hns.data,
                "sample_size": sample_size,
                "bsz": targets.size(0),
                "scale": net_output[2].data
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, net_ouput_hns, sample, reduce=True):
        lprobs_pocket = F.log_softmax(net_output.float(), dim=-1)
        lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
        sample_size = lprobs_pocket.size(0)
        targets= torch.arange(sample_size, dtype=torch.long).view(-1).cuda()
        #targets = sample["target"]["finetune_target"].view(-1)
        
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss_pocket = F.nll_loss(
            lprobs_pocket,
            targets,
            reduction="sum" if reduce else "none",
        )
        lprobs_mol = F.log_softmax(torch.transpose(net_output.float(), 0, 1), dim=-1)
        lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))
        lprobs_mol = lprobs_mol[:sample_size]

        #targets = sample["target"]["finetune_target"].view(-1)
        
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss_mol = F.nll_loss(
            lprobs_mol,
            targets,
            reduction="sum" if reduce else "none",
        )
        if not self.training:
            loss = 0.5 * loss_pocket + 0.5 * loss_mol
            return loss, loss
        
        lprobs_hns = F.log_softmax(net_ouput_hns.float(), dim=-1)
        lprobs_hns = lprobs_hns.view(-1, lprobs_hns.size(-1))
        target_hns = torch.ones_like(targets).cuda()
        loss_hns = F.nll_loss(
            lprobs_hns,
            target_hns,
            reduction="sum" if reduce else "none",
        )

        
        
        loss = 0.5 * loss_pocket + 0.5 * loss_mol + 0.5 * loss_hns
        return loss, loss_hns

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        metrics.log_scalar("scale", logging_outputs[0].get("scale"), round=3)
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        loss_hns_sum = sum(log.get("loss_hns", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar("loss_hns", loss_hns_sum / sample_size / math.log(2), sample_size, round=3)

        if "valid" in split or "test" in split:
            acc_sum = sum(
                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            
            prob_list = []
            if len(logging_outputs) == 1:
                prob_list.append(logging_outputs[0].get("prob"))
            else:
                for i in range(len(logging_outputs)-1):
                    prob_list.append(logging_outputs[i].get("prob"))
            probs = torch.cat(prob_list, dim=0)
            
            metrics.log_scalar(
                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
            )

            metrics.log_scalar(
                "valid_neg_loss", -loss_sum / sample_size / math.log(2), sample_size, round=3
            )
            
            
            targets = torch.cat(
                [log.get("target", 0) for log in logging_outputs], dim=0
            )
            print(targets.shape, probs.shape)

            targets = targets[:len(probs)]
            auc_list = []
            for i in range(len(probs)):
                prob = probs[i]
                target = targets[i]
                label = torch.zeros_like(prob)
                label[target]=1.0
                cur_auc = roc_auc_score(label.cpu(), prob.cpu())
                auc_list.append(cur_auc)
            auc = np.mean(auc_list)

            auc = roc_auc_score(targets.cpu(), probs.cpu(), average='macro', multi_class='ovr')
            top_k_acc = top_k_accuracy_score(targets.cpu(), probs.cpu(), k=3, normalize=True)
            metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3)
            metrics.log_scalar(f"{split}_top3_acc", top_k_acc, sample_size, round=3)
    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("in_batch_softmax_colbert")
class IBSColbertLoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            smi_list = sample["smi_name"],
            pocket_list = sample["pocket_name"],
            features_only=True,
            fix_encoder=fix_encoder
        )
        logit_output_s = net_output[0]
        logit_output_t_pocket = net_output[1]
        logit_output_t_mol = net_output[2]
        loss, loss_s, loss_t, loss_pocket_t, loss_mol_t, loss_kd_pocket, loss_kd_mol = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = logit_output_s.size(0)
        targets = torch.arange(sample_size, dtype=torch.long).cuda()
        if not self.training:
            probs_s = F.softmax(logit_output_s.float(), dim=-1).view(
                -1, logit_output_s.size(-1)
            )
            probs_t = F.softmax(logit_output_t_pocket.float(), dim=-1).view(
                -1, logit_output_t_pocket.size(-1)
            )
            logging_output = {
                "loss": loss.data,
                "loss_s": loss_s.data,
                "loss_t": loss_t.data,
                "loss_pocket_t": loss_pocket_t.data,
                "loss_mol_t": loss_mol_t.data,
                "loss_kd_pocket": loss_kd_pocket.data,
                "loss_kd_mol": loss_kd_mol.data,
                "prob_s": probs_s.data,
                "prob_t": probs_t.data,
                "target": targets,
                "smi_name": sample["smi_name"],
                "sample_size": sample_size,
                "bsz": targets.size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "loss_s": loss_s.data,
                "loss_t": loss_t.data,
                "loss_pocket_t": loss_pocket_t.data,
                "loss_mol_t": loss_mol_t.data,
                "loss_kd_pocket": loss_kd_pocket.data,
                "loss_kd_mol": loss_kd_mol.data,
                "sample_size": sample_size,
                "bsz": targets.size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss_once(self, net_output_pocket, net_output_mol, mode, reduce=True):
        lprobs_pocket = F.log_softmax(net_output_pocket.float(), dim=-1)
        lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1))
        sample_size = lprobs_pocket.size(0)
        targets= torch.arange(sample_size, dtype=torch.long).view(-1).cuda()

        loss_pocket = F.nll_loss(
            lprobs_pocket,
            targets,
            reduction="sum" if reduce else "none",
        )
        
        lprobs_mol = F.log_softmax(net_output_mol.float(), dim=-1)
        lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1))

        #targets = sample["target"]["finetune_target"].view(-1)
        
        # print("111", lprobs.shape, targets.shape, sample["target"]["finetune_target"].shape)
        loss_mol = F.nll_loss(
            lprobs_mol,
            targets,
            reduction="sum" if reduce else "none",
        )
        # if mode == "t":
        #     loss_pocket *= 0.0
        loss = 0.5 * loss_pocket + 0.5 * loss_mol
        return loss, loss_pocket, loss_mol, lprobs_pocket, lprobs_mol

    def compute_loss(self, model, net_output, sample, reduce=True):
        loss_s, loss_pocket_s, loss_mol_s, lprobs_pocket_s, lprobs_mol_s  = self.compute_loss_once(net_output[0], net_output[0].T, mode = "s", reduce=reduce)
        loss_t, loss_pocket_t, loss_mol_t, lprobs_pocket_t, lprobs_mol_t = self.compute_loss_once(net_output[1], net_output[2],  mode = "t", reduce=reduce)
        kl_loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
        loss_kd_pocket = kl_loss(lprobs_pocket_s, lprobs_pocket_t.detach()) 
        loss_kd_mol = kl_loss(lprobs_mol_s, lprobs_mol_t.detach()) 
        #print(loss_s, loss_t, loss_kd_pocket, loss_kd_mol)
        loss = loss_s + loss_t + loss_kd_pocket + loss_kd_mol
        loss = 0.5*loss_t + 0.5*loss_s #+ loss_kd_pocket + loss_kd_mol
        return loss, loss_s, loss_t, loss_pocket_t, loss_mol_t, loss_kd_pocket, loss_kd_mol

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        loss_s_sum = sum(log.get("loss_s", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_s", loss_s_sum / sample_size / math.log(2), sample_size, round=3
        )

        loss_t_sum = sum(log.get("loss_t", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_t", loss_t_sum / sample_size / math.log(2), sample_size, round=3
        )
        loss_pocket_t_sum = sum(log.get("loss_pocket_t", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_pocket_t", loss_pocket_t_sum / sample_size / math.log(2), sample_size, round=3
        )
        loss_mol_t_sum = sum(log.get("loss_mol_t", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_mol_t", loss_mol_t_sum / sample_size / math.log(2), sample_size, round=3
        )
        loss_kd_pocket_sum = sum(log.get("loss_kd_pocket", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_kd_pocket", loss_kd_pocket_sum / sample_size / math.log(2), sample_size, round=3
        )
        loss_kd_mol_sum = sum(log.get("loss_kd_mol", 0) for log in logging_outputs)
        metrics.log_scalar(
            "loss_kd_mol", loss_kd_mol_sum / sample_size / math.log(2), sample_size, round=3
        )


        if "valid" in split or "test" in split:
            acc_sum_s = sum(
                sum(log.get("prob_s").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            metrics.log_scalar(
                "valid_neg_loss", -loss_sum / sample_size / math.log(2), sample_size, round=3
            )
            
            prob_list_s = []
            if len(logging_outputs) == 1:
                prob_list_s.append(logging_outputs[0].get("prob_s"))
            else:
                for i in range(len(logging_outputs)-1):
                    prob_list_s.append(logging_outputs[i].get("prob_s"))
            probs_s = torch.cat(prob_list_s, dim=0)
            
            metrics.log_scalar(
                f"{split}_acc_s", acc_sum_s / sample_size, sample_size, round=3
            )
            
            
            targets = torch.cat(
                [log.get("target", 0) for log in logging_outputs], dim=0
            )
            print(targets.shape, probs_s.shape)
            bed_list_s = []
            for i in range(len(probs_s)):
                prob = probs_s[i]
                target = targets[i]
                label = torch.zeros_like(prob)
                label[target]=1.0
                bedroc = calculate_bedroc(label.cpu(), prob.cpu(), 80.5)
                bed_list_s.append(bedroc)
            bedroc_s = np.mean(bed_list_s)
            targets = targets[:len(probs_s)]
            auc_s = roc_auc_score(targets.cpu(), probs_s.cpu(), average='macro', multi_class='ovr')
            top_k_acc_s = top_k_accuracy_score(targets.cpu(), probs_s.cpu(), k=3, normalize=True)
            metrics.log_scalar(f"{split}_bedroc_s", bedroc_s, sample_size, round=3)
            metrics.log_scalar(f"{split}_auc_s", auc_s, sample_size, round=3)
            metrics.log_scalar(f"{split}_top3_acc_s", top_k_acc_s, sample_size, round=3)

            acc_sum_t= sum(
                sum(log.get("prob_t").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            
            prob_list_t = []
            if len(logging_outputs) == 1:
                prob_list_t.append(logging_outputs[0].get("prob_t"))
            else:
                for i in range(len(logging_outputs)-1):
                    prob_list_t.append(logging_outputs[i].get("prob_t"))
            probs_t = torch.cat(prob_list_t, dim=0)
            
            metrics.log_scalar(
                f"{split}_acc_t", acc_sum_t / sample_size, sample_size, round=3
            )
            
            
            targets = torch.cat(
                [log.get("target", 0) for log in logging_outputs], dim=0
            )
            print(targets.shape, probs_t.shape)

            targets = targets[:len(probs_t)]
            bed_list_t = []
            for i in range(len(probs_t)):
                prob = probs_t[i]
                target = targets[i]
                label = torch.zeros_like(prob)
                label[target]=1.0
                bedroc = calculate_bedroc(label.cpu(), prob.cpu(), 80.5)
                bed_list_t.append(bedroc)
            bedroc_t = np.mean(bed_list_t)

            auc_t= roc_auc_score(targets.cpu(), probs_t.cpu(), average='macro', multi_class='ovr')
            top_k_acc_t = top_k_accuracy_score(targets.cpu(), probs_t.cpu(), k=3, normalize=True)
            metrics.log_scalar(f"{split}_bedroc_t", bedroc_t, sample_size, round=3)
            metrics.log_scalar(f"{split}_auc_t", auc_t, sample_size, round=3)
            metrics.log_scalar(f"{split}_top3_acc_t", top_k_acc_t, sample_size, round=3)
    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("multi_task_BCE")
class MultiTaskBCELoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            masked_tokens=None,
            features_only=True,
            classification_head_name=self.args.classification_head_name,
            fix_encoder=fix_encoder
        )
        logit_output = net_output[0]
        is_valid = sample["target"]["finetune_target"] > -0.5
        loss = self.compute_loss(
            model, logit_output, sample, reduce=reduce, is_valid=is_valid
        )
        sample_size = sample["target"]["finetune_target"].size(0)
        if not self.training:
            probs = torch.sigmoid(logit_output.float()).view(-1, logit_output.size(-1))
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": sample["target"]["finetune_target"].view(-1).data,
                "num_task": self.args.num_classes,
                "sample_size": sample_size,
                "conf_size": self.args.conf_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None):
        pred = net_output[is_valid].float()
        targets = sample["target"]["finetune_target"][is_valid].float()
        loss = F.binary_cross_entropy_with_logits(
            pred,
            targets,
            reduction="sum" if reduce else "none",
        )
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            agg_auc_list = []
            num_task = logging_outputs[0].get("num_task", 0)
            conf_size = logging_outputs[0].get("conf_size", 0)
            y_true = (
                torch.cat([log.get("target", 0) for log in logging_outputs], dim=0)
                .view(-1, conf_size, num_task)
                .cpu()
                .numpy()
                .mean(axis=1)
            )
            y_pred = (
                torch.cat([log.get("prob") for log in logging_outputs], dim=0)
                .view(-1, conf_size, num_task)
                .cpu()
                .numpy()
                .mean(axis=1)
            )
            # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1)
            for i in range(y_true.shape[1]):
                # AUC is only defined when there is at least one positive data.
                if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
                    # ignore nan values
                    is_labeled = y_true[:, i] > -0.5
                    agg_auc_list.append(
                        roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
                    )
            if len(agg_auc_list) < y_true.shape[1]:
                warnings.warn("Some target is missing!")
            if len(agg_auc_list) == 0:
                raise RuntimeError(
                    "No positively labeled data available. Cannot compute Average Precision."
                )
            agg_auc = sum(agg_auc_list) / len(agg_auc_list)
            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)

    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("BCE")
class BCELoss(CrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True, fix_encoder=False):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            smi_list = sample["smi_name"],
            pocket_list = sample["pocket_name"],
            features_only=True,
            fix_encoder=fix_encoder
        )
        logit_output = net_output
        loss = self.compute_loss(
            model, logit_output, sample, reduce=reduce
        )
        sample_size = sample["target"]["finetune_target"].size(0)

        if not self.training:
            probs = torch.sigmoid(logit_output.float())
            #print(probs.size())
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": sample["target"]["finetune_target"].view(-1).data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None):
        pred = net_output.float()
        targets = sample["target"]["finetune_target"].float()
        loss = F.binary_cross_entropy_with_logits(
            pred,
            targets,
            reduction="sum" if reduce else "none",
        )
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            y_true_list = []
            y_pred_list = []
            y_true_list = [log.get("target", 0) for log in logging_outputs]
            y_pred_list = [log.get("prob") for log in logging_outputs]
            y_true = (
                torch.cat(y_true_list, dim=0)
                .cpu()
                .numpy()
            )
            y_pred = (
                torch.cat(y_pred_list, dim=0)
                .cpu()
                .numpy()
            )
            # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1)

            auc = roc_auc_score(y_true , y_pred)
                    

            agg_auc = auc
            metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4)

    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return is_train

@register_loss("finetune_cross_entropy_pocket")
class FinetuneCrossEntropyPocketLoss(FinetuneCrossEntropyLoss):
    def __init__(self, task):
        super().__init__(task)

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(
            **sample["net_input"],
            features_only=True,
            classification_head_name=self.args.classification_head_name,
        )
        logit_output = net_output[0]
        loss = self.compute_loss(model, logit_output, sample, reduce=reduce)
        sample_size = sample["target"]["finetune_target"].size(0)
        if not self.training:
            probs = F.softmax(logit_output.float(), dim=-1).view(
                -1, logit_output.size(-1)
            )
            logging_output = {
                "loss": loss.data,
                "prob": probs.data,
                "target": sample["target"]["finetune_target"].view(-1).data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        else:
            logging_output = {
                "loss": loss.data,
                "sample_size": sample_size,
                "bsz": sample["target"]["finetune_target"].size(0),
            }
        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if "valid" in split or "test" in split:
            acc_sum = sum(
                sum(log.get("prob").argmax(dim=-1) == log.get("target"))
                for log in logging_outputs
            )
            metrics.log_scalar(
                f"{split}_acc", acc_sum / sample_size, sample_size, round=3
            )
            preds = (
                torch.cat(
                    [log.get("prob").argmax(dim=-1) for log in logging_outputs], dim=0
                )
                .cpu()
                .numpy()
            )
            targets = (
                torch.cat([log.get("target", 0) for log in logging_outputs], dim=0)
                .cpu()
                .numpy()
            )
            metrics.log_scalar(f"{split}_pre", precision_score(targets, preds), round=3)
            metrics.log_scalar(f"{split}_rec", recall_score(targets, preds), round=3)
            metrics.log_scalar(
                f"{split}_f1", f1_score(targets, preds), sample_size, round=3
            )
