import torch
import torch.nn as nn
from torch_scatter import scatter
import torch.nn.functional as F


class BaseMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, activation, residual=False, last_act=False, flat=False):
        super(BaseMLP, self).__init__()
        self.residual = residual
        if flat:
            activation = nn.Tanh()
            hidden_dim = 4 * hidden_dim
        if residual:
            assert output_dim == input_dim
        if last_act:
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                activation,
                nn.Linear(hidden_dim, hidden_dim),
                activation,
                nn.Linear(hidden_dim, output_dim),
                activation
            )
        else:
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                activation,
                nn.Linear(hidden_dim, hidden_dim),
                activation,
                nn.Linear(hidden_dim, output_dim)
            )

    def forward(self, x):
        return self.mlp(x) if not self.residual else self.mlp(x) + x


class GNNMessagePassingLayer(nn.Module):
    def __init__(self, x_dim, edge_dim, hidden_dim, activation):
        super(GNNMessagePassingLayer, self).__init__()
        self.x_dim, self.edge_dim = x_dim, edge_dim
        self.hidden_dim = hidden_dim
        self.activation = activation
        self.net = BaseMLP(input_dim=2 * x_dim + edge_dim,
                           hidden_dim=hidden_dim,
                           output_dim=hidden_dim,
                           activation=activation,
                           residual=False,
                           last_act=False,
                           flat=False)
        self.self_net = BaseMLP(input_dim=2 * hidden_dim,
                                hidden_dim=hidden_dim,
                                output_dim=hidden_dim,
                                activation=activation,
                                residual=False,
                                last_act=False,
                                flat=False)

    def forward(self, x, edge_index, edge_attr):
        edge_info = torch.cat((x[edge_index[0]], x[edge_index[1]], edge_attr), dim=-1)
        edge_info = self.net(edge_info)
        message = scatter(edge_info, edge_index[0], dim=0, reduce='mean', dim_size=x.shape[0])  # [N, H]
        message = torch.cat((x, message), dim=-1)
        out = self.self_net(message)
        return out


class GNS(nn.Module):
    def __init__(self, n_layer, p_step, s_dim, hidden_dim=128, activation=nn.SiLU(), cutoff=0.10, gravity_axis=None):
        super(GNS, self).__init__()
        self.cutoff = cutoff
        # initialize the networks
        # self.embedding = nn.Linear(s_dim + 6, hidden_dim)
        self.quat_offset = torch.FloatTensor([1., 0., 0., 0.]).cuda()
        self.dt = 0.01
        self.embedding = nn.Sequential(
            nn.Linear(s_dim + 9 + 9, hidden_dim),
            activation,
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.n_layer = n_layer
        self.p_step = p_step
        self.local_interaction = GNNMessagePassingLayer(x_dim=hidden_dim, edge_dim=4, hidden_dim=hidden_dim,activation=activation)
        self.readout = BaseMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=6, activation=activation)

    def forward(self, x_p, v_p, force_p, h_p, edge_index_inner, edge_index_inter, obj_id):
    # def forward(self, xyz,v_xyz):
        # input embedding
        h_p = torch.cat((h_p, torch.zeros_like(x_p), v_p, force_p), dim=-1)  # [N, 3+3+H]
      
        f_p = torch.cat((x_p, v_p, force_p), dim=-1)  # [N, 9]
        f_o = scatter(f_p, obj_id, dim=0, reduce='mean')  # [N_obj, 9]
        h_p = torch.cat((h_p, f_p - f_o[obj_id]), dim=-1)  # [N, 3+3+H+6]

        s_p = self.embedding(h_p)  # [N, H]

        edge_index = torch.cat((edge_index_inner, edge_index_inter), dim=1)
        edge_attr = x_p[edge_index[0]] - x_p[edge_index[1]]
        edge_attr = torch.cat((edge_attr, torch.norm(edge_attr, dim=-1, keepdim=True)), dim=-1)  # [M, 4]

        for i in range(self.p_step):
            s_p = self.local_interaction(s_p, edge_index, edge_attr)

        sum_out = self.readout(s_p)
        pos = sum_out[:,0:3]
        vel = sum_out[:,3:6]

        # v_out = xyz,v_xyz,,F_xyz,G

        return pos, vel # xyz,v_xyz,T,F_xyz,G
    
    
class ForcePredictionLayer(nn.Module):
    def __init__(self, n_layer, p_step, s_dim, hidden_dim=128, activation=nn.SiLU(), cutoff=0.10):
        super(ForcePredictionLayer, self).__init__()
        self.cutoff = cutoff
        # initialize the networks
        # self.embedding = nn.Linear(s_dim + 6, hidden_dim)
        self.quat_offset = torch.FloatTensor([1., 0., 0., 0.]).cuda()
        self.dt = 0.01
        self.embedding = nn.Sequential(
            nn.Linear(s_dim + 6 + 6, hidden_dim),
            activation,
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.n_layer = n_layer
        self.p_step = p_step
        self.local_interaction = GNNMessagePassingLayer(x_dim=hidden_dim, edge_dim=4, hidden_dim=hidden_dim,activation=activation)
        self.readout = BaseMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=3, activation=activation)

    def forward(self, x_p, v_p, h_p, edge_index_inner, edge_index_inter, obj_id):
   
        h_p = torch.cat((h_p, torch.zeros_like(x_p), v_p), dim=-1)  # [N, 3+3+H]

        f_p = torch.cat((x_p, v_p), dim=-1)  # [N, 6]
        f_o = scatter(f_p, obj_id, dim=0, reduce='mean')  # [N_obj, 9]
        h_p = torch.cat((h_p, f_p - f_o[obj_id]), dim=-1)  # [N, 3+3+H+9]

        s_p = self.embedding(h_p)  # [N, H]

        edge_index = torch.cat((edge_index_inner, edge_index_inter), dim=1)
        edge_attr = x_p[edge_index[0]] - x_p[edge_index[1]]
        edge_attr = torch.cat((edge_attr, torch.norm(edge_attr, dim=-1, keepdim=True)), dim=-1)  # [M, 4]

        for i in range(self.p_step):
            s_p = self.local_interaction(s_p, edge_index, edge_attr)

        sum_out = self.readout(s_p)
        force_out = sum_out

        # v_out = xyz,v_xyz,,F_xyz,G

        return force_out # xyz,v_xyz,T,F_xyz,G

class ConForcePredictionLayer(nn.Module):
    def __init__(self, n_layer, p_step, s_dim, hidden_dim=128, activation=nn.SiLU(), cutoff=0.10):
        super(ConForcePredictionLayer, self).__init__()
        self.cutoff = cutoff
        # initialize the networks
        # self.embedding = nn.Linear(s_dim + 6, hidden_dim)
        self.quat_offset = torch.FloatTensor([1., 0., 0., 0.]).cuda()
        self.dt = 0.01
        self.embedding = nn.Sequential(
            nn.Linear(s_dim + 6 + 6, hidden_dim),
            activation,
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.n_layer = n_layer
        self.p_step = p_step
        self.local_interaction = GNNMessagePassingLayer(x_dim=hidden_dim, edge_dim=4, hidden_dim=hidden_dim,activation=activation)
        self.readout = BaseMLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=3, activation=activation)

    def forward(self, x_p, v_p, h_p, edge_con_force_index_inner, edge_con_forceindex_inter, objcon_id):
   
        h_p = torch.cat((h_p, torch.zeros_like(x_p), v_p), dim=-1)  # [N, 3+3+H]

        f_p = torch.cat((x_p, v_p), dim=-1)  # [N, 6]
        f_o = scatter(f_p, objcon_id, dim=0, reduce='mean')  # [N_obj, 9]
        h_p = torch.cat((h_p, f_p - f_o[objcon_id]), dim=-1)  # [N, 3+3+H+9]

        s_p = self.embedding(h_p)  # [N, H]

        edge_index = torch.cat((edge_con_force_index_inner, edge_con_forceindex_inter), dim=1)
        edge_attr = x_p[edge_index[0]] - x_p[edge_index[1]]
        edge_attr = torch.cat((edge_attr, torch.norm(edge_attr, dim=-1, keepdim=True)), dim=-1)  # [M, 4]

        for i in range(self.p_step):
            s_p = self.local_interaction(s_p, edge_index, edge_attr)

        sum_out = self.readout(s_p)
        force_out = sum_out

        # v_out = xyz,v_xyz,,F_xyz,G

        return force_out # xyz,v_xyz,T,F_xyz,G

