from mmaction.apis import init_recognizer
import torch
import torch.nn as nn

class SlowFast(nn.Module):

    def __init__(
        self,
        **kwargs
    ):
        super(SlowFast, self).__init__()

        config_file = 'pretrained/configs/slowfast_r101_8x8x1_256e_kinetics400_rgb.py'
        checkpoint_file = 'pretrained/slowfast_r101_8x8x1_256e_kinetics400_rgb_20210218-0dd54025.pth'
        self.model = init_recognizer(config_file, checkpoint_file, **kwargs)
        
        self.out_dim = 2304
        
    def forward(self, x):
        x_slow, x_fast = self.model.backbone(x)

        # ([N, channel_slow, 1, 1, 1], [N, channel_fast, 1, 1, 1])
        x_slow = self.model.cls_head.avg_pool(x_slow)
        x_fast = self.model.cls_head.avg_pool(x_fast)
        # [N, channel_fast + channel_slow, 1, 1, 1]
        x = torch.cat((x_fast, x_slow), dim=1)

        if self.model.cls_head.dropout is not None:
            x = self.model.cls_head.dropout(x)

        # [N x C]
        x = x.view(x.size(0), -1)
        
        return x

