import math
import torch
from torch import nn
from torch.nn import functional as F

from ..manifolds import StereographicModel
from .networks import StereographicTransformer, StereographicLogits, StereographicLayerNorm


class Encoder(nn.Module):
    """
    Tokenized Graph Transformer.
    """
    
    def __init__(self, args):
        super().__init__()
      
        if args.c is not None:
            self.c = nn.ParameterList([nn.Parameter(torch.Tensor([float(args.c)]*args.num_heads)) for _ in range(args.num_layers)])
        else:
            self.c = nn.ParameterList([nn.Parameter(torch.Tensor([0.]*args.num_heads)) for _ in range(args.num_layers)])
        self.manifold = StereographicModel()
        self.num_layers = args.num_layers
        self.tied_weights = args.tied_weights
        self.attn_type = args.attn_type
        self.feat_dim = args.feat_dim
        self.dim = args.dim
        self.lap_k = args.lap_k
        self.dropout = args.dropout
        self.lap_dropout = args.lap_dropout
        self.lap_sign_flip = args.lap_sign_flip
        self.num_heads = args.num_heads
        self.act = getattr(nn, args.act)()
        self.hops = args.hops
        self.task = args.task
        
        self.hop_mixer = nn.Linear(self.hops+1, 1, bias=False)
        self.input_encoder = nn.Linear(self.feat_dim, self.dim, bias=True)
        
        nn.init.xavier_uniform_(self.input_encoder.weight, gain=1 / math.sqrt(2))
        nn.init.constant_(self.input_encoder.bias, 0)
        
        self.lap_encoder = nn.Sequential(
            nn.Linear(4*self.lap_k, self.dim, bias=False),
            self.act,
            nn.Linear(self.dim, self.dim, bias=False)
        )
        
        self.lap_eig_dropout = nn.Dropout(p=self.lap_dropout) if self.lap_dropout > 0 else None
        self.order_embed = nn.Embedding(2, self.dim)
        self.embed_dropout = torch.nn.Dropout(p = self.dropout)
        
        if self.tied_weights:
            layers = [StereographicTransformer(self.manifold,
                                               self.c[0], 
                                               args)]
        else:
            layers = []
            for idx in range(self.num_layers):
                layers.append(StereographicTransformer(self.manifold,
                                                       self.c[idx],
                                                       args))
        self.layers = nn.Sequential(*layers)
        
        if args.layer_norm:
            self.norm = StereographicLayerNorm(self.manifold, self.c[-1], args.dim, args.num_heads)
        else:
            self.norm = nn.Identity()

        if args.task == 'nc':
            self.cls = StereographicLogits(self.manifold, args.dim, args.n_classes, self.c[-1], args.num_heads)
        
    @staticmethod
    @torch.no_grad()
    def get_random_sign_flip(eigvec):
        K = eigvec.shape[-1]
        s = torch.rand(K,device=eigvec.device)
        s[s >= 0.5] = 1.0
        s[s < 0.5] = -1.0
        return eigvec * s

    def logit(self, embeddings):
        return F.log_softmax(self.cls(embeddings), dim=1)

    def sqdist(self, x_1, x_2):
        return self.manifold.sqdist(x_1, x_2, self.c[-1])
            
    def forward(self, data):
        x, adj, lap_eigvec, edge_index = data['features'], data['adj_train_norm'], data['lap_eigvec'], data['edge_index']
        N = adj.shape[0]
        M = edge_index.shape[-1]
        D = x.shape[-1]
        device = x.device
        
        node_feat = torch.matmul(x.transpose(-1,-2), F.softmax(self.hop_mixer.weight, dim=-1).t()).squeeze()
        node_feat = self.input_encoder(node_feat)

        node_id = torch.matmul(lap_eigvec.transpose(-1, -2), F.softmax(self.hop_mixer.weight, dim=-1).t()).squeeze()
        
        if self.lap_sign_flip and self.training:
            node_id = self.get_random_sign_flip(node_id)
        
        if self.lap_eig_dropout is not None:
            node_id = self.lap_eig_dropout(node_id)
        
        # Step 1: Laplacian positional encodings
        src_id = node_id[edge_index[0]]
        dst_id = node_id[edge_index[1]]
        pos_id = torch.cat((torch.cat((node_id, node_id), dim=1),
                            torch.cat((src_id, dst_id), dim=1)),
                           dim=0)
        tokens = self.lap_encoder(pos_id)
        
        # Step 3: Add order embeddings and graph token
        tokens = tokens + self.order_embed(torch.tensor([0]*N + [1]*M, device=device))
        
        # Step 4: add node features
        tokens[:N] = tokens[:N] + node_feat
        
        if not self.task == 'lp':
            tokens = self.embed_dropout(tokens)
        
        mask = torch.ones_like(tokens[:,0])
        
        # Step 5: project to product stereographic space
        tokens = self.manifold.expmap0(tokens, self.c[0])
        tokens = self.manifold.proj(tokens, self.c[0])
        
        # Layers that operate on the stereographic product space
        if self.tied_weights:
            for idx in range(self.num_layers):
                tokens = self.layers[0](tokens, mask)
        else:
            for idx in range(self.num_layers):
                tokens = self.layers[idx](tokens, mask)
                if idx < self.num_layers - 1: # transition to product space of next layer
                    tokens = self.manifold.logmap0(tokens, self.c[idx])
                    tokens = self.manifold.expmap0(tokens, self.c[idx+1])
                    tokens = self.manifold.proj(tokens, self.c[idx+1])
                    
        tokens = self.norm(tokens)
            
        return tokens[:N]

