
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as tla
import numpy.linalg as la
from sklearn.metrics import roc_auc_score 


def compute_loss(pos_score, neg_score, labels):
    # the edge scores computed, concatenate them 
    scores = torch.cat([pos_score, neg_score])
    return F.binary_cross_entropy_with_logits(scores,labels)

def compute_error_curl(output,target,B2):
    assert output.shape == target.shape 
    return tla.vector_norm((output-target))/tla.vector_norm(target), tla.vector_norm(B2.T@output)  

def compute_error_curl_np(output,target,B2):
    assert output.shape == target.shape 
    return la.norm((output-target))/la.norm(target), la.norm(B2.T@output)  

def compute_loss_forex(output, target, mask,B2):
    # the edge scores computed, concatenate them 
    assert output.shape == target.shape 
    loss = tla.vector_norm((output-target))/tla.vector_norm(target) + 0*tla.vector_norm(B2.T@output,1)
    return loss


def compute_error_curl_interp(output,target,B2,mask):
    assert output.shape == target.shape 
    return tla.vector_norm((output-target))/tla.vector_norm(target), tla.vector_norm(B2.T@output)  

def compute_error_curl_np_interp(output,target,B2,mask):
    assert output.shape == target.shape 
    return la.norm((output-target))/la.norm(target), la.norm(B2.T@output) 
def compute_loss_forex_interp(output, target, mask,B2):
    # the edge scores computed, concatenate them 
    assert output.shape == target.shape 
    loss = tla.vector_norm((output-target))/tla.vector_norm(target) + 0.00*tla.vector_norm(B2.T@output,1)
    return loss

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).detach().cpu().numpy()
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels,scores)


class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        # the first layer of MLP is the concatenated version of the three edge features which have dimension of h_feats
        self.W1 = nn.Linear(h_feats * 3, h_feats) 
        self.W2 = nn.Linear(h_feats, 1)
        self.m = nn.Sigmoid()

    def forward(self, fi, fj, fk):
        """
        Computes a scalar score for each triangle 

        Parameters
        ----------
        edge features
        fi -- features of the first edge, dim: #training/testing pos/neg triangles x # features 
        fj -- features of the second edge, dim: #training/testing pos/neg triangles x # features  
        fk -- features of the third edge, dim: #training/testing pos/neg triangles x # features  

        Returns
        -------
        new feature or score used to perform prediction
        """
        h = torch.cat([fi, fj, fk], 1)
        return self.W2(self.m(self.W1(h))).squeeze(1)
    
class MLPPredictor_forex(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        # the first layer of MLP is the concatenated version of the three edge features which have dimension of h_feats
        self.W1 = nn.Linear(h_feats, h_feats) 
        self.W2 = nn.Linear(h_feats, 1)
        self.m = nn.Sigmoid()

    def forward(self, f):
        """
        Computes a scalar score for each triangle 

        Parameters
        ----------
        edge features
        fi -- features of the first edge, dim: #training/testing pos/neg triangles x # features 
        fj -- features of the second edge, dim: #training/testing pos/neg triangles x # features  
        fk -- features of the third edge, dim: #training/testing pos/neg triangles x # features  

        Returns
        -------
        new feature or score used to perform prediction
        """
        return self.W2(self.m(self.W1(f)))
