import torch
import torchvision
import torch.nn as nn
import random

class BasicBlock3D(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv = nn.Sequential(nn.Conv3d(inplanes, outplanes, kernel_size=3, padding=1, stride=stride, bias=False),
                                  nn.BatchNorm3d(outplanes),
                                  nn.ReLU(True),
                                  nn.Conv3d(outplanes, outplanes, kernel_size=3, padding=1, bias=False),
                                  nn.BatchNorm3d(outplanes))
        self.downsample=downsample
        self.relu = nn.ReLU(True)
    def forward(self, x):
        feature = self.conv(x)
        if self.downsample is not None:
            return self.relu(self.downsample(x)+feature)
        return self.relu(x+feature)


class FusionModel(nn.Module):
    def __init__(self,num_classes):
        super(FusionModel, self).__init__()
        self.VideoInconv = nn.Sequential(nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=2, padding=[1, 3, 3], bias=False),
                                    nn.BatchNorm3d(64),
                                    nn.ReLU(True),
                                    nn.MaxPool3d(kernel_size=3, stride=2, padding=1))
        self.VideoLayer1 = self._make_layer(BasicBlock3D, 2, 64, 64, stride=1)
        self.VideoLayer2 = self._make_layer(BasicBlock3D, 2, 64, 128, stride=2)
        self.VideoLayer3 = self._make_layer(BasicBlock3D, 2, 128, 256, stride=2)
        self.VideoLayer4 = self._make_layer(BasicBlock3D, 2, 256, 512, stride=2)
        self.VideoAvgpool = nn.AdaptiveAvgPool3d(1)

        resnet = torchvision.models.resnet18(num_classes=num_classes)

        self.AudioInconv = nn.Sequential(nn.Conv2d(1, 64, stride=2, kernel_size=7, padding=True, bias=False),
                                    resnet.bn1, resnet.relu, resnet.maxpool)
        self.AudioLayer1 = resnet.layer1
        self.AudioLayer2 = resnet.layer2
        self.AudioLayer3 = resnet.layer3
        self.AudioLayer4 = resnet.layer4
        self.AudioAvgpool = resnet.avgpool
        self.dropout = nn.Dropout(p=0.5)

        self.proj = nn.Linear(1024, num_classes)

    def _make_layer(self, block, num_block, inplanes, outplanes, stride):
        downsample = None
        if inplanes != outplanes or stride != 1:
            downsample = nn.Sequential(nn.Conv3d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
                                       nn.BatchNorm3d(outplanes))
        layers = [block(inplanes, outplanes, stride, downsample)]
        for i in range(num_block):
            layers.append(block(outplanes, outplanes))
        return nn.Sequential(*layers)

    def forward(self, audio, video, modal_drop=False):
        hv = self.VideoInconv(video.permute(0, 2, 1, 3, 4))
        v1 = self.VideoLayer1(hv)
        v2 = self.VideoLayer2(v1)
        v3 = self.VideoLayer3(v2)
        v4 = self.VideoLayer4(v3)
        v_f = self.VideoAvgpool(v4).squeeze(-1).squeeze(-1).squeeze(-1)

        ha = self.AudioInconv(audio)
        a1 = self.AudioLayer1(ha)
        a2 = self.AudioLayer2(a1)
        a3 = self.AudioLayer3(a2)
        a4 = self.AudioLayer4(a3)
        a_f = self.AudioAvgpool(a4).flatten(1)

        if modal_drop:
            # print("hello")
            fake_f = torch.zeros_like(a_f).cuda().float()
            prob = random.random()
            if prob < 1/3:
                av_f = torch.cat((v_f, a_f), 1)
            elif prob < 2/3:
                av_f = torch.cat((v_f, fake_f), 1)
            else:
                av_f = torch.cat((fake_f, a_f), 1)

        else:
            av_f = torch.cat((v_f, a_f), 1)
        av_f = self.dropout(av_f)
        y_hat = self.proj(av_f)

        return a_f, v_f, y_hat
