import os
import time

import einops
import torch.nn as nn
import numpy as np
import torch

from CrossModal.AudioNet import AudioNet
from CrossModal.MBT import MBT
from CrossModal.VisualNet import VisualNet
from CrossModal.Perceiver import Perceiver


class CrossNet(nn.Module):
    def __init__(self, n_fft=512, frame=10, num_classes=1, img_width=224, img_height=224,
                 auto_mode=False, depth=4, t=8, dim=512):
        super(CrossNet, self).__init__()
        self.T = t
        self.vis = VisualNet(num_classes=num_classes, img_width=img_width, img_height=img_height,
                             auto_mode=auto_mode)
        self.aud = AudioNet(n_fft=n_fft, frame=frame * t, time=t, auto_mode=auto_mode)
        # with perceiver
        # self.mixer = Perceiver(num_freq_bands=6,
        #                        max_freq=10.,
        #                        depth=6,
        #                        input_channels=1)
        self.mixer = MBT(depth=6, t=1)
        # with perceiver
        # self.syn_bi_classify = nn.Linear(dim, num_classes)
        # with mbt
        self.syn_bi_classify = nn.Linear(dim*2, num_classes)

    def forward(self, vis, aud):
        vis_pred, vis_fea = self.vis(vis)
        aud_pred, aud_fea = self.aud(aud)
        # with perceiver
        # mixup = einops.rearrange(torch.stack([vis_fea, aud_fea], dim=-1), 'B V (A C) -> B V A C', C=1)
        # sync = self.mixer(mixup)
        # with mbt
        sync = self.mixer(aud_fea, vis_fea)
        syn_pred = self.syn_bi_classify(sync)
        return aud_pred, vis_pred, syn_pred

