import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import othermodels


class Normalize(nn.Module):
    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm)
        return out


class Flatten(nn.Module):
    """A shape adaptation layer to patch certain networks."""
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)


class Unsqueeze(nn.Module):
    """A shape adaptation layer to patch certain networks."""
    def __init__(self):
        super(Unsqueeze, self).__init__()

    def forward(self, x):
        return x.unsqueeze(-1)


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x


def random_weight_init(model):
    for m in model.modules():
        if isinstance(m, nn.Conv3d):
            m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None: 
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm3d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()


def get_video_feature_extractor(vid_base_arch='r2plus1d_18', pretrained=False):
    if vid_base_arch =='r2plus1d_18':
        model = torchvision.models.video.__dict__[vid_base_arch](pretrained=pretrained)
        if not pretrained:
            print("Randomy initializing models")
            random_weight_init(model)
        model.fc = Identity()
    elif vid_base_arch =='r2plus1d_34':
        model = othermodels.r2plus1d_34()
        model.fc = Identity()
    elif vid_base_arch =='r2plus1d_50':
        model = othermodels.r2plus1d_50()
        model.fc = Identity()
    else:
        model = othermodels.S3DG()
        model.fc = Identity()
    return model


def get_audio_feature_extractor(aud_base_arch='resnet18', pretrained=False):
    assert(aud_base_arch in ['resnet9', 'resnet18', 'resnet34', 'resnet50', 'vgg_audio'])
    if aud_base_arch in ['resnet18', 'resnet34', 'resnet50']:
        model = torchvision.models.__dict__[aud_base_arch](pretrained=False)
        model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        model.fc = Identity()
        return model
    elif aud_base_arch == 'resnet9':
        model = torchvision.models.resnet._resnet('resnet9', torchvision.models.resnet.BasicBlock,
                                                 [1,1,1,1], pretrained=False,progress=False)
        model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        model.fc = Identity()
        return model
    elif aud_base_arch == 'vgg_audio':
        model = othermodels.VGG16AudioNet()
        return model


class VideoBaseNetwork(nn.Module):
    def __init__(self, vid_base_arch='r2plus1d_18', pretrained=False, norm_feat=False):
        super(VideoBaseNetwork, self).__init__()
        self.base = get_video_feature_extractor(
            vid_base_arch, 
            pretrained=pretrained
        )
        self.norm_feat = norm_feat
    def forward(self, x):
        x = self.base(x).squeeze()
        if self.norm_feat:
            x = F.normalize(x, p=2, dim=1)
        return x


class AudioBaseNetwork(nn.Module):
    def __init__(self, aud_base_arch='resnet18', pretrained=False, norm_feat=False):
        super(AudioBaseNetwork, self).__init__()
        self.base = get_audio_feature_extractor(
            aud_base_arch, 
            pretrained=pretrained
        )
        self.norm_feat = norm_feat
    def forward(self, x):
        x = self.base(x).squeeze()
        if self.norm_feat:
            x = F.normalize(x, p=2, dim=1)
        return x


class AV_GDT(nn.Module):
    def __init__(self, vid_base_arch='r2plus1d_18', aud_base_arch='resnet18',
                 pretrained=False, norm_feat=True, use_mlp=False,
                 mlptype=0, headcount=1, num_classes=256, use_max_pool=False):
        super(AV_GDT, self).__init__()
        self.video_network = VideoBaseNetwork(
            vid_base_arch, 
            pretrained=pretrained
        )
        self.audio_network = AudioBaseNetwork(
            aud_base_arch, 
            pretrained=pretrained
        )
        if use_max_pool:
            self.video_network.base.avgpool = torch.nn.AdaptiveMaxPool3d(output_size=(1, 1, 1))
            self.audio_network.base.avgpool = torch.nn.AdaptiveMaxPool2d(output_size=(1, 1))
        self.use_mlp = use_mlp
        self.hc = headcount
        self.norm_feat = norm_feat
        self.return_features = False
        encoder_dim = 512 if vid_base_arch in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50'] else 2048
        encoder_dim_a = 2048 if aud_base_arch in ['resnet50'] else 512
        if use_mlp and self.hc == 1:
            if mlptype == 0:
                print("Using Regular MLP")
                self.mlp_v = MLPv2(encoder_dim, num_classes)
                self.mlp_a = MLPv2(encoder_dim_a, num_classes)
            elif mlptype == 1:
                print("Using MLP to be combined with SyncBN")
                self.mlp_v = MLPv2(encoder_dim, num_classes)
                self.mlp_a = MLPv2(encoder_dim_a, num_classes)
            elif mlptype == 2:
                print("Using 3-Layer MLP")
                self.mlp_v = MLPv3(encoder_dim, num_classes)
                self.mlp_a = MLPv3(encoder_dim_a, num_classes)
            else:
                print("Using Linear Layer")
                self.mlp_v = nn.Linear(encoder_dim, num_classes)
                self.mlp_a = nn.Linear(encoder_dim_a, num_classes)
        elif self.hc > 1:
            if use_mlp:
                if mlptype == 0:
                    print("Using Regular MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLPv2(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLPv2(512, num_classes))
                elif mlptype == 1:
                    print("Using MLP to be combined with SyncBN")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLPv2(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLPv2(512, num_classes))
                elif mlptype == 2:
                    print("Using MLP 3-layers")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d" % a, MLPv3(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d" % a, MLPv3(512, num_classes))
                else:
                    print("Using Residual MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLP_residual(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLP_residual(512, num_classes))
            else:
                for a, i in enumerate(range(self.hc)):
                    setattr(self, "mlp_a%d"%a, nn.Linear(512, num_classes))
                    setattr(self, "mlp_v%d"%a, nn.Linear(512, num_classes))


    def forward(self, img, spec, whichhead=0):
        img_features = self.video_network(img).squeeze()
        aud_features = self.audio_network(spec).squeeze()
        if self.return_features:
            return img_features, aud_features
        if len(aud_features.shape) == 1:
            aud_features = aud_features.unsqueeze(0)
        if len(img_features.shape) == 1:
            img_features = img_features.unsqueeze(0)

        if self.use_mlp and self.hc == 1:
            img_features = self.mlp_v(img_features)
            aud_features = self.mlp_a(aud_features)
            if self.norm_feat:
                img_features = F.normalize(img_features, p=2, dim=1)
                aud_features = F.normalize(aud_features, p=2, dim=1)
        elif self.hc > 1:
            # note: will return lists here.
            img_features = [getattr(self, "mlp_v%d"%head)(img_features) for head in range(self.hc)]
            aud_features = [getattr(self, "mlp_a%d"%head)(aud_features) for head in range(self.hc)]
            if self.norm_feat:
                img_features = [ F.normalize(imgf, p=2, dim=1) for imgf in img_features ]
                aud_features = [ F.normalize(audf, p=2, dim=1) for audf in aud_features ]

        return img_features, aud_features


class MLP(nn.Module):
    def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
        super(MLP, self).__init__()
        self.n_input = n_input
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        if n_hidden is None:
            # use linear classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_classes, bias=True)
            )
        else:
            # use simple MLP classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_hidden, bias=False),
                nn.BatchNorm1d(n_hidden),
                nn.ReLU(inplace=True),
                nn.Dropout(p=p),
                nn.Linear(n_hidden, n_classes, bias=True)
            )

    def forward(self, x):
        return self.block_forward(x)

class MLPv3(nn.Module):
    def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
        super(MLPv3, self).__init__()
        self.n_input = n_input
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        # use simple MLP classifier
        self.block_forward = nn.Sequential(
            Flatten(),
            nn.Dropout(p=p),
            nn.Linear(n_input, n_hidden, bias=False),
            Unsqueeze(),
            nn.BatchNorm1d(n_hidden),
            Flatten(),
            nn.ReLU(inplace=True),

            Flatten(),
            nn.Dropout(p=p),
            nn.Linear(n_hidden, n_hidden, bias=True),
            Unsqueeze(),
            nn.BatchNorm1d(n_hidden),
            Flatten(),
            nn.ReLU(inplace=True),

            nn.Dropout(p=p),
            nn.Linear(n_hidden, n_classes, bias=True)
        )

    def forward(self, x):
        return self.block_forward(x)

class MLPv2(nn.Module):
    def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
        super(MLPv2, self).__init__()
        self.n_input = n_input
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        if n_hidden is None:
            # use linear classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_classes, bias=True)
            )
        else:
            # use simple MLP classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_hidden, bias=False),
                Unsqueeze(),
                nn.BatchNorm1d(n_hidden),
                Flatten(),
                nn.ReLU(inplace=True),
                nn.Dropout(p=p),
                nn.Linear(n_hidden, n_classes, bias=True)
            )

    def forward(self, x):
        return self.block_forward(x)


class MLP_residual(nn.Module):
    """residual connection MLP"""
    def __init__(self, in_planes, planes):
        super(MLP_residual, self).__init__()
        withbias = True
        self.block = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(in_planes, planes, bias=withbias),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(planes, planes, bias=withbias),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True)
        )

        self.shortcut = nn.Sequential()
        if in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Linear(in_planes, planes, bias=False),
                nn.BatchNorm1d(planes)
            )

    def forward(self, x):
        return self.block(x) + self.shortcut(x)

if __name__ == '__main__':
    l3 = AVC_GDT(vid_base_arch='r2plus1d_18', aud_base_arch='vgg_audio', pretrained=False, use_max_pool=True, mlptype=2)
    img = torch.rand(1, 3, 25, 112, 112)
    aud = torch.rand(1, 1, 40, 99)
    out_a, out_v = l3(img, aud)
    print(out_a.shape)
    print(out_v.shape)
