

import torch
import torch.nn as nn

from _convs import EGNN
from _utils import remove_mean, remove_mean_with_mask


class EGNN_dynamics(nn.Module):
    def __init__(self, n_particles, n_dimension, hidden_nf=64,
            act_fn=torch.nn.SiLU(), n_layers=4, recurrent=True, attention=False,
                 condition_time=True, tanh=False, agg='sum'):
        super().__init__()
        self.egnn = EGNN(in_node_nf=1, in_edge_nf=1, hidden_nf=hidden_nf, act_fn=act_fn, n_layers=n_layers, recurrent=recurrent, attention=attention, tanh=tanh, agg=agg)

        self._n_particles = n_particles
        self._n_dimension = n_dimension
        self.edges = self._create_edges()
        self._edges_dict = {}
        self.condition_time = condition_time


    def forward(self, t, xs, edges=None):
        n_batch = xs.shape[0]
        edges = self._cast_edges2batch(self.edges, n_batch, self._n_particles, xs.device)
        edges = [edges[0], edges[1]]
        x = xs.view(n_batch*self._n_particles, self._n_dimension).clone()
        h = torch.ones(n_batch*self._n_particles, 1).to(xs.device)
        if self.condition_time:
            h = h*t

        edge_attr = torch.sum((x[edges[0]] - x[edges[1]])**2, dim=1, keepdim=True)
        _, x_final = self.egnn(h, x, edges, edge_attr=edge_attr)
        vel = x_final - x

        vel = vel.view(n_batch, self._n_particles, self._n_dimension)
        vel = remove_mean(vel)
        return vel

    def _create_edges(self):
        rows, cols = [], []
        for i in range(self._n_particles):
            for j in range(i + 1, self._n_particles):
                rows.append(i)
                cols.append(j)
                rows.append(j)
                cols.append(i)
        return [torch.LongTensor(rows), torch.LongTensor(cols)]

    def _cast_edges2batch(self, edges, n_batch, n_nodes, device):
        if n_batch not in self._edges_dict:
            self._edges_dict = {}
            rows, cols = edges
            rows_total, cols_total = [], []
            for i in range(n_batch):
                rows_total.append(rows + i * n_nodes)
                cols_total.append(cols + i * n_nodes)
            rows_total = torch.cat(rows_total).to(device)
            cols_total = torch.cat(cols_total).to(device)

            self._edges_dict[n_batch] = [rows_total, cols_total]
        return self._edges_dict[n_batch]







class EGNN_dynamics_QM9(nn.Module):
    def __init__(self, in_node_nf, context_node_nf,
                 n_dims, hidden_nf=64, 
            act_fn=torch.nn.SiLU(), n_layers=4, recurrent=True, attention=False,
                 condition_time=True, tanh=False, mode='egnn_dynamics', agg='sum'):
        super().__init__()
        self.mode = mode
        self.egnn = EGNN(
            in_node_nf=in_node_nf + context_node_nf, in_edge_nf=1,
            hidden_nf=hidden_nf, act_fn=act_fn,
            n_layers=n_layers, recurrent=recurrent, attention=attention, tanh=tanh, agg=agg)
        self.in_node_nf = in_node_nf

        self.context_node_nf = context_node_nf
        self.n_dims = n_dims
        self._edges_dict = {}
        self.condition_time = condition_time

    def forward(self, t, xh, edges, context=None):
        n_nodes, dims = xh.shape
        h_dims = dims - self.n_dims
        xh = xh.view(n_nodes, -1).clone()
        x = xh[:, 0:self.n_dims].clone()
        if h_dims == 0:
            h = torch.ones(n_nodes, 1).to(xh.device)
        else:
            h = xh[:, self.n_dims:].clone()

        if self.condition_time:
            h_time = torch.empty_like(h[:, 0:1]).fill_(t)
            h = torch.cat([h, h_time], dim=1)

        if context is not None:
            # We're conditioning, awesome!
            context = context.view(n_nodes, self.context_node_nf)
            h = torch.cat([h, context], dim=1)

        edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True)
        h_final, x_final = self.egnn(h, x, edges, edge_attr=edge_attr)
        vel = (x_final - x) # This masking operation is redundant but just in case

        if context is not None:
            # Slice off context size:
            h_final = h_final[:, :-self.context_node_nf]

        if self.condition_time:
            # Slice off last dimension which represented time.
            h_final = h_final[:, :-1]

        vel = vel.view(n_nodes, -1)

        vel = remove_mean(vel)

        if h_dims == 0:
            return vel
        else:
            h_final = h_final.view(n_nodes, -1)
            return torch.cat([vel, h_final], dim=1)
