import os

import einops
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import transforms

from CrossModal.R2Plus1D import R2Plus1DClassifier


class VisualNet(nn.Module):
    def __init__(self, num_classes=1, img_width=224, img_height=224,
                 auto_mode=False, is_load_pretrained_pth=True, vis_only=False, direct_classify=False):
        super(VisualNet, self).__init__()
        assert img_width == img_height
        img_size = img_width
        self.num_classes = num_classes
        self.vis_only = vis_only
        self.direct_classify = direct_classify
        self.vis_net = R2Plus1DClassifier(1, (2, 2, 2, 2))

    def forward(self, x):
        if not self.vis_only:
            B, T, H, W, C = x.shape

            x = einops.rearrange(x, 'B T H W C -> B C T H W')
            x, fea = self.vis_net(x)
            return x, fea
        elif self.direct_classify:
            low, high = self.split(x)
            x, fea = self.vis_net(high)
            return x


if __name__ == '__main__':
    pass
