import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

# Euclidean embedding model
class L2Model(nn.Module):
    def __init__(self, 
                 X: torch.Tensor, 
                 Y: torch.Tensor, 
                 inference_only: bool=False):
        super(L2Model, self).__init__()
        self.X = nn.Parameter(X) if not inference_only else X
        self.Y = nn.Parameter(Y) if not inference_only else Y
        self.beta = nn.Parameter(torch.rand(1)) # (scalar) free parameter bias term
        self.S = None # ! only set if pretraining on SVD objective

    @classmethod
    def init_random(cls, 
                    n_row: int, 
                    n_col: int, 
                    rank: int,
                    **kwargs):
        """
        Initializes the low rank approximation tensors,
            with values drawn from std. gaussian distribution.
        """
        X = torch.randn(n_row, rank)
        Y = torch.randn(n_col, rank)
        return cls(X,Y, **kwargs)
    
    """
    Method 1 of 2. [together with init_post_svd]
    Initialize a model for pre-training (to improve initialization point), 
        with unitary matrices U and V from SVD on A, for learning S.
    """
    @classmethod
    def init_pre_svd(cls, 
                     U: torch.Tensor, 
                     V: torch.Tensor, 
                     **kwargs):
        assert U.shape == V.shape, "U & V must be dimensions (n,r) & (n,r), respectively, r: emb. rank, n: # of nodes"
        model = cls(U, V, inference_only=True, **kwargs) # we only learn S in A = USV^T
        S = torch.randn(U.shape[1])
        model.S = nn.Parameter(S)
        return model
    
    """
    Method 2 of 2. [together with init_pre_svd]
    Initialize a model for further training, using U, V and learned S to 
        compute an improved initialization point.
    """
    @classmethod
    def init_post_svd(cls, 
                      U: torch.Tensor, 
                      V: torch.Tensor, 
                      S: torch.Tensor,
                      **kwargs):
        S_inv_sqrt = torch.diag(torch.sqrt(F.softplus(S)) ** (-1))
        X = U @ S_inv_sqrt
        Y = V @ S_inv_sqrt
        return cls(X,Y,**kwargs)


    def reconstruct(self, node_indices: torch.Tensor = None):
        if node_indices is not None:
            X = self.X[node_indices]
            Y = self.Y[node_indices]
        else:
            X = self.X
            Y = self.Y

        if self.S is not None:
            # _S = softplus(_S) for nonneg # _S = _S**(1/2) as it is mult on both matrices
            _S = torch.diag(torch.sqrt(F.softplus(self.S)))
            norms = torch.cdist(X@_S, Y@_S, p=2)
        else:
            norms = torch.cdist(X, Y, p=2)
        A_hat = - norms + self.beta
        
        return A_hat
    
    def reconstruct_subset(self, 
                           links_list: torch.Tensor, 
                           nonlinks_list: torch.Tensor):
        # ! for Case Control
        index_list = torch.hstack([links_list, nonlinks_list])
        X = self.X[index_list[0]]
        Y = self.Y[index_list[1]]
        preds = self.beta - torch.norm(X - Y, p=2, dim=1)
        
        return preds

    def forward(self):
        if self.S is not None: # during pretraining, i.e. SVD target
            return self.X, self.Y, self.S
        
        return self.X, self.Y, self.beta
    
