import torch.nn as nn
import torch
import torch.nn.functional as F
from base_operations import TensorMaxPool, TensorReLU, TensorBatchNorm, TopkPooling 
from base_operations import TensorConvLayer2d
from base_operations import TensorBatchNorm2d, TensorBatchNorm3d, TensorBatchNorm4d


class tcnn4(nn.Module):
    def __init__(self, device):
        super(tcnn4, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, in_channel=1, out_channel=4, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81*4),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(in_features=81*4, out_features=10, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        #output = F.log_softmax(xc, dim=1)
        return xc

class TCNN3_3(nn.Module):
    def __init__(self, device):
        super(TCNN3_3, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayerInput(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[6, 6, 6, 6], compress_tensor=True),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm(1),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, in_channel=1, out_channel=10, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6], compress_tensor=True),
            TensorBatchNorm(10),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=1, end_dim=- 1)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        x = self.flatten(x4)
        output = F.log_softmax(x, dim=1)
        return output


class new_TCNN(nn.Module):
    def __init__(self, device):
        super(new_TCNN, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayerInput(device, kh=3, kw=3, stride_w=3, stride_h=3, in_channel=2, out_channel=2, input_tensor=[3, 1], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_h=3, stride_w=3, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_w=3, stride_h=3, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=2, out_channel=2, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3]),
            TensorBatchNorm(2),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, in_channel=2, out_channel=4, input_tensor=[3, 3, 3, 3], output_tensor=[3, 3, 3, 3], compress_tensor=False),
            TensorBatchNorm(4),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=324, out_features=10, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        output = F.log_softmax(xc, dim=1)
        return output

class new_TCNN3_1(nn.Module):
    def __init__(self, device):
        super(new_TCNN3_1, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayerInput(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[6, 6, 6, 6], compress_tensor=True),
            TensorBatchNorm4d(1296),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6]),
            TensorBatchNorm4d(1296),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, in_channel=1, out_channel=1, input_tensor=[6, 6, 6, 6], output_tensor=[6, 6, 6, 6], compress_tensor=False),
            TensorBatchNorm4d(1296),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=1296, out_features=10, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        output = F.log_softmax(xc, dim=1)
        return output

class new_TCNN3_2(nn.Module):
    def __init__(self, device):
        super(new_TCNN3_2, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayerInput(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[8, 8], compress_tensor=True),
            TensorBatchNorm2d(64),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8]),
            TensorBatchNorm2d(64),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[8, 8], output_tensor=[8, 8, 8], out_type=1),
            TensorBatchNorm3d(512),
            nn.PReLU(),

            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
            nn.PReLU(),
            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
            nn.PReLU(),
            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
            nn.PReLU(),
            TensorConvLayer3d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8]),
            TensorBatchNorm3d(512),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer3d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8]),
            TensorBatchNorm4d(4096),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, in_channel=1, out_channel=1, input_tensor=[8, 8, 8, 8], output_tensor=[8, 8, 8, 8], compress_tensor=False),
            TensorBatchNorm4d(4096),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            #nn.Dropout(p=0.2, inplace=True),
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=4096, out_features=10, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        output = F.log_softmax(xc, dim=1)
        return output
        #return xc

class new_TCNN3_4(nn.Module):
    def __init__(self, device):
        super(new_TCNN3_4, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayerInput(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[7, 7, 7, 7], compress_tensor=True),
            TensorBatchNorm4d(2401),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),

            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
            TensorConvLayer4d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7]),
            TensorBatchNorm4d(2401),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer4d(device, kh=3, kw=3, in_channel=1, out_channel=1, input_tensor=[7, 7, 7, 7], output_tensor=[7, 7, 7, 7], compress_tensor=False),
            TensorBatchNorm4d(2401),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=2401, out_features=10, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        output = F.log_softmax(xc, dim=1)
        return output
