import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch_scatter import scatter_mean

from src.modules.gvp_module import GVP, GVPConvLayer, LayerNorm, tuple_index


class GVP_Model(nn.Module):
    '''
    GVP-GNN for structure-conditioned autoregressive 
    protein design as described in manuscript.
    
    Takes in protein structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a categorical distribution
    over 20 amino acids at each position in a `torch.Tensor` of 
    shape [n_nodes, 20].
    
    Should be used with `gvp.data.ProteinGraphDataset`, or with generators
    of `torch_geometric.data.Batch` objects with the same attributes.
    
    The standard forward pass requires sequence information as input
    and should be used for training or evaluating likelihood.
    For sampling or design, use `self.sample`.
    
    :param node_in_dim: node dimensions in input graph, should be
                        (6, 3) if using original features
    :param node_h_dim: node dimensions to use in GVP-GNN layers
    :param node_in_dim: edge dimensions in input graph, should be
                        (32, 1) if using original features
    :param edge_h_dim: edge dimensions to embed to before use
                       in GVP-GNN layers
    :param num_layers: number of GVP-GNN layers in each of the encoder
                       and decoder modules
    :param drop_rate: rate to use in all dropout layers
    '''
    def __init__(self, args, drop_rate=0.1):
    
        super(GVP_Model, self).__init__()
        self.args = args
        self.node_in_dim = (6, 3)
        self.node_h_dim = (100, 16)
        self.edge_in_dim = (32, 1)
        self.edge_h_dim = (32, 1)
        self.num_layers = 3
        self.W_v = nn.Sequential(
            GVP(self.node_in_dim, self.node_h_dim, activations=(None, None)),
            LayerNorm(self.node_h_dim)
        )
        self.W_e = nn.Sequential(
            GVP(self.edge_in_dim, self.edge_h_dim, activations=(None, None)),
            LayerNorm(self.edge_h_dim)
        )
        
        self.encoder_layers = nn.ModuleList(
                GVPConvLayer(self.node_h_dim, self.edge_h_dim, drop_rate=drop_rate) 
            for _ in range(self.num_layers))
        
        self.W_s = nn.Embedding(33, 20)
        self.edge_h_dim = (self.edge_h_dim[0] + 20, self.edge_h_dim[1])
      
        self.decoder_layers = nn.ModuleList(
                GVPConvLayer(self.node_h_dim, self.edge_h_dim, 
                             drop_rate=drop_rate, autoregressive=True) 
            for _ in range(self.num_layers))
        
        self.W_out = GVP(self.node_h_dim, (33, 0), activations=(None, None))

        self.encode_t = 0
        self.decode_t = 0

    def _get_features(self, batch):
        return batch

    def forward(self, batch):
        '''
        Forward pass to be used at train-time, or evaluating likelihood.
        
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param seq: int `torch.Tensor` of shape [num_nodes]
        '''
        h_V = (batch.node_s, batch.node_v)
        h_E = (batch.edge_s, batch.edge_v)
        edge_index = batch.edge_index
        seq = batch.seq
        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        
        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)
        
        encoder_embeddings = h_V
        
        h_S = self.W_s(seq)
        h_S = h_S[edge_index[0]]
        h_S[edge_index[0] >= edge_index[1]] = 0
        h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
        
        for layer in self.decoder_layers:
            h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
        
        logits = self.W_out(h_V)
        log_probs = nn.functional.log_softmax(logits, dim=-1)
        return {'log_probs': log_probs, 'logits': logits}
    
    def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1):
        '''
        Samples sequences auto-regressively from the distribution
        learned by the model.
        
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param n_samples: number of samples
        :param temperature: temperature to use in softmax 
                            over the categorical distribution
        
        :return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the
                 residue-to-int mapping of the original training data
        '''
        
        with torch.no_grad():
            device = edge_index.device
            L = h_V[0].shape[0]

            h_V = self.W_v(h_V)
            h_E = self.W_e(h_E)
            
            for layer in self.encoder_layers:
                h_V = layer(h_V, edge_index, h_E)   

            h_V = (h_V[0].repeat(n_samples, 1),
                   h_V[1].repeat(n_samples, 1, 1))
            
            h_E = (h_E[0].repeat(n_samples, 1),
                   h_E[1].repeat(n_samples, 1, 1))
            
            edge_index = edge_index.expand(n_samples, -1, -1)
            offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1)
            edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
            
            seq = torch.zeros(n_samples * L, device=device, dtype=torch.int)
            h_S = torch.zeros(n_samples * L, 20, device=device)
    
            h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]
            
            all_probs = []
            for i in range(L):
                h_S_ = h_S[edge_index[0]]
                h_S_[edge_index[0] >= edge_index[1]] = 0
                h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
                        
                edge_mask = edge_index[1] % L == i
                edge_index_ = edge_index[:, edge_mask]
                h_E_ = tuple_index(h_E_, edge_mask)
                node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool)
                node_mask[i::L] = True
                
                for j, layer in enumerate(self.decoder_layers):
                    out = layer(h_V_cache[j], edge_index_, h_E_,
                               autoregressive_x=h_V_cache[0], node_mask=node_mask)
                    
                    out = tuple_index(out, node_mask)
                    
                    if j < len(self.decoder_layers)-1:
                        h_V_cache[j+1][0][i::L] = out[0]
                        h_V_cache[j+1][1][i::L] = out[1]
                    
                logits = self.W_out(out)
                seq[i::L] = Categorical(logits=logits / temperature).sample()
                h_S[i::L] = self.W_s(seq[i::L])
                all_probs.append(torch.softmax(logits, dim=-1))

            self.probs = torch.cat(all_probs, dim=0)
            return seq.view(n_samples, L)
    
    def test_recovery(self, protein):
        h_V = (protein.node_s, protein.node_v)
        h_E = (protein.edge_s, protein.edge_v) 
        sample = self.sample(h_V, protein.edge_index, h_E, n_samples=1)
        return sample.squeeze(0)