import torch
import torch.nn as nn

from egnn.egnn_new import EGNN, GNN
from equivariant_diffusion.utils import remove_mean, remove_mean_with_mask


class EGNN_regressor(nn.Module):
    def __init__(self, in_node_nf, context_node_nf,
                 n_dims, hidden_nf=64, device='cpu',
                 act_fn=torch.nn.SiLU(), n_layers=4, attention=False,
                 tanh=False, mode='egnn_dynamics', norm_constant=0,
                 inv_sublayers=2, sin_embedding=False, normalization_factor=100, aggregation_method='sum',
                 include_atomic_numbers=False, condition_time=True,
                 classifier_head=False, n_props=1): 
        '''
        :param in_node_nf: Number of latent invariant features.
        :param out_node_nf: Number of invariant features.
        '''
        super().__init__()
        if condition_time:
            in_node_nf = in_node_nf + 1

        self.mode = mode
        if mode == 'egnn_dynamics':
            self.egnn = EGNN(
                in_node_nf=in_node_nf + context_node_nf, out_node_nf=hidden_nf, 
                in_edge_nf=1, hidden_nf=hidden_nf, device=device, act_fn=act_fn,
                n_layers=n_layers, attention=attention, tanh=tanh, norm_constant=norm_constant,
                inv_sublayers=inv_sublayers, sin_embedding=sin_embedding,
                normalization_factor=normalization_factor,
                aggregation_method=aggregation_method)
            self.in_node_nf = in_node_nf
        elif mode == 'gnn_dynamics':
            self.gnn = GNN(
                in_node_nf=in_node_nf + context_node_nf + 3, out_node_nf=hidden_nf + 3, 
                in_edge_nf=0, hidden_nf=hidden_nf, device=device,
                act_fn=act_fn, n_layers=n_layers, attention=attention,
                normalization_factor=normalization_factor, aggregation_method=aggregation_method)

        self.node_dec = nn.Sequential(nn.Linear(hidden_nf, hidden_nf),
                                      act_fn,
                                      nn.Linear(hidden_nf, hidden_nf)
                                      ).to(device) 

        if classifier_head:
            # multi-label binary classifier for morgan fingerprint
            # self.graph_dec = nn.Sequential(nn.Linear(hidden_nf, 512),
            #                             act_fn,
            #                             nn.Linear(512, 1024),
            #                             act_fn,
            #                             nn.Linear(1024, 1024)).to(device) 
            self.graph_dec = nn.Sequential(nn.Linear(hidden_nf, 1024),
                                        act_fn,
                                        nn.Linear(1024, 1024)).to(device) 

        else:
            # 1 or multiple regression targets
            self.graph_dec = nn.Sequential(nn.Linear(hidden_nf, hidden_nf),
                                        act_fn,
                                        nn.Linear(hidden_nf, n_props)
                                        ).to(device) 

        self.hidden_nf = hidden_nf
        self.include_atomic_numbers = include_atomic_numbers
        self.device = device
        self.n_dims = n_dims
        self.condition_time = condition_time
        self.classifier_head = classifier_head
        self._edges_dict = {}

    def forward(self, t, xh, node_mask, edge_mask, context=None):
        raise NotImplementedError

    def wrap_forward(self, node_mask, edge_mask, context):
        def fwd(time, state):
            return self._forward(time, state, node_mask, edge_mask, context)
        return fwd

    def unwrap_forward(self):
        return self._forward

    def _forward(self, t, xh, node_mask, edge_mask, context=None):
        """
        xh is the output of the encoder
        context should be usually None as we're trying to predict the property
        """
        bs, n_nodes, dims = xh.shape
        h_dims = dims - self.n_dims
        edges = self.get_adj_matrix(n_nodes, bs, self.device)
        edges = [x.to(self.device) for x in edges]
        node_mask = node_mask.view(bs*n_nodes, 1)
        edge_mask = edge_mask.view(bs*n_nodes*n_nodes, 1)
        xh = xh.view(bs*n_nodes, -1).clone() * node_mask
        x = xh[:, 0:self.n_dims].clone()
        if h_dims == 0:
            h = torch.ones(bs*n_nodes, 1).to(self.device)
        else:
            h = xh[:, self.n_dims:].clone()

        if self.condition_time:
            # t is different over the batch dimension.
            h_time = t.view(bs, 1).repeat(1, n_nodes)
            h_time = h_time.view(bs * n_nodes, 1)
            h = torch.cat([h, h_time], dim=1)

        if context is not None:
            # We're conditioning, awesome!
            context = context.view(bs*n_nodes, self.context_node_nf)
            h = torch.cat([h, context], dim=1)

        if self.mode == 'egnn_dynamics':
            # we only need the features
            h_final, _ = self.egnn(h, x, edges, node_mask=node_mask, edge_mask=edge_mask)

        elif self.mode == 'gnn_dynamics':
            xh = torch.cat([x, h], dim=1)
            output = self.gnn(xh, edges, node_mask=node_mask)
            vel = output[:, 0:3] * node_mask
            h_final = output[:, 3:]

        else:
            raise Exception("Wrong mode %s" % self.mode)

        # h_final has shape (bs*n_nodes, hidden_nf)
        h = self.node_dec(h_final) # (bs*n_nodes, hidden_nf)
        if node_mask is not None:
            h = h * node_mask
        h = h.view(-1, n_nodes, self.hidden_nf) # (bs, n_nodes, hidden_nf)
        h = torch.sum(h, dim=1) # (bs, hidden_nf)
        pred = self.graph_dec(h) # (bs, n_props) or (bs, 512)
        return pred.squeeze(1)

    # # reference to run only on nodes we need
    # def decode_features_from_h(self, h, node_mask):
    #     bs, n_nodes, _ = h.shape

    #     # The following code ensures that we only predict the nodes we care about
    #     # i.e. we ignore the nodes that were padded to all atoms in the batch so that they have the same number of nodes
    #     h = h.view(bs*n_nodes, -1)
    #     h_indices_non_zero = node_mask.bool().squeeze()
    #     h_processed = self.h_head(h[h_indices_non_zero])
    #     _, output_dim = h_processed.shape
    #     h_final_output = torch.zeros((bs*n_nodes, output_dim)).to(self.device)
    #     h_final_output[h_indices_non_zero] = h_processed
    #     h = h_final_output
    #     h = h.view(bs, n_nodes, -1)
    #     return h

    def get_adj_matrix(self, n_nodes, batch_size, device):
        if n_nodes in self._edges_dict:
            edges_dic_b = self._edges_dict[n_nodes]
            if batch_size in edges_dic_b:
                return edges_dic_b[batch_size]
            else:
                # get edges for a single sample
                rows, cols = [], []
                for batch_idx in range(batch_size):
                    for i in range(n_nodes):
                        for j in range(n_nodes):
                            rows.append(i + batch_idx * n_nodes)
                            cols.append(j + batch_idx * n_nodes)
                edges = [torch.LongTensor(rows).to(device),
                         torch.LongTensor(cols).to(device)]
                edges_dic_b[batch_size] = edges
                return edges
        else:
            self._edges_dict[n_nodes] = {}
            return self.get_adj_matrix(n_nodes, batch_size, device)
