"""
=====================
Convolutional Layers
=====================


"""
import torch
from torch import nn
import torch.nn.functional as F

from _utils import *

def unsorted_segment_sum(data, segment_ids, num_segments):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)



class E_GCL(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), recurrent=True, attention=False, clamp=False, norm_diff=True, tanh=False, coords_range=1, agg='sum', normalize=False, norm_const=1.0, normalize_type=0):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.recurrent = recurrent
        self.attention = attention
        self.norm_diff = norm_diff
        self.agg_type = agg
        self.tanh = tanh
        edge_coords_nf = 1
        self.normalize = normalize
        self.norm_const = norm_const
        self.normalize_type = normalize_type
        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
            self.coords_range = coords_range

        self.coord_mlp = nn.Sequential(*coord_mlp)
        self.clamp = clamp

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

        #if recurrent:
        #    self.gru = nn.GRUCell(hidden_nf, hidden_nf)

        self.listener = {} # !JB: Listener
        pass


    def edge_model(self, source, target, radial, edge_attr, edge_mask):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)

        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val

        if edge_mask is not None:
            out = out * edge_mask
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.recurrent:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask):
        row, col = edge_index

        pvals = self.coord_mlp(edge_feat)
        if self.training:
            p_vals_grad = torch.autograd.grad(pvals, radial, torch.ones_like(pvals), create_graph=True)[0]
            edge_feat_grad = torch.autograd.grad(edge_feat, radial, torch.ones_like(edge_feat), create_graph=True)[0]
            ip = edge_feat_grad.T @ p_vals_grad
            ip = p_vals_grad
        else:
            p_vals_grad = torch.zeros_like(pvals)
            ip = torch.zeros_like(pvals)

        if self.tanh:
            self.reg = (coord_diff * self.coords_range * p_vals_grad).norm() * -0.5 + (self.coords_range *  pvals).norm()
            self.reg = ip.norm() + 0.0 * pvals.norm()
            # self.reg =  0.0 * pvals.norm()
        else:
            self.reg = (coord_diff * p_vals_grad).norm() * -0.5 + (pvals).norm()
            self.reg =  0.0 * pvals.norm()
            # self.reg = ip.norm() + 0.0 * pvals.norm()

        # if self.training: self.listener['phi_x'] = pvals
        # if self.training: pvals.register_hook(self.phi_x_grad)

        if self.tanh:
            trans = coord_diff * pvals * self.coords_range
        else:
            trans = coord_diff * pvals
        #trans = torch.clamp(trans, min=-100, max=100)
        if edge_mask is not None:
            trans = trans * edge_mask

        if self.agg_type == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.agg_type == 'mean':
            if node_mask is not None:
                #raise Exception('This part must be debugged before use')
                agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
                M = unsorted_segment_sum(node_mask[col], row, num_segments=coord.size(0))
                agg = agg/(M-1)
            else:
                agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception("Wrong coordinates aggregation type")
        coord = coord + agg
        return coord

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        if self.training: self.listener['h[l]'] = h

        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)
        if self.training: self.listener['r_ij'] = radial 
        if self.training: self.listener['max_r_ij'] = radial.max().item()

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr, edge_mask)
        # if self.training: self.listener['m_ij'] = edge_feat
        # if self.training: edge_feat.register_hook(self.mij_grad)
        if self.training: list(self.edge_mlp.parameters())[0].register_hook(self.theta_grad)
        # print(list(self.edge_mlp.parameters()))
        # exit()

        coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask)
        # if self.training: self.listener['x^[l+1]'] = coord
        # if self.training: coord.register_hook(self.x_grad)

        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        # if self.training: self.listener['h[l+1]'] = h
        # if self.training: self.listener['phi_h'] = agg
        # if self.training: agg.register_hook(self.phi_h_grad)
        # coord = self.node_coord_model(h, coord)
        # x = self.node_model(x, edge_index, x[col], u, batch)  # GCN

        if node_mask is not None:
            h = h * node_mask
            coord = coord * node_mask
        return h, coord, edge_attr

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)

        if self.normalize:
            if self.normalize_type == 0:
                norm = torch.sqrt(radial+1e-8)
                coord_diff = coord_diff/(norm)
            elif self.normalize_type == 1:
                norm = torch.sqrt(radial+1e-8)
                coord_diff = coord_diff/(norm + self.norm_const)
            elif self.normalize_type == 2:
                norm = torch.sqrt(radial+self.norm_const)
                coord_diff = coord_diff/(norm)

        return radial, coord_diff

    def theta_grad(self, grad):
        self.listener['||dL/dtheta||'] = grad.clone().norm().item()
    def mij_grad(self, grad):
        self.listener['||dL/dm_ij||'] = grad.clone().norm().item()
    def x_grad(self, grad):
        self.listener['||dL/dx||']=grad.clone().norm().item()
    def phi_h_grad(self, grad):
        self.listener['||dL/dphi_h||']=grad.clone().norm().item()
    def phi_x_grad(self, grad):
        self.listener['||dL/dphi_x||']=grad.clone().norm().item()

    def get_listener(self):

        # Compute all interesting partials
        temp_dict = {}
        # temp_dict['dm_ij/dh[l]'] = torch.autograd.grad(self.listener['m_ij'].clone(), self.listener['h[l]'], torch.ones_like(self.listener['m_ij']), create_graph=True)
        # temp_dict['dm_ij/dr_ij'] = torch.autograd.grad(self.listener['m_ij'].clone(), self.listener['r_ij'], torch.ones_like(self.listener['m_ij']), create_graph=True)

        # temp_dict['dphi_x/dh[l]'] = torch.autograd.grad(self.listener['phi_x'], self.listener['h[l]'], torch.ones_like(self.listener['phi_x']), create_graph=True)
        # temp_dict['dphi_x/dr_ij'] = torch.autograd.grad(self.listener['phi_x'], self.listener['r_ij'], torch.ones_like(self.listener['phi_x']), create_graph=True)
        # temp_dict['dphi_x/dm_ij'] = torch.autograd.grad(self.listener['phi_x'], self.listener['m_ij'], torch.ones_like(self.listener['phi_x']), create_graph=True)

        # temp_dict['dphi_h/dh[l]'] = torch.autograd.grad(self.listener['phi_h'], self.listener['h[l]'], torch.ones_like(self.listener['phi_h']), create_graph=True)
        # temp_dict['dphi_h/dr_ij'] = torch.autograd.grad(self.listener['phi_h'], self.listener['r_ij'], torch.ones_like(self.listener['phi_h']), create_graph=True)
        # temp_dict['dphi_h/dm_ij'] = torch.autograd.grad(self.listener['phi_h'], self.listener['m_ij'], torch.ones_like(self.listener['phi_h']), create_graph=True)

        # temp_dict['dh[l+1]/dh[l]'] = torch.autograd.grad(self.listener['h[l+1]'], self.listener['h[l]'], torch.ones_like(self.listener['h[l+1]']), create_graph=True)
        # temp_dict['dh[l+1]/dr_ij'] = torch.autograd.grad(self.listener['h[l+1]'], self.listener['r_ij'], torch.ones_like(self.listener['h[l+1]']), create_graph=True)
        # temp_dict['dh[l+1]/dm_ij'] = torch.autograd.grad(self.listener['h[l+1]'], self.listener['m_ij'], torch.ones_like(self.listener['h[l+1]']), create_graph=True)
        # temp_dict['dh[l+1]/dphi_h'] = torch.autograd.grad(self.listener['h[l+1]'], self.listener['phi_h'], torch.ones_like(self.listener['h[l+1]']), create_graph=True)

        # # Store norms of partials
        # self.listener['||dm_ij/dh[l]||'] = temp_dict['dm_ij/dh[l]'][0].clone().detach().norm().item()
        # self.listener['||dm_ij/dr_ij||'] = temp_dict['dm_ij/dr_ij'][0].clone().detach().norm().item()

        # self.listener['||dphi_x/dh[l]||'] = temp_dict['dphi_x/dh[l]'][0].clone().detach().norm().item()
        # self.listener['||dphi_x/dr_ij||'] = temp_dict['dphi_x/dr_ij'][0].clone().detach().norm().item()
        # self.listener['||dphi_x/dm_ij||'] = temp_dict['dphi_x/dm_ij'][0].clone().detach().norm().item()

        # self.listener['||dphi_h/dh[l]||'] = temp_dict['dphi_h/dh[l]'][0].clone().detach().norm().item()
        # self.listener['||dphi_h/dr_ij||'] = temp_dict['dphi_h/dr_ij'][0].clone().detach().norm().item()
        # self.listener['||dphi_h/dm_ij||'] = temp_dict['dphi_h/dm_ij'][0].clone().detach().norm().item()

        # self.listener['||dh[l+1]/dh[l]||'] = temp_dict['dh[l+1]/dh[l]'][0].clone().detach().norm().item()
        # self.listener['||dh[l+1]/dr_ij||'] = temp_dict['dh[l+1]/dr_ij'][0].clone().detach().norm().item()
        # self.listener['||dh[l+1]/dm_ij||'] = temp_dict['dh[l+1]/dm_ij'][0].clone().detach().norm().item()
        # self.listener['||dh[l+1]/dphi_h||'] = temp_dict['dh[l+1]/dphi_h'][0].clone().detach().norm().item()

        del temp_dict
        return self.listener


class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum', normalize=False, norm_const=1.0, normalize_type=0, reg_para=0):
        super(EGNN, self).__init__()
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.hidden_nf = hidden_nf
        self.n_layers = n_layers
        self.coords_range_layer = float(coords_range)/self.n_layers
        self.reg_para = reg_para
        if agg == 'mean':
            self.coords_range_layer = self.coords_range_layer * 19
        #self.reg = reg
        ### Encoder
        #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=recurrent, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer, agg=agg, normalize=normalize, norm_const=norm_const, normalize_type=normalize_type))

    def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        h = self.embedding(h)
        self.reg_terms = []
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
            self.reg_terms.append(self._modules["gcl_%d" % i].reg * (self.reg_para*i + 1)) #jb: REGNN hyh: reg by linear
        h = self.embedding_out(h)

        # Important, the bias of the last linear might be non-zero
        if node_mask is not None:
            h = h * node_mask
        return h, x




class EGNN_output_h(nn.Module):
    def __init__(self, in_node_nf, out_node_nf, hidden_nf=64, 
                 act_fn=torch.nn.SiLU(), n_layers=4, recurrent=True,
                 attention=False, agg='sum'):
        super().__init__()
        self.egnn = EGNN(in_node_nf=in_node_nf, in_edge_nf=0,
                         hidden_nf=hidden_nf, act_fn=act_fn,
                         n_layers=n_layers, recurrent=recurrent,
                         attention=attention,
                         out_node_nf=out_node_nf, agg=agg)

        self.in_node_nf = in_node_nf
        self.out_node_nf = out_node_nf
        # self.n_dims = None
        self._edges_dict = {}

    def forward(self, h, x, edges, batch):
        h_final, x_final = self.egnn(h, x, edges)
        return h_final