import torch.nn as nn
import torch
import torch.nn.functional as F

from backbones import Linear_fw, Conv2d_fw, BatchNorm2d_fw


class MLP2(nn.Module):
    """
    this net is used for fitting sine function
    """
    def __init__(self, in_dim=1, out_dim=1, hid_dim=40):
        super(MLP2, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hid_dim = hid_dim

        self.feature_backbone = nn.Sequential(
            Linear_fw(self.in_dim, self.hid_dim),
            nn.ReLU(),
            Linear_fw(self.hid_dim, self.hid_dim),
            nn.ReLU(),
        )
        self.output_layer = Linear_fw(self.hid_dim, out_dim)

    def forward(self, x):
        raise ValueError("no implement forward, use forward_feature or forward_pred")

    def forward_feature(self, x):
        return self.feature_backbone(x)

    def forward_pred(self, x):
        feature = self.forward_feature(x)
        return self.output_layer(feature)


class Conv3(nn.Module):
    """
     used in QMUL, regression
    """
    def __init__(self, embedding_dim=2916):
        super(Conv3, self).__init__()
        hidden_dim = 36
        self.layer1 = Conv2d_fw(3, hidden_dim, 3, stride=2, dilation=2)
        self.layer2 = Conv2d_fw(hidden_dim, hidden_dim, 3, stride=2, dilation=2)
        self.layer3 = Conv2d_fw(hidden_dim, hidden_dim, 3, stride=2, dilation=2)
        self.out_layer = Linear_fw(embedding_dim, 1)

    def forward(self, x):
        raise ValueError("no implement forward, use forward_feature or forward_pred")

    def forward_pred(self, x):
        """
        predict the label with output layer
        :param x:
        :return:
        """
        out = self.forward_feature(x)
        return self.out_layer(out)

    def forward_feature(self, x):
        """
        extract the feature from the last hidden layer
        :param x:
        :return:
        """
        out = F.relu(self.layer1(x))
        out = F.relu(self.layer2(out))
        out = F.relu(self.layer3(out))
        out = out.view(out.shape[0], -1)
        return out

def build_conv_block(in_channels: int, out_channels: int, conv_kernel=3, padding=1, max_pool_kernel_size=2):
    return nn.Sequential(
        Conv2d_fw(in_channels, out_channels, kernel_size=conv_kernel, padding=padding),
        BatchNorm2d_fw(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=max_pool_kernel_size)
    )


class Conv4(nn.Module):
    def __init__(self, input_channel=3, embedding_dim=64, out_class_num=5):
        super(Conv4, self).__init__()
        hidden_dim = 64
        self.encoder = nn.Sequential(
            build_conv_block(input_channel, hidden_dim),
            build_conv_block(hidden_dim, hidden_dim),
            build_conv_block(hidden_dim, hidden_dim),
            build_conv_block(hidden_dim, hidden_dim)
        )
        self.output_feature_dim = embedding_dim
        self.out_layer = Linear_fw(embedding_dim, out_class_num)

    def forward(self, x):
        raise ValueError("no implement forward, use forward_feature or forward_pred")

    def forward_pred(self, x):
        feature = self.forward_feature(x, is_last_2=False)
        return self.out_layer(feature)

    def forward_feature_last_layer(self, x):
        """
        output embedding
        :param x:
        :return:
        """
        x = self.encoder(x)
        return x.view(x.size(0), -1)  # batch_size * feature_d

    def forward_feature(self, x, is_last_2=False):
        """
        :param x:
        :param is_last_2:
            * true, == r2d2
            * false, == maml
        :return:
        """
        if is_last_2:
            return self.concat_last_2(x)
        else:
            return self.forward_feature_last_layer(x)

    def concat_last_2_layers(self, x):
        """
        concat feature from the last 2 layers
        :param x:
        :return:
        """
        x1 = self.encoder[0:3](x).view(x.size(0), -1)
        x2 = self.encoder(x).view(x.size(0), -1)
        return torch.cat([x1, x2], dim=1)
