from torch import nn
import pickle
import torch


class E_GCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    re
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.residual = residual
        self.attention = attention
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8
        edge_coords_nf = 1

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.residual:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        if torch.any(torch.isinf(edge_feat)):
            pickle.dump([coord, edge_index, coord_diff, edge_feat], open("tmp_data_2", "wb"))
            print("inf edge_feat 0")
            exit()
        if torch.any(torch.isinf(coord_diff)):
            print("inf coord_diff")
            pickle.dump([coord, edge_index, coord_diff, edge_feat], open("tmp_data_2", "wb"))
            exit()
        trans = coord_diff * self.coord_mlp(edge_feat)
        if torch.any(torch.isinf(trans)):
            print("inf trans edge_feat")
            pickle.dump([coord, edge_index, coord_diff, edge_feat], open("tmp_data_2", "wb"))
            exit()
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord = coord + agg
        return coord

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum(coord_diff**2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + self.epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)
        if torch.any(torch.isnan(radial)) or torch.any(torch.isnan(coord_diff)):
            print(torch.any(torch.isnan(radial)))
            print(torch.any(torch.isnan(coord_diff)))
            exit()
        if torch.any(torch.isinf(radial)):
            print("inf radial")
            exit()

        if torch.any(torch.isinf(coord_diff)):
            print("inf coord_diff")
            exit()

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        if torch.any(torch.isinf(edge_feat)):
            print("inf edge_feat")
            exit()
        if torch.any(torch.isnan(edge_feat)):
            print("edge_feat")
            exit()
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        if torch.any(torch.isnan(coord)):
            print("coord")
            exit()
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        if torch.any(torch.isnan(h)):
            print("h")
            exit()

        return h, coord, edge_attr


class EGNN(nn.Module):
    def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False):
        '''

        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param in_edge_nf: Number of features for the edge features
        :param device: Device (e.g. 'cpu', 'cuda:0',...)
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param attention: Whether using attention or not
        :param normalize: Normalizes the coordinates messages such that:
                    instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)
                    we get:     x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j||
                    We noticed it may help in the stability or generalization in some future works.
                    We didn't use it in our paper.
        :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of
                        phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy.
                        We didn't use it in our paper.
        '''

        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf,
                self.hidden_nf,
                self.hidden_nf,
                edges_in_d=in_edge_nf,
                act_fn=act_fn,
                residual=residual,
                attention=attention,
                normalize=normalize,
                tanh=tanh))
        self.to(self.device)

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding_in(h)
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr)
        h = self.embedding_out(h)
        return h, x


def unsorted_segment_sum(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    if torch.any(torch.isnan(result)):
        print("result0", result)
        exit()
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    if torch.any(torch.isnan(result)):
        tmp_data = [segment_ids, result_shape, num_segments, data.size(1), data]
        pickle.dump(tmp_data, open("tmp_data.pkl", "wb"))
        print("result1", result)
        exit()
    if  torch.any(torch.isnan(count)):
        print("count", count)
        exit()
    tmp_result = result / count.clamp(min=1)

    if torch.any(torch.isnan(tmp_result)):
        print("count", count)
        print("result", result)
        exit()
    return result / count.clamp(min=1)


def get_edges(n_nodes):
    rows, cols = [], []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j:
                rows.append(i)
                cols.append(j)

    edges = [rows, cols]
    return edges


def get_edges_batch(n_nodes, batch_size):
    edges = get_edges(n_nodes)
    edge_attr = torch.ones(len(edges[0]) * batch_size, 1)
    edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])]
    if batch_size == 1:
        return edges, edge_attr
    elif batch_size > 1:
        rows, cols = [], []
        for i in range(batch_size):
            rows.append(edges[0] + n_nodes * i)
            cols.append(edges[1] + n_nodes * i)
        edges = [torch.cat(rows), torch.cat(cols)]
    return edges, edge_attr


def get_edges_new(pos):
    all_edge_0 = []
    all_edge_1 = []
    for i in range(pos.shape[0]):
        tmp_pos = pos[i]
        tmp_pos = tmp_pos.unsqueeze(dim=0)
        tmp_dist = torch.cdist(tmp_pos, pos)
        tmp_dist = tmp_dist.squeeze(dim=0)
        index = torch.nonzero((tmp_dist<4) & (tmp_dist>0))
        index = index.squeeze(dim=1)
        edge_0 = index
        edge_1 = torch.ones_like(index)
        edge_1.fill_(i)
        edge_0 = edge_0.tolist()
        edge_1 = edge_1.tolist()
        all_edge_0 += edge_0
        all_edge_1 += edge_1
    edges = [all_edge_0, all_edge_1]
    edges_index = torch.tensor(edges)
    edges_attr = torch.ones_like(edges_index[0])
    edges_attr = edges_attr.unsqueeze(dim=1)
    print(edges_index.shape, edges_attr.shape)
    return edges_index, edges_attr


if __name__ == "__main__":
    import pickle
    # Dummy parameters
    batch_size = 8
    n_nodes = 4
    n_feat = 1
    x_dim = 3

    # Dummy variables h, x and fully connected edges
    h = torch.ones(batch_size *  n_nodes, n_feat)
    x = torch.ones(batch_size * n_nodes, x_dim)
    edges, edge_attr = get_edges_batch(n_nodes, batch_size)
    egnn = EGNN(in_node_nf=n_feat, hidden_nf=n_feat, out_node_nf=n_feat, in_edge_nf=1)
    from torch_geometric.nn import radius_graph
    edges = radius_graph(x, r=3)
    edge_attr = torch.ones_like(edges[0])
    edge_attr = edge_attr.unsqueeze(dim=1)
    egnn(h, x, edges, edge_attr)
    n_nodes = 4
    n_feat = 128
    x_dim = 3
    # Initialize EGNN
    egnn = EGNN(in_node_nf=n_feat, hidden_nf=n_feat, out_node_nf=n_feat, in_edge_nf=1)
    num_params = sum(p.numel() for p in egnn.parameters())
    print(num_params)
    exit()
    egnn = egnn.to("cuda:1")

    # Run EGNN
    h, x = egnn(h, x, edge_index, edge_attr)
    print(h.shape, x.shape)

