import torch as th
import torch.nn as nn
from torch.nn import LSTM

from dgl.nn import GATConv


class GeniePathConv(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):
        super(GeniePathConv, self).__init__()
        self.breadth_func = GATConv(
            in_dim, hid_dim, num_heads=num_heads, residual=residual
        )
        self.depth_func = LSTM(hid_dim, out_dim)

    def forward(self, graph, x, h, c):
        x = self.breadth_func(graph, x)
        x = th.tanh(x)
        x = th.mean(x, dim=1)
        x, (h, c) = self.depth_func(x.unsqueeze(0), (h, c))
        x = x[0]
        return x, (h, c)


class GeniePath(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        hid_dim=16,
        num_layers=2,
        num_heads=1,
        residual=False,
    ):
        super(GeniePath, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = nn.Linear(hid_dim, out_dim)
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(
                GeniePathConv(
                    hid_dim,
                    hid_dim,
                    hid_dim,
                    num_heads=num_heads,
                    residual=residual,
                )
            )

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        for layer in self.layers:
            x, (h, c) = layer(graph, x, h, c)
        x = self.linear2(x)

        return x


class GeniePathLazy(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        hid_dim=16,
        num_layers=2,
        num_heads=1,
        residual=False,
    ):
        super(GeniePathLazy, self).__init__()
        self.hid_dim = hid_dim
        self.linear1 = nn.Linear(in_dim, hid_dim)
        self.linear2 = th.nn.Linear(hid_dim, out_dim)
        self.breaths = nn.ModuleList()
        self.depths = nn.ModuleList()
        for i in range(num_layers):
            self.breaths.append(
                GATConv(
                    hid_dim, hid_dim, num_heads=num_heads, residual=residual
                )
            )
            self.depths.append(LSTM(hid_dim * 2, hid_dim))

    def forward(self, graph, x):
        h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
        c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)

        x = self.linear1(x)
        h_tmps = []
        for layer in self.breaths:
            h_tmps.append(th.mean(th.tanh(layer(graph, x)), dim=1))
        x = x.unsqueeze(0)
        for h_tmp, layer in zip(h_tmps, self.depths):
            in_cat = th.cat((h_tmp.unsqueeze(0), x), -1)
            x, (h, c) = layer(in_cat, (h, c))
        x = self.linear2(x[0])

        return x
