import torch
import torch.nn as nn
import torch.nn.functional as F
from loss import l_mv_cl, l_global, l_or

class MultiViewEncoder(nn.Module):
    def __init__(self, input_dims, hidden_dim=128, D=64):
        """input_dims: dict, { "text": dim_text, "clinical": dim_clinical, ... }
        hidden_dim: hidden dim des MLP de projection
        D: output final (fused) dim"""
        super().__init__()
        self.view_names = list(input_dims.keys())
        
        # f_theta^(v)
        self.f_theta = nn.ModuleDict({
            v: nn.Sequential(
                nn.Linear(input_dims[v], hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, D)
            ) for v in self.view_names
        })

        # g_phi 
        self.g_phi = nn.Linear(len(self.view_names) * D, D)

    def forward(self, x):
        z_views = {}
        z_list = []

        for v in self.view_names:
            z_v = self.f_theta[v](x[v])
            z_v = F.normalize(z_v, p=2, dim=-1)
            z_views[v] = z_v
            z_list.append(z_v)

        # Concat+linear proj
        e_k = self.g_phi(torch.cat(z_list, dim=-1))
        e_k = F.normalize(e_k, p=2, dim=-1)

        return e_k, z_views