import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PRN(nn.Module):

    def __init__(self,  encoder):
        super(PRN, self).__init__()
        self.encoder = encoder

        self.relu = nn.ReLU(inplace=True)
        self.softmax = nn.Softmax(dim=-1)


    def encode(self, src,tar,  src_mask, ):
        return self.relu(self.encoder(src, tar, src_mask))

    def forward(self, src, tar=None ):

        #graph
        aa = torch.argmax(src, dim=-1).unsqueeze(-1)
        aa = aa.repeat(1,1,aa.shape[1])
        bb = torch.transpose(aa, 1, 2)
        src_mask = (aa==bb)*1

        if tar is not None:
            src_mask =torch.ones([src.shape[0],src.shape[-2],src.shape[-2]], dtype=torch.long, device=src.device)

        encoder_out = self.encode(src, tar, src_mask)

        return self.softmax(encoder_out)
