# -*- coding: utf-8 -*-
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATv2Conv, GATConv
from torch_geometric.data import Data


class BaseGNN(torch.nn.Module):
    def __init__(self, dropout, num_layers):
        super(BaseGNN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.dropout = dropout
        self.num_layers = num_layers

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        """
            edges: 2 x num_edges
        """
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        if self.num_layers == 1:
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x
    


class SAGE(BaseGNN):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(SAGE, self).__init__(dropout, num_layers)
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.convs.append(SAGEConv(first_channels, second_channels))


class GCN(BaseGNN):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GCN, self).__init__(dropout, num_layers)
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.convs.append(GCNConv(first_channels, second_channels, normalize=False))


class GATv2(BaseGNN):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GATv2, self).__init__(dropout, num_layers)
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.convs.append(GATv2Conv(first_channels, second_channels))

class GAT(BaseGNN):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GAT, self).__init__(dropout, num_layers)
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.convs.append(GATConv(first_channels, second_channels))

class MLPPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(MLPPredictor, self).__init__()
        self.lins = torch.nn.ModuleList()
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.lins.append(torch.nn.Linear(first_channels, second_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x):
        # x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return x


# class MLPCatPredictor(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
#         super(MLPCatPredictor, self).__init__()
#         self.lins = torch.nn.ModuleList()
#         in_channels = 2 * in_channels
#         for i in range(num_layers):
#             first_channels = in_channels if i == 0 else hidden_channels
#             second_channels = out_channels if i == num_layers - 1 else hidden_channels
#             self.lins.append(torch.nn.Linear(first_channels, second_channels))
#         self.dropout = dropout

#     def reset_parameters(self):
#         for lin in self.lins:
#             lin.reset_parameters()

#     def forward(self, x_i, x_j):
#         x1 = torch.cat([x_i, x_j], dim=-1)
#         x2 = torch.cat([x_j, x_i], dim=-1)
#         for lin in self.lins[:-1]:
#             x1, x2 = lin(x1), lin(x2)
#             x1, x2 = F.relu(x1), F.relu(x2)
#             x1 = F.dropout(x1, p=self.dropout, training=self.training)
#             x2 = F.dropout(x2, p=self.dropout, training=self.training)
#         x1 = self.lins[-1](x1)
#         x2 = self.lins[-1](x2)
#         x = (x1 + x2)/2
#         return x


# class MLPDotPredictor(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, num_layers, dropout):
#         super(MLPDotPredictor, self).__init__()
#         self.lins = torch.nn.ModuleList()
#         self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
#         for _ in range(num_layers - 1):
#             self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
#         self.dropout = dropout

#     def reset_parameters(self):
#         for lin in self.lins:
#             lin.reset_parameters()

#     def forward(self, x_i, x_j):
#         for lin in self.lins:
#             x_i, x_j = lin(x_i), lin(x_j)
#             x_i, x_j = F.relu(x_i), F.relu(x_j)
#             x_i, x_j = F.dropout(x_i, p=self.dropout, training=self.training), \
#                 F.dropout(x_j, p=self.dropout, training=self.training)
#         x = torch.sum(x_i * x_j, dim=-1)
#         return x


# class MLPBilPredictor(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, num_layers, dropout):
#         super(MLPBilPredictor, self).__init__()
#         self.lins = torch.nn.ModuleList()
#         self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
#         for _ in range(num_layers - 1):
#             self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
#         self.bilin = torch.nn.Linear(hidden_channels, hidden_channels, bias=False)
#         self.dropout = dropout

#     def reset_parameters(self):
#         for lin in self.lins:
#             lin.reset_parameters()
#         self.bilin.reset_parameters()

#     def forward(self, x_i, x_j):
#         for lin in self.lins:
#             x_i, x_j = lin(x_i), lin(x_j)
#             x_i, x_j = F.relu(x_i), F.relu(x_j)
#             x_i, x_j = F.dropout(x_i, p=self.dropout, training=self.training), \
#                 F.dropout(x_j, p=self.dropout, training=self.training)
#         x = torch.sum(self.bilin(x_i) * x_j, dim=-1)
#         return x


# class DotPredictor(torch.nn.Module):
#     def __init__(self):
#         super(DotPredictor, self).__init__()

#     def reset_parameters(self):
#         return

#     def forward(self, x_i, x_j):
#         x = torch.sum(x_i * x_j, dim=-1)
#         return x


# class BilinearPredictor(torch.nn.Module):
#     def __init__(self, hidden_channels):
#         super(BilinearPredictor, self).__init__()
#         self.bilin = torch.nn.Linear(hidden_channels, hidden_channels, bias=False)

#     def reset_parameters(self):
#         self.bilin.reset_parameters()

#     def forward(self, x_i, x_j):
#         x = torch.sum(self.bilin(x_i) * x_j, dim=-1)
#         return x

class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z):
        """
            N x M x dim
        """
        w = self.project(z)  # (N, M, 1)
        beta = torch.softmax(w, dim=1)  # (N, M, 1)
        # beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)
        out = (beta * z).sum(1)  # (N, M)

        return out
    
    def reset_parameters(self):
        for lin in self.project:
            if isinstance(lin, nn.Linear):
                lin.reset_parameters()
        return self

class ConcatFuse(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, num_layers=3, dropout=.5):
        super(ConcatFuse, self).__init__()
        self.lins = torch.nn.ModuleList()
        out_channels = in_channels
        in_channels = 2 * in_channels
        for i in range(num_layers):
            first_channels = in_channels if i == 0 else hidden_channels
            second_channels = out_channels if i == num_layers - 1 else hidden_channels
            self.lins.append(torch.nn.Linear(first_channels, second_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x1 = torch.cat([x_i, x_j], dim=-1)
        for lin in self.lins[:-1]:
            x1 = lin(x1)
            x1 = F.relu(x1)
            x1 = F.dropout(x1, p=self.dropout, training=self.training)
        x1 = self.lins[-1](x1)
        return x1