import torch 


class DotProductDecoder(torch.nn.Module):
    def forward(self, z):
        # Batched matrix multiplication
        A = z @ torch.transpose(z, 1, 2)
        return A

class MLPDecoder(torch.nn.Module):
    def __init__(self, in_channels, permutation_invariant=False):
        super(MLPDecoder, self).__init__()
        if permutation_invariant:
            self.lin = torch.nn.Linear(in_channels, in_channels)
            self.lin_final = torch.nn.Linear(in_channels, 1)
        else:
            self.lin = torch.nn.Linear(in_channels, in_channels)
            self.lin_final = torch.nn.Linear(in_channels, 1)
        self.permutation_invariant = permutation_invariant

    def forward(self, z):

        if self.permutation_invariant:
            n_graphs, n_nodes, n_games, n_feat = z.shape
            z = z.permute(0, 3, 1, 2).reshape(n_graphs*n_feat, n_nodes, n_games)  # n_graphs*n_feat, n_nodes, n_games
            z_tilde = torch.bmm(z, z.transpose(1, 2))  # n_graphs*n_feat, n_nodes, n_nodes
            z_tilde = z_tilde.reshape(n_graphs, n_feat, n_nodes, n_nodes).permute(0, 2, 3, 1)

            h = self.lin(z_tilde).relu()
            A = self.lin_final(h).reshape(-1, n_nodes, n_nodes)
        else:
            n_nodes = z.shape[1]
            nodes = torch.arange(start=0, end=n_nodes, dtype=torch.long)
            all_src = torch.repeat_interleave(nodes, n_nodes)
            all_dst = nodes.repeat(n_nodes)

            z_src = z[:, all_src]
            z_dst = z[:, all_dst]

            h = self.lin(z_src * z_dst).relu()
            A = self.lin_final(h).reshape(-1, n_nodes, n_nodes)

        return A


class CorrelationCoefficientDecoder(torch.nn.Module):
    def forward(self, z):
        # Batched matrix multiplication
        z_mean = torch.mean(z, dim=-1, keepdim=True)
        z_shift = z - z_mean
        cov = z_shift @ z_shift.transpose(1, 2)
        z_var = (z_shift * z_shift).sum(dim=-1, keepdim=True)
        corr_coeff = cov / torch.sqrt(z_var * z_var.transpose(1, 2))

        return (corr_coeff + 1) / 2  # shifts it in the range [0, 1] so that we can use MSE for training the model

class CosineSimilarityDecoder(torch.nn.Module):
    def forward(self, z, pearson=False):
        # Batched matrix multiplication
        A = z @ torch.transpose(z, 1, 2)
        z_norm = torch.norm(z, dim=-1, keepdim=True)
        A = A / (z_norm * torch.transpose(z_norm, 1, 2))
        return A