import torch.nn as nn
import torch
import torch.nn.functional as F
from base_operations import TensorConvLayer2d
from base_operations import Squash, TensorReLU
from base_operations import TensorBatchNorm2d, TensorBatchNorm3d, TensorBatchNorm4d


class tcnn4(nn.Module):
    def __init__(self, device):
        super(tcnn4, self).__init__()

        self.features1 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[3, 1], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.block1 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.features2 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_h=2, stride_w=2, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.block2 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.features3 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, stride_w=2, stride_h=2, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),

            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )
        self.block3 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
            nn.PReLU(),
            TensorConvLayer2d(device, kh=3, kw=3, pad_w=1, pad_h=1, stride_w=1, stride_h=1, in_channel=1, out_channel=1, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81),
        )

        self.features4 = nn.Sequential(
            TensorConvLayer2d(device, kh=3, kw=3, in_channel=1, out_channel=1*4, input_tensor=[9, 9], output_tensor=[9, 9], out_type=2),
            TensorBatchNorm2d(81*4),
            nn.PReLU(),
        )

        self.flatten = nn.Flatten(start_dim=4, end_dim=-1)
        self.flatten1 = nn.Flatten(start_dim=1, end_dim=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(in_features=81*4, out_features=100, bias=True)
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = x.unsqueeze(1) 
        x = x.unsqueeze(5)
        x1 = self.features1(x)
        b1 = self.block1(x1)
        x2 = self.features2(x1+b1)
        b2 = self.block2(x2)
        x3 = self.features3(x2+b2)
        b3 = self.block3(x3)
        x4 = self.features4(x3+b3)
        xf = self.flatten(x4)
        xf = torch.transpose(xf, 2, 4)
        xf1 = self.flatten1(xf)
        xa = self.avgpool(xf1)
        xa = torch.squeeze(xa)
        xc = self.classifier(xa)
        return xc

