import torch
from torch import nn as nn
from torch.nn import Sequential

from peagang.models.components.utilities_classes import PointNetBlock
from peagang.models.components.utilities_functions import sn_wrap
from peagang.utils.utils import zero_diag


class KernelEdges(nn.Module):
    def __init__(self, p=2):
        super().__init__()
        self.p = p
        self.log_sigma = torch.nn.Parameter(torch.ones(()), requires_grad=True)

    def forward(self, X):
        dists = torch.zeros(X.shape[0], X.shape[1], X.shape[1], device=X.device)
        sigma = self.log_sigma.exp()
        for i in range(X.shape[1]):
            for j in range(i):
                if i == j:
                    continue
                x1 = X[:, i, :]
                x2 = X[:, j, :]
                d = torch.norm(x1 - x2, p=self.p)
                dists[:, i, j] = d
                # TODO/Note: can't do directed graphs like this
                dists[:, j, i] = d
        # assert not torch.isnan(sigma).any()
        S = -((dists / sigma) ** 2)
        ## can't zero_grad here buuuuuut, we can still make sure the diagonal isn't zero
        _A = (-S).exp()
        A = zero_diag(_A)

        return A


class RescaledSoftmax(nn.Module):
    def __init__(
        self,
        p=2,
        with_bias=False,
        spectral_norm=None,
        feat_dim=None,
        bias_hidden=128,
        act=None,
    ):
        super().__init__()
        if act is None:
            act = nn.ReLU
        self.p = p
        self.sigma = 1.0
        if with_bias:
            # node wise temperature
            self.bias = nn.Sequential(
                PointNetBlock(feat_dim, bias_hidden, spectral_norm=spectral_norm),
                act(),
                PointNetBlock(bias_hidden, 1, spectral_norm=spectral_norm),
            )
        else:
            self.bias = None

    def forward(self, X):
        # X: B N F
        # TODO: QK/QQ attention
        # inner product => B N N
        prod = X @ X.permute(0, -1, -2)
        if self.bias is True or self.bias == "add":
            prod = prod + self.bias(X)
        elif self.bias == "mult":
            prod = prod * self.bias(X)
        zeroed = zero_diag(prod)
        sm = zeroed.softmax(-1)
        # zero diag again
        sm_zero = zero_diag(sm)
        # renormalizec
        ma = sm_zero.max(-1)[0].unsqueeze(-1)
        mi = sm_zero.min(-1)[0].unsqueeze(-1)
        e = torch.finfo(ma.dtype).eps
        A_nonsym = (sm_zero - mi) / (ma - mi + e)
        # symmetrize
        Atriu = torch.triu(A_nonsym)
        A = Atriu + Atriu.permute(0, -1, -2)
        return A


class BiasedSigmoid(nn.Module):
    def __init__(
        self, feat_dim, hidden_dim=128, spectral_norm=None, bias="scalar", act=None
    ):
        super().__init__()
        if act is None:
            act = nn.ReLU
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        if bias == "scalar":
            self.trunk = Sequential(
                sn_wrap(nn.Linear(feat_dim, hidden_dim), spectral_norm),
                act(),
                sn_wrap(nn.Linear(hidden_dim, 1), spectral_norm),
            )
        elif bias == "nodes":
            # node wise temperature
            self.trunk = nn.Sequential(
                PointNetBlock(feat_dim, hidden_dim, spectral_norm=spectral_norm),
                act(),
                PointNetBlock(hidden_dim, 1, spectral_norm=spectral_norm),
            )

    def forward(self, X):
        # X: B N F
        # aggregate along node_dim
        x = X.mean(1)
        # B 1 1
        b = self.trunk(x).unsqueeze(-1)
        # B N N
        prod = X @ X.permute(0, 2, 1)
        biased_prod = prod + b
        _A = torch.sigmoid(biased_prod)
        A_nonsym = zero_diag(_A)
        # symmetrize
        Atriu = torch.triu(A_nonsym)
        A = Atriu + Atriu.permute(0, -1, -2)
        return A


if __name__ == "__main__":
    k = KernelEdges()
    b = BiasedSigmoid(feat_dim=10, spectral_norm="diff", bias="nodes")
    r = RescaledSoftmax(feat_dim=10, spectral_norm="diff", with_bias=True)
    X = torch.randn(3, 5, 10)
    s = k(X)
    l = s.sum()
    l.backward()
    X = torch.randn(3, 5, 10)
    s = b(X)
    l = s.sum()
    l.backward()
    X = torch.randn(3, 5, 10)
    s = r(X)
    l = s.sum()
    l.backward()
