import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.glob import global_add_pool, global_mean_pool, global_max_pool

from typing import List, Optional, Union, Callable
from torch import Tensor
from torch_scatter import scatter
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.nn.models import MLP
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_sparse import SparseTensor, matmul
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_geometric.nn.inits import reset


class GINConv(MessagePassing):
    def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        self.eps.data.fill_(self.initial_eps)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None, edge_weight: OptTensor = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, size=size, edge_weight=edge_weight)

        x_r = x[1]
        if x_r is not None:
            out += (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j: Tensor, edge_weight: OptTensor = None) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor,
                              x: OptPairTensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x[0], reduce=self.aggr)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(nn={self.nn})'

    
class GIN(torch.nn.Module):
    def __init__(self, in_dim, num_classes, num_layers, hidden, dropout=0.5, pool_type='mean'):
        super(GIN, self).__init__()
        num_features = in_dim
        self.conv1 = GINConv(Sequential(
            Linear(num_features, hidden),
            ReLU(),
            Linear(hidden, hidden),
            ReLU(),
            BN(hidden),
        ),
            train_eps=True)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(Sequential(
                    Linear(hidden, hidden),
                    ReLU(),
                    Linear(hidden, hidden),
                    ReLU(),
                    BN(hidden),
                ),
                    train_eps=True))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_classes)
        # self.lin2 = Linear(hidden, 1)
        self.dropout = dropout

        if pool_type == 'mean':
            self.pool = global_mean_pool
        elif pool_type == 'sum':
            self.pool = global_add_pool
        elif pool_type == 'max':
            self.pool = global_max_pool

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_weights=None):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        if edge_weights is None:
            if 'edge_weights' in data:
                edge_weight = data.edge_weights.float()
            else:
                edge_weight = None
        else: 
            edge_weight = edge_weights.float()
        x = self.conv1(x, edge_index, edge_weight = edge_weight)
        for conv in self.convs:
            x = conv(x, edge_index, edge_weight = edge_weight)
        node_emb = x
        h = self.pool(x, batch)

        x = F.relu(self.lin1(h))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        # return F.log_softmax(x, dim=-1)
        return h, x, node_emb
    
    def M_mixup_forward(self, data1, data2, lambd):
        x1, edge_index1, batch1 = data1.x, data1.edge_index, data1.batch
        x2, edge_index2, batch2 = data2.x, data2.edge_index, data2.batch

        x1 = self.conv1(x1, edge_index1)
        for conv in self.convs:
            x1 = conv(x1, edge_index1)
        embed1 = self.pool(x1, batch1)

        x2 = self.conv1(x2, edge_index2)
        for conv in self.convs:
            x2 = conv(x2, edge_index2)
        embed2 = self.pool(x2, batch2)

        mixup_embed = lambd * embed1 + (1.0 - lambd) * embed2
        embed = F.relu(self.lin1(mixup_embed))
        embed = F.dropout(embed, p=self.dropout, training=self.training)
        embed = self.lin2(embed)
        # return F.log_softmax(embed, dim=-1)
        return mixup_embed, embed

    def __repr__(self):
        return self.__class__.__name__

