# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F
from unicore import metrics
from unicore.losses import UnicoreLoss, register_loss
import numpy as np
from torch.distributions.gamma import Gamma
from torch.distributions.bernoulli import Bernoulli


a_u = 1
b_u = 1
a_minus = 10
b_minus = 1
a_plus = 5
b_plus = 1
num_iteration=0

def sample_u(w_matrix, sim_matrix):
    full_mat = w_matrix * sim_matrix
    #print("w_mat:", w_matrix)
    rate_param = b_u + full_mat.sum(dim=1)
       
    u_dist = Gamma(torch.tensor(a_u).float().to(sim_matrix.device),\
            rate_param.float())
       
    u=u_dist.sample()
    print("u:",u)
    return u
    
def sample_w(U, s_matrix):
    #import ipdb;ipdb.set_trace()
    BS = s_matrix.shape[0]
    mask = torch.cat([torch.ones(1), torch.zeros(BS-1)]).repeat(BS, 1)

    s_plus = s_matrix.masked_select(mask.bool().to(s_matrix.device))
    s_minus = s_matrix.masked_select(~mask.bool().to(s_matrix.device))
    w_plus_dist = Gamma(torch.tensor(1+a_plus).float().to(s_matrix.device),\
                U*s_plus + b_plus)
    U = U.repeat_interleave(int(BS-1)) ###### CHeckifyouneedthis
    w_minus_dist = Gamma(torch.tensor(a_minus).float().to(s_matrix.device),\
                U*s_minus + b_minus)
    w_plus = w_plus_dist.sample().reshape(BS,1)
        #print("w_plus:",w_plus)
    w_minus = w_minus_dist.sample().reshape(BS,BS-1)
        #print("w_minus:",w_minus)
 
    result = torch.cat([w_plus, w_minus], dim=1)

    return result

def sample_w_bernolli(U, s_matrix):
    BS = s_matrix.shape[0]
    mask = torch.cat([torch.ones(1), torch.zeros(BS-1)]).repeat(BS, 1)

    s_plus = s_matrix.masked_select(mask.bool().to(s_matrix.device))
    s_minus = s_matrix.masked_select(~mask.bool().to(s_matrix.device))

    # Sample w_plus from Gamma distribution
    w_plus_dist = Gamma(torch.tensor(1+a_plus).float().to(s_matrix.device), U*s_plus + b_plus)
    w_plus = w_plus_dist.sample().reshape(BS, 1)

    # Sample w_minus from Bernoulli distribution
    w_minus_prob = a_minus * torch.exp(-U * s_minus) / (1 - a_minus + a_minus * torch.exp(-U * s_minus))
    w_minus_dist = Bernoulli(w_minus_prob)
    w_minus = w_minus_dist.sample().reshape(BS, BS-1).float()

    result = torch.cat([w_plus, w_minus], dim=1)

    return result

def myloss(zis, zjs):
    #import ipdb;ipdb.set_trace()
    #representations = torch.cat([zjs, zis], dim=0)
    #similarity_matrix = self.similarity_function(representations, representations)
    similarity_matrix = torch.matmul(zis, zjs.T)

    # filter out the scores from the positive samples
    pos = torch.diag(similarity_matrix)
    positives = pos.view(-1, 1)

    #print("similarity_matrix")
    #print(similarity_matrix)
    #print("positives")
    #print(positives)

    #import ipdb;ipdb.set_trace()
    batch_size=zis.shape[0]
    temperature=0.1
    
    diag = np.eye(batch_size)
    mask = torch.from_numpy((diag))
    mask = (1 - mask).type(torch.bool)

    negatives = similarity_matrix[mask].view(batch_size, -1)

    logits = torch.cat((positives, negatives), dim=1)
    logits /= temperature
        
    weights = torch.ones_like(logits)
    for _ in range(num_iteration):
        U = sample_u(weights, logits)
        weights = sample_w_bernolli(U, logits)
        #weights = sample_w(U, logits)
    #import ipdb;ipdb.set_trace()wei
    #self.w=weights
    weighted_logits = logits * weights
    log_probs = torch.nn.functional.softmax(weighted_logits, dim=1)
    mask = torch.cat([torch.ones(1), torch.zeros(batch_size-1)]).\
            repeat(batch_size, 1).to(logits.device)
    pos_sim = log_probs.masked_select(mask.bool())
    #neg_sim = weighted_logits.masked_select(~mask.bool()).reshape(2*self.batch_size, -1)
    #import ipdb;ipdb.set_trace()
    loss = -pos_sim.mean()

    return loss


@register_loss("unimol")
class UniMolLoss(UnicoreLoss):
    def __init__(self, task):
        super().__init__(task)
        self.padding_idx = task.dictionary.pad()
        self.seed = task.seed
        self.dist_mean = 6.312581655060595
        self.dist_std = 3.3899264663911888

    def forward(self, model, sample, reduce=True):
        #import ipdb;ipdb.set_trace()
        #print(sample["target"]["smi_name"][0])
        #print("src_token:")
        #print(sample["net_input"]["src_tokens"][0])
        input_key = "net_input"
        target_key = "target"
        masked_tokens = sample[target_key]["tokens_target"].ne(self.padding_idx)
        sample_size = masked_tokens.long().sum()
        (
            logits_encoder,
            encoder_distance,
            encoder_coord,
            x_norm,
            delta_encoder_pair_rep_norm,
            feature_ori,
            feature_aug
        ) = model(**sample[input_key], encoder_masked_tokens=masked_tokens)
        target = sample[target_key]["tokens_target"]
        if masked_tokens is not None:
            target = target[masked_tokens]
        masked_token_loss = F.nll_loss(
            F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32),
            target,
            ignore_index=self.padding_idx,
            reduction="mean",
        )
        masked_pred = logits_encoder.argmax(dim=-1)
        masked_hit = (masked_pred == target).long().sum()
        masked_cnt = sample_size
        loss = masked_token_loss * self.args.masked_token_loss

        new_loss=myloss(feature_aug,feature_ori)
        loss = loss+ 0.1* new_loss

        #print( "myloss:", new_loss)

        logging_output = {
            "sample_size": 1,
            "bsz": sample[target_key]["tokens_target"].size(0),
            "seq_len": sample[target_key]["tokens_target"].size(1)
            * sample[target_key]["tokens_target"].size(0),
            "masked_token_loss": masked_token_loss.data,
            "masked_token_hit": masked_hit.data,
            "masked_token_cnt": masked_cnt,
            "myloss":new_loss.data
        }

        if encoder_coord is not None:
            # real = mask + delta
            coord_target = sample[target_key]["coord_target"]
            masked_coord_loss = F.smooth_l1_loss(
                encoder_coord[masked_tokens].view(-1, 3).float(),
                coord_target[masked_tokens].view(-1, 3),
                reduction="mean",
                beta=1.0,
            )
            loss = loss + masked_coord_loss * self.args.masked_coord_loss
            # restore the scale of loss for displaying
            logging_output["masked_coord_loss"] = masked_coord_loss.data

        if encoder_distance is not None:
            dist_masked_tokens = masked_tokens
            masked_dist_loss = self.cal_dist_loss(
                sample, encoder_distance, dist_masked_tokens, target_key, normalize=True
            )
            loss = loss + masked_dist_loss * self.args.masked_dist_loss
            logging_output["masked_dist_loss"] = masked_dist_loss.data

        if self.args.x_norm_loss > 0 and x_norm is not None:
            loss = loss + self.args.x_norm_loss * x_norm
            logging_output["x_norm_loss"] = x_norm.data

        if (
            self.args.delta_pair_repr_norm_loss > 0
            and delta_encoder_pair_rep_norm is not None
        ):
            loss = (
                loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm
            )
            logging_output[
                "delta_pair_repr_norm_loss"
            ] = delta_encoder_pair_rep_norm.data

        logging_output["loss"] = loss.data
        return loss, 1, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs, split="valid") -> None:
        """Aggregate logging outputs from data parallel training."""
        #import ipdb;ipdb.set_trace()
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        bsz = sum(log.get("bsz", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        #print(sample_size)
        seq_len = sum(log.get("seq_len", 0) for log in logging_outputs)
        my_loss = sum(log.get("myloss", 0) for log in logging_outputs)
        metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
        metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3)

        masked_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs)
        metrics.log_scalar(
            "masked_token_loss", masked_loss / sample_size, sample_size, round=3
        )

        metrics.log_scalar(
            "my_loss", my_loss / sample_size, sample_size, round=3
        )

        masked_acc = sum(
            log.get("masked_token_hit", 0) for log in logging_outputs
        ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs)
        metrics.log_scalar("masked_acc", masked_acc, sample_size, round=3)

        masked_coord_loss = sum(
            log.get("masked_coord_loss", 0) for log in logging_outputs
        )
        if masked_coord_loss > 0:
            metrics.log_scalar(
                "masked_coord_loss",
                masked_coord_loss / sample_size,
                sample_size,
                round=3,
            )

        masked_dist_loss = sum(
            log.get("masked_dist_loss", 0) for log in logging_outputs
        )
        if masked_dist_loss > 0:
            metrics.log_scalar(
                "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3
            )

        x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs)
        if x_norm_loss > 0:
            metrics.log_scalar(
                "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3
            )

        delta_pair_repr_norm_loss = sum(
            log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs
        )
        if delta_pair_repr_norm_loss > 0:
            metrics.log_scalar(
                "delta_pair_repr_norm_loss",
                delta_pair_repr_norm_loss / sample_size,
                sample_size,
                round=3,
            )

    @staticmethod
    def logging_outputs_can_be_summed(is_train) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True

    def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False):
        dist_masked_tokens = masked_tokens
        masked_distance = dist[dist_masked_tokens, :]
        masked_distance_target = sample[target_key]["distance_target"][
            dist_masked_tokens
        ]
        non_pad_pos = masked_distance_target > 0
        if normalize:
            masked_distance_target = (
                masked_distance_target.float() - self.dist_mean
            ) / self.dist_std
        masked_dist_loss = F.smooth_l1_loss(
            masked_distance[non_pad_pos].view(-1).float(),
            masked_distance_target[non_pad_pos].view(-1),
            reduction="mean",
            beta=1.0,
        )
        return masked_dist_loss


@register_loss("unimol_infer")
class UniMolInferLoss(UnicoreLoss):
    def __init__(self, task):
        super().__init__(task)
        self.padding_idx = task.dictionary.pad()

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        input_key = "net_input"
        target_key = "target"
        src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx)
        (
            encoder_rep,
            encoder_pair_rep,
        ) = model(**sample[input_key], features_only=True)
        sample_size = sample[input_key]["src_tokens"].size(0)
        encoder_pair_rep_list = []
        if 'pdb_id' in sample[target_key].keys():
            name_key = 'pdb_id'
        elif 'smi_name' in sample[target_key].keys():
            name_key = 'smi_name'
        else:
            raise NotImplementedError("No name key in the original data")

        for i in range(sample_size):  # rm padding token
            encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy())
        logging_output = {
                "mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(),  # get cls token
                "pair_repr": encoder_pair_rep_list,
                "data_name": sample[target_key][name_key],
                "bsz": sample[input_key]["src_tokens"].size(0),
            }
        return 0, sample_size, logging_output
