# -*- coding: utf-8 -*-
# @File : tcn.py
# @Author : 王军
# @Time : 2022/11/13 10:05
# @Software : PyCharm
import torch
import torch.nn as nn
from tadconv import TAdFeatureCNN
from space_gcn import Mish
from tadconv import TAdFeatureCNN
class Hight_order_Gate(nn.Module):
    def __init__(self,in_channels,out_channels,
                 kernel_size=3,padding=1,dilation=1):
        super(Hight_order_Gate, self).__init__()
        self.filters = nn.ModuleList()
        self.gates = nn.ModuleList()
        self.filters = nn.Sequential(
                    nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=(1, kernel_size),
                              padding=(0, padding)),
                    nn.Tanh()
                )

        self.gates = nn.Sequential(
                   nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=(1, kernel_size),
                             padding=(0,padding)),

                    nn.Sigmoid()
                )
        self.dy_cnn = nn.Sequential(
            TAdFeatureCNN(in_channels=out_channels,
                          out_channels=out_channels,
                          kernel_size=(1, 2),
                          dilation=(1, dilation)
                          ),
                        Mish()
        )

    def forward(self, x):
        x_f = x.clone()
        for i in range(2):
            x_g = self.gates(x_f)
            x_f = x_g * self.gates(x)
        return self.dy_cnn(x_f + x)
