import torch
import torch.nn as nn

from torch import Tensor
from models.EGNN_GRU.egnns import EGNN

from torch_geometric.utils import dense_to_sparse

__all__ = [
    'EGNN_GRU'
]

class EGNN_GRU(nn.Module):
    def __init__(self, 
            input_size: int, gnn_hidden_size: int, rnn_hidden_size: int,
            n_layers: int, repara: bool, repara_size: int, n_system: int, device,
            n_diff_time: int, t_embed_size: int, **kwargs
        ):
        super().__init__()
        node_embed_size = 16
        self.n_system = n_system
        self.t_embed = nn.Embedding(n_diff_time, t_embed_size)
        self.node_embed = nn.Embedding(n_system, node_embed_size)

        self.gnn = EGNN(
            in_node_nf = 2 + t_embed_size + node_embed_size,
            hidden_nf = gnn_hidden_size,
            out_node_nf = gnn_hidden_size,
            in_edge_nf = 1,
            n_layers = n_layers,
            device = device
        )

        self.rnn = nn.GRU(
            input_size=gnn_hidden_size * n_system + input_size,
            hidden_size=rnn_hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )
        self.gnn_hidden_size = gnn_hidden_size

        self.read_out = nn.Linear(
            2 * rnn_hidden_size + input_size,
            input_size
        )

        self.phy_out = nn.Linear(
            2 * rnn_hidden_size + input_size,
            repara_size
        ) if repara else None

        if repara:
            edge_size = (rnn_hidden_size // 8) * n_system
            self.edge_rnn = nn.GRU(
                input_size=gnn_hidden_size * n_system + input_size,
                hidden_size=edge_size,
                num_layers=3,
                batch_first=True,
                bidirectional=True
            )
            self.edge_linear = nn.Linear(edge_size * 2 // n_system, 2)


    def assemble_atom_pair_feature(self, node_attr: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]
        return h_row*h_col

    def forward(self, x: Tensor, t: Tensor, edge: Tensor, energy: Tensor) -> Tensor:
        # t.shape == (batch_size, )
        # x.shape == (batch_size, n_time, input_size)
        # edge.shape == (batch_size, n_system, n_system)

        batch_size, n_time, input_size = x.shape
        assert input_size == self.n_system * 4
        n_system = self.n_system

        # x.reshape.shape == (batch_size, n_time, n_system, loc/vel, xy, n_system)
        # xxx_feature.shape == (batch_size, n_time, n_system, 2)
        
        loc_feature, vel_feature = x.reshape(batch_size, n_time, 2, 2, n_system).permute(3, 0, 1, 4, 2)
        loc_feature -= loc_feature.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)

        t_embed = self.t_embed(t).expand(n_time, n_system, batch_size, -1).permute(2, 0, 1, 3)

        edge_degree = self.node_embed(edge.sum(dim=-1).to(torch.long))
        edge_degree = edge_degree.expand(n_time, batch_size, n_system, -1).permute(1, 0, 2, 3)
        node_feature = torch.concatenate([vel_feature, edge_degree, t_embed], dim=-1)

        time_edge = edge.expand(n_time, *edge.shape).permute(1, 0, 2, 3).reshape(-1, n_system, n_system)
        edge_index, edge_type = dense_to_sparse(adj=time_edge)
        edge_type = edge_type.unsqueeze(dim=-1)

        gnn_features = self.gnn(
            h=node_feature.reshape(-1, node_feature.size(-1)),
            x=loc_feature.reshape(-1, loc_feature.size(-1)),
            edges=edge_index,
            edge_attr=edge_type
        )[0].reshape(batch_size, n_time, -1)

        rnn_features = torch.concatenate([x, self.rnn(torch.concatenate([gnn_features, x], dim=-1))[0]], dim=-1)

        if self.phy_out:
            node_energy = self.edge_rnn(torch.concatenate([gnn_features, x], dim=-1))[0]
            node_energy = node_energy.reshape(batch_size * n_time * n_system, -1)
            edge_energy = self.edge_linear(node_energy[edge_index[0]] * node_energy[edge_index[1]])

            return self.read_out(rnn_features), \
                    self.phy_out(rnn_features).reshape(batch_size, n_time, n_system, 2), \
                    edge_energy
        else:
            return self.read_out(rnn_features)
