import torch
from torch.nn import Linear, ReLU, SiLU, Sequential
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
from torch_scatter import scatter

import torch.nn as nn
import numpy as np
import scipy as sp
import torch.nn.functional as F



class MPN(nn.Module):
    """
    Implementaion of EGNN(Satorras et al.) with no activations
    """
    def __init__(self, n, num_layers, d=3):
        super().__init__()
        self.device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_layers = num_layers
        
        self.n = n 
        self.A = nn.Parameter(torch.randn(3, requires_grad=True, device=self.device))
        self.B = nn.Parameter(torch.randn(2, requires_grad=True, device=self.device))
        self.C = nn.Parameter(torch.randn(n, requires_grad=True, device=self.device))

    def forward(self, X):
        h = torch.zeros(self.n, device=self.device)
        m = torch.zeros(self.n, self.n,device=self.device)
        M = torch.zeros(self.n,device=self.device)
        for t in range(self.num_layers):
            for i in range(self.n):
                for j in range(self.n):
                    temp=torch.inner(self.A,torch.tensor([h[i],h[j],torch.linalg.norm(X[i,:]-X[j,:])], device=self.device))
                    m[i,j]=torch.sin(temp)
                sort, _ = torch.sort(m[i,:])
                M[i]=torch.inner(self.C, sort)
            for i in range(self.n):
                temp=torch.inner(self.B, torch.tensor([h[i],M[i]], device=self.device))
                h[i]=temp
        return h
    

def tr12(input):
    return torch.transpose(input, 1, 2)

def gen_cross_product(input):
    """
  Generalized cross product
  Input 1 x d x n
  returns cross product of columns
    """
    if torch.cuda.is_available():
        device='cuda'
    else:
        device='cpu'
    mat = input.squeeze() # n-1 x n
    ones = torch.ones(1, mat.size(1), device=device)
    mat = torch.cat([ones, mat], dim=0)
    n = mat.size(1)
    order = torch.arange(n,device=device)
    empty= torch.zeros(n,device=device)
    first_row = temp =torch.cat([order[:0], order[0+1:]], dim=0)
    for i in range(n):
        temp = torch.cat([order[:i], order[i+1:]], dim=0)
        empty[i] =(-1)**i * torch.det(mat[first_row][:,temp])
    return empty

class embed_vec_sort(nn.Module):
   # Calculates a permutation-invariant embedding of n vectors in R^d
   # Input size: b x d x n, with b = batch size, d = input feature dimension, n = set size
   # Output size: b x d_out. Default d_out is 2*d*n+1 
    def __init__(self, d, n, d_out = None, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        super().__init__()
        self.device = device
      
        if d_out is None:
            d_out = 2*d*n+1

        self.d = d
        self.n = n
        self.d_out = d_out
        self.A = nn.Parameter(torch.randn([d, d_out], requires_grad=True, device=self.device))
        self.w = nn.Parameter(torch.randn([1, n, d_out], requires_grad=True, device=self.device))

    def forward(self, input):
        prod = tr12( torch.tensordot( tr12(input), self.A, [[2], [0]] ) ) 
        [prod_sort, inds_sort] = torch.sort(prod, dim=2)
        out = torch.sum( prod_sort * tr12(self.w), dim=2)

        return out


# Calculates an embedding of n vectors in R^d that is invariant to permutations,
# rotations and optionally translations.
#
# Input size: b x (d + d_feature) x n
# b: batch size
# d: dimension of Euclidean space (currently only d=3 is supported)
# d_feature: accompanying feature dimension (default: 0)
# n: number of points
#
# Input shape: input[:,  0:(d-1), :] should be the Euclidean coordinates. The rest input[:,  d:(d+d_feature), :] should contain the accompanying feature vectors.
class embed_graph(nn.Module):
    def __init__(self, d, n, d_feature = 0, embed_mij_dim=None, embed_global_dim=None, translation_invariant = True, is_compact = True, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        super().__init__()
        self.device = device
        self.d = d      
        self.n = n
        self.d_feature = d_feature

        self.translation_invariant = translation_invariant

        self.n_combs = np.prod(range(n-d+2,n+1))

        # Dimension of the embedding vector for one index-combination
        if embed_mij_dim==None:
            self.d_comb = d*d + (d-1)*d_feature + 2*(d+d_feature)*(n-(d-1)) + 1
        else:
            self.d_comb = d*d + (d-1)*d_feature + embed_mij_dim

        # Dimension of the embedding for the entire graph
        if is_compact:
            self.d_graph = 2*(d+d_feature)*n + 1
        else:
            self.d_graph = 2*self.d_comb*self.n_combs
            
        if embed_mij_dim==None:
            self.embed_comb = embed_vec_sort(d+d_feature, n-(d-1), device=device)
        else:
            self.embed_comb = embed_vec_sort(d+d_feature, n-(d-1), d_out=embed_mij_dim, device=device)
            
        if embed_global_dim==None:
            self.embed_graph = embed_vec_sort(self.d_comb, self.n_combs, self.d_graph, device=device)
        else:
            self.embed_graph = embed_vec_sort(self.d_comb, self.n_combs, embed_global_dim, device=device)


    def forward(self, input):
        input = input.reshape(1, input.size(dim=0), input.size(dim=1))
        [b,d_tot,n] = input.shape
        d_feature = self.d_feature
        d = d_tot - d_feature
        if d != self.d or n !=self.n:
            print(d); print(self.d)
            print(n); print(self.n)
        assert(d == self.d and n == self.n)
        #assert(d == 3) # Only d=3 is currently supported
        
        X = input[:,range(d),:]
        F = input[:,range(d,d_tot),:] if d_feature > 0 else None
        
        if self.translation_invariant:
            centers = torch.mean(X, axis=2)
            centers = torch.unsqueeze(centers, 2)
            X -= centers
        
        combs = torch.combinations(torch.arange(start=0, end=n), r=d-1, with_replacement=False)
        combs_flip = combs.flip(dims=[1])
        combs = torch.cat([combs, combs_flip], dim=0)

        global_vecs = torch.zeros([b, self.d_comb, self.n_combs], device=self.device)
        
        for i, comb in enumerate(combs):
            M0 = X[:, :, comb]

            #chnage to generalized cross 
            xprod = torch.linalg.cross(M0[:,:,0], M0[:,:,1], dim=1)
            #xprod = gen_cross_product(M0.squeeze().t()).t().reshape(1, self.d) #tested
            xprod = torch.unsqueeze(xprod, 2)
        
            # For a vector combination (v1, v2), the matrix M consists of columns [v1, v2, v1 x v2]
            M = torch.cat([M0, xprod], dim=2)
            Mt = tr12(M)
            MtM = torch.bmm(Mt,M)
            MtM_vec = torch.flatten(MtM, 1)
        
            if F is None:
                vec1 = MtM_vec
            else:
                vec1 = torch.cat([MtM_vec, torch.flatten(F[:,:,comb] , 1)], dim=1)
            
            # Complement of the current combination
            combcomp = [x for x in range(self.n) if x not in comb]
        
            # Product of M^T with all vectors except v1,v2
            MtV = torch.bmm(Mt, X[:,:,combcomp])
        
            if F is None:
                vecs_to_embed = MtV
            else:
                vecs_to_embed = torch.cat([MtV, F[:,:,combcomp]], dim=1)
        
            vec2 = self.embed_comb(vecs_to_embed)
        
            global_vecs[:, :, i] = torch.cat([vec1, vec2], dim=1)
        
        out = self.embed_graph(global_vecs)
        return out

class EGNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """E(n) Equivariant GNN Layer

        Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + 1, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        msg = torch.cat([h_i, h_j, dists], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)  # torch.clamp(updates, min=-100, max=100)
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs
        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
        return msg_aggr, pos_aggr

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr = aggr_out
        upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
        upd_pos = pos + pos_aggr
        return upd_out, upd_pos

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"


class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
        """Vanilla Message Passing GNN layer
        
        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )

    def forward(self, h, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: (n, d) - updated node features
        """
        out = self.propagate(edge_index, h=h)
        return out

    def message(self, h_i, h_j):
        # Compute messages
        msg = torch.cat([h_i, h_j], dim=-1)
        msg = self.mlp_msg(msg)
        return msg

    def aggregate(self, inputs, index):
        # Aggregate messages
        msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
        return msg_aggr

    def update(self, aggr_out, h):
        upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1))
        return upd_out

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
