import torch
import torch.nn as nn
import torch.nn.functional as F

from src.modules.graphtrans_module import Struct2Seq, cat_neighbors_nodes, gather_nodes, ProteinFeatures


class StructGNN_Model(nn.Module):
    def __init__(self, args):
        super(StructGNN_Model, self).__init__()
        self.args = args
        self.device = 'cuda:0'
        self.smoothing = args.smoothing
        self.model = Struct2Seq(
            vocab=args.vocab_size,
            num_letters=args.vocab_size,
            node_features=args.hidden,
            edge_features=args.hidden, 
            hidden_dim=args.hidden,
            k_neighbors=args.k_neighbors,
            protein_features=args.features,
            dropout=args.dropout,
            use_mpnn=True)
        
        self.featurizer =  ProteinFeatures(
            args.hidden, args.hidden, top_k=args.k_neighbors,
            features_type=args.features,
            dropout=args.dropout
        )
    
    def _get_features(self, batch):
        X, lengths, mask = batch['X'], batch['lengths'], batch['mask']
        V, E, E_idx = self.featurizer(X, lengths, mask)
        batch.update({'V':V, 
                "E":E,
                "E_idx":E_idx, 
                "mask": mask})
        return batch
        
        
    def forward(self, batch):
        """ Graph-conditioned sequence model """
        S, V, E, E_idx, mask = batch['S'], batch['V'], batch['E'], batch['E_idx'], batch['mask']

        # Prepare node and edge embeddings
        h_V = self.model.W_v(V)
        h_E = self.model.W_e(E)

        # Encoder is unmasked self-attention
        mask_attend = gather_nodes(mask.unsqueeze(-1),  E_idx).squeeze(-1)
        mask_attend = mask.unsqueeze(-1) * mask_attend
        for layer in self.model.encoder_layers:
            h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
            h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)

        # Concatenate sequence embeddings for autoregressive decoder
        h_S = self.model.W_s(S)
        h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)

        # Build encoder embeddings
        h_ES_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
        h_ESV_encoder = cat_neighbors_nodes(h_V, h_ES_encoder, E_idx)

        # Decoder uses masked self-attention
        mask_attend = self.model._autoregressive_mask(E_idx).unsqueeze(-1)
        mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
        mask_bw = mask_1D * mask_attend
        
        if self.model.forward_attention_decoder:
            mask_fw = mask_1D * (1. - mask_attend)
            h_ESV_encoder_fw = h_ESV_encoder
        else:
            h_ESV_encoder_fw = 0
        for layer in self.model.decoder_layers:
            # Masked positions attend to encoder information, unmasked see. 
            h_ESV_dec = cat_neighbors_nodes(h_V, h_ES, E_idx)
            h_ESV = mask_bw * h_ESV_dec + mask_fw*h_ESV_encoder_fw 
            h_V = layer(h_V, h_ESV, mask_V=mask) 

        logits = self.model.W_out(h_V)
        log_probs = F.log_softmax(logits, dim=-1)
        return {'log_probs': log_probs}

    def sample(self, V, E, E_idx, mask, chain_mask=None, temperature=1.0):
        """ Autoregressive decoding of a model """
         # Prepare node and edge embeddings
        h_V = self.model.W_v(V)
        h_E = self.model.W_e(E)

        # Encoder is unmasked self-attention
        mask_attend = gather_nodes(mask.unsqueeze(-1),  E_idx).squeeze(-1)
        mask_attend = mask.unsqueeze(-1) * mask_attend
        for layer in self.model.encoder_layers:
            h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
            h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)

        # Decoder alternates masked self-attention
        mask_attend = self.model._autoregressive_mask(E_idx).unsqueeze(-1)
        mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
        mask_bw = mask_1D * mask_attend
        mask_fw = mask_1D * (1. - mask_attend)
        N_batch, N_nodes = V.size(0), V.size(1)
        h_S = torch.zeros_like(h_V)
        S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device = self.device)
        h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.model.decoder_layers))]
        all_probs = []
        for t in range(N_nodes):
            # Hidden layers
            E_idx_t = E_idx[:,t:t+1,:]
            h_E_t = h_E[:,t:t+1,:,:]
            # use cache
            h_ES_enc_t = cat_neighbors_nodes(torch.zeros_like(h_S), h_E_t, E_idx_t)
            h_ESV_encoder_t = mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_ES_enc_t, E_idx_t)
            
            for l, layer in enumerate(self.model.decoder_layers):
                # Updated relational features for future states
                h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
                h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) # [batch, 1, K, 384]
                h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_ESV_encoder_t # [batch, 1, K, 384]
                
                h_V_t = h_V_stack[l][:,t:t+1,:] # [batch, 1 128]
                h_V_stack[l+1][:,t,:] = layer(
                    h_V_t, h_ESV_t, mask_V=mask[:,t:t+1]
                ).squeeze(1) # [1, 128]

            # Sampling step
            h_V_t = h_V_stack[-1][:,t,:]
            logits = self.model.W_out(h_V_t) / temperature
            probs = F.softmax(logits, dim=-1)
            S_t = torch.multinomial(probs, 1).squeeze(-1)
            
            all_probs.append(probs)
            
            # Update
            h_S[:,t,:] = self.model.W_s(S_t)
            S[:,t] = S_t
        
        self.probs = torch.cat(all_probs, dim=0)
        return S