import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import OptPairTensor
from ChebnetII_pro import ChebnetII_prop, Bern_prop, GPR_prop


def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)


class WGINConv(MessagePassing):
	def __init__(self, nn, eps: float = 0., train_eps: bool = False, **kwargs):
		kwargs.setdefault('aggr', 'add')
		super(WGINConv, self).__init__(**kwargs)
		self.nn = nn
		self.initial_eps = eps
		if train_eps:
			self.eps = torch.nn.Parameter(torch.Tensor([eps]))
		else:
			self.register_buffer('eps', torch.Tensor([eps]))
		self.reset_parameters()

	def reset_parameters(self):
		reset(self.nn)
		self.eps.data.fill_(self.initial_eps)

	def forward(self, x, edge_index, edge_weight = None, size=None):
		if isinstance(x, Tensor):
			x: OptPairTensor = (x, x)
		out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
		x_r = x[1]
		if x_r is not None:
			out += (1 + self.eps) * x_r
		return self.nn(out)

	def message(self, x_j, edge_weight):
		return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)

	def __repr__(self):
		return '{}(nn={})'.format(self.__class__.__name__, self.nn)
      

class GIN(torch.nn.Module):
    def __init__(self, num_features, dim, num_gc_layers, device):
        super(GIN, self).__init__()
        self.num_gc_layers = num_gc_layers
        self.device = device
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        for i in range(num_gc_layers):
            if i:
                nn = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
            else:
                nn = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
            conv = GINConv(nn)
            bn = torch.nn.BatchNorm1d(dim)
            self.convs.append(conv)
            self.bns.append(bn)

    def forward(self, x, edge_index, batch):
        if x is None:
            x = torch.ones((batch.shape[0], 1)).to(self.device)
        xs = []
        for i in range(self.num_gc_layers):
            x = F.relu(self.convs[i](x, edge_index))
            x = self.bns[i](x)
            xs.append(x)
        xpool = [global_add_pool(x, batch) for x in xs]
        x = torch.cat(xpool, 1)
        return x, torch.cat(xs, 1)

    def get_embeddings(self, loader):
        ret = []
        y = []
        with torch.no_grad():
            for data in loader:
                data = data[0]
                data.to(self.device)
                x, edge_index, batch = data.x, data.edge_index, data.batch
                if x is None:
                    x = torch.ones((batch.shape[0],1)).to(self.device)
                x, _ = self.forward(x, edge_index, batch)
                ret.append(x.cpu().numpy())
                y.append(data.y.cpu().numpy())
        ret = np.concatenate(ret, 0)
        y = np.concatenate(y, 0)
        return ret, y
    


class ChebNetII_V2(torch.nn.Module):
    def __init__(self, num_node_features, args, device):
        super(ChebNetII_V2, self).__init__()
        self.device = device
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.K = args.K
        self.use_bn = args.use_bn

        self.bn = nn.BatchNorm1d(num_node_features)
        self.prop1 = ChebnetII_prop(args.K)

    def reset_parameters(self): 
        self.prop1.reset_parameters()
        self.bn.reset_parameters()
    
    def get_embeddings(self, x, edge_index):
        return self(x=x, edge_index=edge_index)

    def forward(self, x, edge_index, batch):
        if x is None:
            x = torch.ones((batch.shape[0], 1)).to(self.device)
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        
        xs = self.prop1(x, edge_index)
        xpool = []
        for x in xs:
            x = F.relu(x)
            if self.use_bn:
                x = self.bn(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xpool.append(global_add_pool(x, batch))
        return torch.cat(xpool, 1), torch.cat(xs, 1)
    
    def get_embeddings(self, loader):
        ret = []
        y = []
        with torch.no_grad():
            for data in loader:
                data = data[0]
                data.to(self.device)
                x, edge_index, batch = data.x, data.edge_index, data.batch
                if x is None:
                    x = torch.ones((batch.shape[0],1)).to(self.device)
                x, _ = self.forward(x, edge_index, batch)
                ret.append(x.cpu().numpy())
                y.append(data.y.cpu().numpy())
        ret = np.concatenate(ret, 0)
        y = np.concatenate(y, 0)
        return ret, y
    


class BernNet_V2(torch.nn.Module):
    def __init__(self, num_node_features, args, device):
        super(BernNet_V2, self).__init__()
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.use_bn = args.use_bn
        self.device = device

        self.prop1 = Bern_prop(args.K)
        self.bn = nn.BatchNorm1d(num_node_features)
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.bn.reset_parameters()

    def get_embeddings(self, loader):
        ret = []
        y = []
        with torch.no_grad():
            for data in loader:
                data = data[0]
                data.to(self.device)
                x, edge_index, batch = data.x, data.edge_index, data.batch
                if x is None:
                    x = torch.ones((batch.shape[0],1)).to(self.device)
                x, _ = self.forward(x, edge_index, batch)
                ret.append(x.cpu().numpy())
                y.append(data.y.cpu().numpy())
        ret = np.concatenate(ret, 0)
        y = np.concatenate(y, 0)
        return ret, y
    
    def forward(self, x, edge_index, batch):
        if x is None:
            x = torch.ones((batch.shape[0], 1)).to(self.device)
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        
        xs = self.prop1(x, edge_index)
        xpool = []
        for x in xs:
            x = F.relu(x)
            if self.use_bn:
                x = self.bn(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xpool.append(global_add_pool(x, batch))
        return torch.cat(xpool, 1), torch.cat(xs, 1)
    

class GPRGNN_V2(torch.nn.Module):
    def __init__(self, num_node_features, args, device):
        super(GPRGNN_V2, self).__init__()
        self.Init = args.Init
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.use_bn = args.use_bn
        self.device = device

        self.prop1 = GPR_prop(args.K, args.alpha, args.Init)
        self.bn = nn.BatchNorm1d(num_node_features)
        self.reset_parameters()

    def reset_parameters(self):
        self.prop1.reset_parameters()
        self.bn.reset_parameters()

    def get_embeddings(self, loader):
        ret = []
        y = []
        with torch.no_grad():
            for data in loader:
                data = data[0]
                data.to(self.device)
                x, edge_index, batch = data.x, data.edge_index, data.batch
                if x is None:
                    x = torch.ones((batch.shape[0],1)).to(self.device)
                x, _ = self.forward(x, edge_index, batch)
                ret.append(x.cpu().numpy())
                y.append(data.y.cpu().numpy())
        ret = np.concatenate(ret, 0)
        y = np.concatenate(y, 0)
        return ret, y
    
    def forward(self, x, edge_index, batch):
        if x is None:
            x = torch.ones((batch.shape[0], 1)).to(self.device)
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        
        xs = self.prop1(x, edge_index)
        xpool = []
        for x in xs:
            x = F.relu(x)
            if self.use_bn:
                x = self.bn(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xpool.append(global_add_pool(x, batch))
        return torch.cat(xpool, 1), torch.cat(xs, 1)
    