import torch
import torch.nn as nn


class conv(nn.Module):
    def __init__(self, nin, nout):
        super(conv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nin, nout, 4, 2, 1),
            nn.BatchNorm2d(nout),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        return self.net(input)


class encConv(nn.Module):
    def __init__(self, args):
        super(encConv, self).__init__()

        self.args = args

        self.n_frames = args.n_frames
        self.n_channels = args.n_channels
        self.n_height = args.n_height
        self.n_width = args.n_width
        self.conv_dim = args.conv_dim
        self.k_dim = args.k_dim

        self.c1 = conv(self.n_channels, self.conv_dim)
        self.c2 = conv(self.conv_dim, self.conv_dim * 2)
        self.c3 = conv(self.conv_dim * 2, self.conv_dim * 4)
        self.c4 = conv(self.conv_dim * 4, self.conv_dim * 8)
        self.c5 = nn.Sequential(
            nn.Conv2d(self.conv_dim * 8, self.k_dim, 4, 1, 0),
            nn.BatchNorm2d(self.k_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.reshape(-1, self.n_channels, self.n_height, self.n_width)
        h1 = self.c1(x)
        h2 = self.c2(h1)
        h3 = self.c3(h2)
        h4 = self.c4(h3)
        h5 = self.c5(h4)

        return h5


class Classifier_CNN(nn.Module):
    def __init__(self, opt):
        super(Classifier_CNN, self).__init__()
        self.g_dim = opt.k_dim  # frame feature
        self.channels = opt.n_channels  # frame feature
        self.hidden_dim = opt.hidden_dim
        self.frames = opt.n_frames
        self.encoder = encConv(opt)
        self.bilstm = nn.LSTM(self.g_dim, self.hidden_dim, 1, bidirectional=True, batch_first=True)
        self.cls_skin = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(True),
            nn.Linear(self.hidden_dim, 6))
        self.cls_top = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(True),
            nn.Linear(self.hidden_dim, 6))
        self.cls_pant = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(True),
            nn.Linear(self.hidden_dim, 6))
        self.cls_hair = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(True),
            nn.Linear(self.hidden_dim, 6))
        self.cls_action = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(True),
            nn.Linear(self.hidden_dim, 9))

    def encoder_frame(self, x):
        # input x is list of length Frames [batchsize, channels, size, size]
        # convert it to [batchsize, frames, channels, size, size]
        # x = torch.stack(x, dim=1)
        # [batch_size, frames, channels, size, size] to [batch_size * frames, channels, size, size]
        x_shape = x.shape
        x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
        x_embed = self.encoder(x)
        # to [batch_size , frames, embed_dim]
        return x_embed.view(x_shape[0], x_shape[1], -1)

    def forward(self, x):
        conv_x = self.encoder_frame(x)
        # pass the bidirectional lstm
        lstm_out, _ = self.bilstm(conv_x)
        backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
        frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
        lstm_out_f = torch.cat((frontal, backward), dim=1)
        return self.cls_action(lstm_out_f), self.cls_skin(lstm_out_f), self.cls_pant(lstm_out_f), \
               self.cls_top(lstm_out_f), self.cls_hair(lstm_out_f)

