import torch
import torch.nn as nn
import torch.nn.functional as F

class HyperNetwork(nn.Module):

    def __init__(self, e_dim = 1024, k = 16, inner_k=256):
        super(HyperNetwork, self).__init__()
        self.e_dim = e_dim
        self.k = k
        self.inner_k = inner_k

        self.fc1 = nn.Linear(self.e_dim, self.k, bias=True)
        self.fc2 = nn.Linear(self.k, self.e_dim*self.inner_k*2, bias=True)  
        self.relu = nn.ReLU()

    def forward(self, e):

        out = self.fc1(e)
        out = self.relu(out) 
        out = self.fc2(out)
        
        W = out.reshape(2,self.e_dim*self.inner_k)
        W1 = W[0].view(self.inner_k, self.e_dim)
        W2 = W[1].view(self.e_dim, self.inner_k)

        return W1, W2

class PrimaryNetwork(nn.Module):

    def __init__(self, e_dim=1024, k = 16, inner_k=256):
        super(PrimaryNetwork, self).__init__()
        self.e_dim = e_dim
        self.k = k
        self.inner_k = inner_k
        self.hnet = HyperNetwork(e_dim=self.e_dim, k = self.k, inner_k = self.inner_k)
        self.relu = nn.ReLU()

    def forward(self, ques_emb, image_features, text_features):
        identity = image_features

        W1, W2 = self.hnet(ques_emb)
        out = F.linear(image_features, W1)
        out = self.relu(out) 
        out = F.linear(out, W2)

        out += identity

        adapted_image_features = out / out.norm(dim=-1, keepdim=True)

        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features = text_features.to(torch.float32)
        logits = 100.0 * adapted_image_features @ torch.transpose(text_features, 1, 2)
        logits = torch.squeeze(logits,1)

        return logits
