import torch
import torch.nn as nn

from _convs import EGNN
from _utils import remove_mean, remove_mean_with_mask

class EGNN_dynamics(nn.Module):
    def __init__(self, n_particles, n_dimension, hidden_nf=64,
            act_fn=torch.nn.SiLU(), n_layers=4, recurrent=True, attention=False,
                 condition_time=True, tanh=False, agg='sum', normalize=False, norm_const=1.0, normalize_type=0, reg_para=0, reg_clip=1e-16):
        super().__init__()
        self.egnn = EGNN(in_node_nf=1, in_edge_nf=1, hidden_nf=hidden_nf, act_fn=act_fn, n_layers=n_layers, recurrent=recurrent, attention=attention, tanh=tanh, agg=agg, normalize=normalize, norm_const=norm_const, normalize_type=normalize_type, reg_para=reg_para)

        self._n_particles = n_particles
        self._n_dimension = n_dimension
        self.edges = self._create_edges()
        self._edges_dict = {}
        self.condition_time = condition_time

        self.reg_terms = torch.tensor(0.0) # hyh
        self.reg_clip = reg_clip
        self.listener = {}
        self.count = 0

    def forward(self, t, xs, edges=None):
        n_batch = xs.shape[0]
        edges = self._cast_edges2batch(self.edges, n_batch, self._n_particles, xs.device)
        edges = [edges[0], edges[1]]
        x = xs.view(n_batch*self._n_particles, self._n_dimension).clone()
        h = torch.ones(n_batch*self._n_particles, 1).to(xs.device)
        if self.condition_time:
            h = h*t

        edge_attr = torch.sum((x[edges[0]] - x[edges[1]])**2, dim=1, keepdim=True)
        _, x_final = self.egnn(h, x, edges, edge_attr=edge_attr)
        

        # x_test = x_final.norm()
        # gradients = torch.autograd.grad(outputs=x_test, inputs=self.egnn.parameters(), retain_graph=True, allow_unused=True)
        # gradients = sum([grad.norm().item() for grad in gradients if grad is not None])
        # self.listener[t.item()] = gradients
        self.count = self.count+1

        if t >= 0.0 and t< 1e-16:
            if sum(self.egnn.reg_terms) < self.reg_clip:
                self.reg_terms =  torch.tensor(0.0)
                self.reg_num_dt = 0
            else:
                self.reg_terms = sum(self.egnn.reg_terms)
                self.reg_num_dt = 1
        elif t >= 1e-16:
            try:
                # self.reg_terms = max(self.reg_terms, sum(self.egnn.reg_terms))
                if sum(self.egnn.reg_terms) >= self.reg_clip:
                    self.reg_terms = self.reg_terms + sum(self.egnn.reg_terms)
                    self.reg_num_dt += 1
            except:
                print("Norm errors:")
                print(self.reg_terms, sum(self.egnn.reg_terms))

        vel = x_final - x

        vel = vel.view(n_batch, self._n_particles, self._n_dimension)
        vel = remove_mean(vel)
        return vel

    def _create_edges(self):
        rows, cols = [], []
        for i in range(self._n_particles):
            for j in range(i + 1, self._n_particles):
                rows.append(i)
                cols.append(j)
                rows.append(j)
                cols.append(i)
        return [torch.LongTensor(rows), torch.LongTensor(cols)]

    def _cast_edges2batch(self, edges, n_batch, n_nodes, device):
        if n_batch not in self._edges_dict:
            self._edges_dict = {}
            rows, cols = edges
            rows_total, cols_total = [], []
            for i in range(n_batch):
                rows_total.append(rows + i * n_nodes)
                cols_total.append(cols + i * n_nodes)
            rows_total = torch.cat(rows_total).to(device)
            cols_total = torch.cat(cols_total).to(device)

            self._edges_dict[n_batch] = [rows_total, cols_total]
        return self._edges_dict[n_batch]



class EGNN_dynamics_QM9(nn.Module):
    def __init__(self, in_node_nf, context_node_nf,
                 n_dims, hidden_nf=64, 
            act_fn=torch.nn.SiLU(), n_layers=4, recurrent=True, attention=False,
                 condition_time=True, tanh=False, mode='egnn_dynamics', agg='sum'):
        super().__init__()
        self.mode = mode
        self.egnn = EGNN(
            in_node_nf=in_node_nf + context_node_nf, in_edge_nf=1,
            hidden_nf=hidden_nf, act_fn=act_fn,
            n_layers=n_layers, recurrent=recurrent, attention=attention, tanh=tanh, agg=agg)
        self.in_node_nf = in_node_nf

        self.context_node_nf = context_node_nf
        self.n_dims = n_dims
        self._edges_dict = {}
        self.condition_time = condition_time

    def forward(self, t, xh, edges, context=None):
        n_nodes, dims = xh.shape
        h_dims = dims - self.n_dims
        xh = xh.view(n_nodes, -1).clone()
        x = xh[:, 0:self.n_dims].clone()
        if h_dims == 0:
            h = torch.ones(n_nodes, 1).to(xh.device)
        else:
            h = xh[:, self.n_dims:].clone()

        if self.condition_time:
            h_time = torch.empty_like(h[:, 0:1]).fill_(t)
            h = torch.cat([h, h_time], dim=1)

        if context is not None:
            # We're conditioning, awesome!
            context = context.view(n_nodes, self.context_node_nf)
            h = torch.cat([h, context], dim=1)

        edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True)
        h_final, x_final = self.egnn(h, x, edges, edge_attr=edge_attr)
        vel = (x_final - x) # This masking operation is redundant but just in case

        if context is not None:
            # Slice off context size:
            h_final = h_final[:, :-self.context_node_nf]

        if self.condition_time:
            # Slice off last dimension which represented time.
            h_final = h_final[:, :-1]

        vel = vel.view(n_nodes, -1)

        vel = remove_mean(vel)

        if h_dims == 0:
            return vel
        else:
            h_final = h_final.view(n_nodes, -1)
            return torch.cat([vel, h_final], dim=1)