import torch
from torch import nn

import timm

class ConvEncoder(nn.Module):
    def __init__(self, num_channels=2, temporal_patches=4, pool_size=(None, 1, 1)):
        super(ConvEncoder, self).__init__()
        self.proj = nn.Sequential(nn.Conv3d(num_channels, 64, (temporal_patches, 8, 8), stride=(temporal_patches, 4, 4), padding=(0, 3, 3)), 
                                  nn.BatchNorm3d(64), 
                                  nn.ReLU())
        module = timm.create_model('resnet34', pretrained=False)
        layers = list(module.children())
        blocks = layers[5:-1]
        self.encoder = torch.nn.Sequential(*blocks[:3])
        self.avg_pool = nn.AdaptiveAvgPool3d(pool_size)
        self.num_features = blocks[2][-1].bn2.num_features
        self.out_features = pool_size[1] * pool_size[2] * self.num_features

    def forward(self, x, return_feature_maps=False):
        b = x.size(0)
        x = x.permute((0, 2, 1, 3, 4)).contiguous()
        y = self.proj(x)
        y = y.permute((0, 2, 1, 3, 4)).contiguous()
        t = y.size(1)
        y = y.reshape(b*y.size(1), *y.shape[2:])
        y = self.encoder(y)
        y = y.reshape(b, t, *y.shape[1:])
        y = y.permute((0, 2, 1, 3, 4)).contiguous()
        y_pool = self.avg_pool(y)
        y_pool = y_pool.permute((0, 2, 1, 3, 4)).contiguous()
        y_pool = y_pool.reshape(b, t, -1)
        fmaps = None
        if return_feature_maps:
            fmaps = y.permute((0, 2, 1, 3, 4)).contiguous()
        out = {'outputs': y_pool, 'feature_maps': fmaps}
        return out
