import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Flatten

class KinovaPretextNetwork(nn.Module):
    def __init__(self, config):
        super(KinovaPretextNetwork, self).__init__()
        self.imgBranch = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(), #(3, 96, 96)->(32, 48, 48)
            nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ReLU(), # (32, 48, 48)->(32, 24, 24)
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), # (32, 24, 24)->(64,12,12)
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(), # (64, 12, 12)->(64, 6, 6)
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(), # (64, 12, 12)->(64, 3, 3)
            Flatten(),
            nn.Linear(64*9, 128), nn.ReLU(),
            nn.Linear(128, config.representationDim)
        )

        self.soundBranch = nn.Sequential(
            nn.Conv2d(1, 32, (5, 40), stride=(2, 1)), nn.ReLU(),  # (1, 100, 40)->(32, 48, 1)
            nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 48, 1)->(32, 23, 1)
            nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 23, 1)->(32, 11, 1)
            nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 11, 1)->(32, 5, 1)
            Flatten(),
            nn.Linear(160, 128), nn.ReLU(),
            nn.Linear(128, config.representationDim)
        )


    def forward(self, image, sound_positive, sound_negative):
        image_feat = None
        sound_feat_negative = None
        if image is not None:
            image_feat=F.normalize(self.imgBranch(image), p=2, dim=1)
        sound_feat_positive=F.normalize(self.soundBranch(sound_positive), p=2, dim=1)
        if sound_negative is not None:
            sound_feat_negative=F.normalize(self.soundBranch(sound_negative), p=2, dim=1)

        return image_feat, sound_feat_positive, sound_feat_negative
