import torch
import torch.nn as nn

from models.modules import ConvModule

def simple_conv_layer(
    in_channels, out_channels,
    kernel_size, stride, norm_layer=None):
    if norm_layer is None:
            norm_layer = nn.BatchNorm2d
    conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=3, bias=False)
    bn1 = norm_layer(out_channels)
    relu = nn.ReLU(inplace=True)
    return [conv1, bn1, relu]



class ConvEmbedding(nn.Module):
    """An embedding network which transforms the input data into a shared feature space.
       Args:
            structures (Tuple[Tuple[int, int, int]]): [[channel, kernel_size, stride,],...] 
    """
    def __init__(
        self, 
        structures):
        super().__init__()
        self.conv_layers = self.build_model(structures)

    def build_model(self, structures):
        conv_layers = []
        for i in range(len(structures)-1):
            kernel_size, stride = structures[i][1], structures[i][2]
            in_channels, out_channels = structures[i][0], structures[i + 1][0]
            conv_layer = simple_conv_layer(in_channels, out_channels, kernel_size, stride)
            conv_layers.extend(conv_layer)
        maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        conv_layers.append(maxpool)
        return nn.ModuleList(conv_layers)
    
    def load_from_resnet(self, resnet_model):
        conv_layers = [resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool]
        for self_layer, conv_layer in zip(self.conv_layers, conv_layers):
            print(self_layer, conv_layer)
            self_layer.load_state_dict(conv_layer.state_dict())
        

    def forward(self, x):
        for conv_layer in self.conv_layers:
            x = conv_layer(x)
        return x

class VideoConvEmbedding(ConvEmbedding):
    def forward(self, x):
        """Input:
            x [B, F, C, H, W]
        """
        B, F, C, H, W = x.size()
        x = x.reshape(B * F, C, H, W)
        for conv_layer in self.conv_layers:
            x = conv_layer(x)
        _, C, H, W = x.size()
        return x.reshape(B, F, C, H, W)