# -*- coding: utf-8 -*-
# @File : space_gcn.py
# @Author : 王军
# @Time : 2022/11/13 21:32
# @Software : PyCharm
import torch.nn.functional as F
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
class Mish(nn.Module):#Mish激活函数
    def __init__(self):
        super().__init__()
        #print("Mish activation loaded...")
    def forward(self,x):
        x = x * (torch.tanh(F.softplus(x)))
        return x

class linear(nn.Module):
    def __init__(self,c_in,c_out):
        super(linear,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True)

    def forward(self,x):
        return self.mlp(x)

class adpnconv(nn.Module):
    def __init__(self):
        super(adpnconv,self).__init__()

    def forward(self,x, A):
        if len(A.shape) == 2:
            x = torch.einsum('ncvl,vw->ncwl',(x,A))
        else:
            x = torch.einsum('bcvt,btvw->bcwt',(x,A))
        return x.contiguous()

class adp_gcn_conv(nn.Module):
    def __init__(self,c_in,c_out,dropout,support_len=3,order=2):
        super(adp_gcn_conv, self).__init__()
        self.nconv = adpnconv()
        c_in = (order * support_len + 1) * c_in
        self.mlp = linear(c_in, c_out)
        self.dropout = dropout
        self.order = order
    def forward(self,x,alpha,support):
        alpha = rearrange(alpha,'b t->b t 1 1')
        #support = [rearrange(adj,'s h->1 1 s h')*alpha for adj in support]
        bk = support[-1].clone()
        support[-1] = F.softmax(rearrange(support[-1], 's h->1 1 s h') * alpha,dim=-2)
        out = [x]
        for a in support:
            x1 = self.nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a)
                out.append(x2)
                x1 = x2
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = F.dropout(h, self.dropout, training=self.training)
        support[-1] = bk
        return h

class alpha_gcn(nn.Module):
    def __init__(self,in_channels):
        super(alpha_gcn, self).__init__()
        self.r = 2
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=int(in_channels/self.r),kernel_size=(1,1)),
            nn.ELU(),
            nn.BatchNorm2d(int(in_channels/self.r)),
            nn.Conv2d(in_channels=int(in_channels/self.r),
                      out_channels=1, kernel_size=(1, 1)),
            nn.AdaptiveAvgPool2d((1,None)),
            Rearrange('b 1 1 t->b t')
        )
    def forward(self,x):
        return self.model(x)

class adp_gcn(nn.Module):
    def __init__(self,c_in,c_out,dropout,support_len=3,order=2):
        super(adp_gcn, self).__init__()
        self.gcn_model = adp_gcn_conv(c_in,c_out,dropout,support_len,order)
        self.alpha_model = alpha_gcn(c_in)
    def forward(self,x,support):
        return self.gcn_model(x,self.alpha_model(x),support)

if __name__ == '__main__':
    model = adp_gcn(c_in=32,c_out=22,dropout=0.3,support_len=1,order=2)
    x = torch.randn((64, 32, 207, 12))
    support = [torch.randn(((207,207)))]
    y = model(x,support)
    print(y.shape)