import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
from gnnModels_slim import TimeEmbedding


class Transformer(nn.Module):
    def __init__(self, ntips, hidden_dim=100, n_head=4, device=torch.device('cpu'), **kwargs):
        super().__init__()
        self.ntips = ntips
        self.hidden_dim, self.n_head = hidden_dim, n_head
        self.device = device
        self.MHA = nn.MultiheadAttention(hidden_dim, n_head, batch_first=True)
        self.MHSA = nn.MultiheadAttention(hidden_dim, n_head, batch_first=True)
        self.MHAnorm = nn.LayerNorm(hidden_dim)
        self.MHSAnorm = nn.LayerNorm(hidden_dim)
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LayerNorm(self.hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,1)
        )
        self.star = nn.Parameter(torch.randn((3, hidden_dim)) / np.sqrt(hidden_dim))
        self.time_embedding = TimeEmbedding(2 * hidden_dim, device=self.device)
        self.query = torch.eye(ntips-3, hidden_dim, device=self.device)
    
    def forward(self, node_features, edge_index, t):
        temb = self.time_embedding(t)
        batch_size, nnodes = edge_index.shape[0], edge_index.shape[1]
        node_features = self.MHAnorm(node_features)
        new_feature, _ = self.MHA(self.query[None, None, t-3].repeat(batch_size,1,1), node_features, node_features) 

        child_info = node_features[:,:-1]
        parent_info = torch.gather(node_features,1, edge_index[:,:-1,0].unsqueeze(-1).expand(-1,-1,node_features.shape[-1]))
        edge_info = torch.max(child_info, parent_info)
        edge_info = torch.concat([new_feature.expand(-1,nnodes-1,-1), edge_info], dim=-1)
        logits = self.readout(edge_info+temb.unsqueeze(1)).squeeze(-1)
        logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
        return logits, new_feature

    def update(self, node_features, edge_index, new_feature, pos):
        batch_size, nnodes = edge_index.shape[0], edge_index.shape[1]
        assert pos.shape == (batch_size, 1)
        parent_index = edge_index[:,:-1,0]
        parent_pos = torch.gather(parent_index, 1, pos) 

        star_features = torch.gather(node_features, 1, torch.stack([pos, parent_pos],dim=1).expand(-1,-1,node_features.shape[-1]))
        star_features = torch.concat([star_features, new_feature], dim=1)
        star_features = torch.concat([star_features, star_features.mean(dim=1, keepdim=True)], dim=1)  # dim1: pos, parent pos, new node, pendant.
        star_features = self.MHSAnorm(star_features)
        next_star_features, _ = self.MHSA(star_features, star_features, star_features)
        star_features = next_star_features + star_features
        
        node_features = torch.scatter(node_features, dim=1, index=torch.stack([pos, parent_pos],dim=1).expand(-1,-1,node_features.shape[-1]), src=star_features[:,:2])
        node_features = torch.cat([node_features[:,:-1], star_features[:,2:], node_features[:,-1].unsqueeze(1)], dim=1)
    
        edge_index = torch.where(edge_index<nnodes-1, edge_index, nnodes+1)
        parent_pos = torch.where(parent_pos<nnodes-1, parent_pos, nnodes+1)
        for b in range(batch_size):
            edge_index[b, pos[b,0], 0] = nnodes
        added_edge_index = torch.stack([
            torch.tensor([[nnodes,-1,-1] for _ in range(batch_size)], device=self.device),
            torch.concat([parent_pos, -torch.ones_like(parent_pos, device=self.device), -torch.ones_like(parent_pos, device=self.device)], dim=-1)
        ], dim=1)
        edge_index = torch.cat([edge_index[:,:-1], added_edge_index, edge_index[:,-1].unsqueeze(1)], dim=1)
        return node_features, edge_index
    
    def _init(self, batch_size):
        node_features = torch.cat([self.star, self.star.mean(0, keepdim=True)], dim=0)
        node_features = node_features.unsqueeze(0).expand(batch_size,-1,-1)
        edge_index = torch.tensor([
            [3,-1,-1],
            [3,-1,-1],
            [3,-1,-1],
            [0,1,2]
        ], device=self.device)
        edge_index = edge_index.unsqueeze(0).expand(batch_size,-1,-1)
        return node_features, edge_index


