import torch
import torch.nn as nn
import numpy as np
from vtb_ops import vtb
from hlb_utils import cosine_similarity

class Network(nn.Module):
    def __init__(self, in_features, hidden, out_features, labels, requires_grad=True, negative=False,
                 factor=1, drop_rate=0.1, batch_size=64): # Increased dropout for better generalization
        super().__init__()
        sampler, vtb_mod = vtb(batch_size=batch_size, input_dim=out_features)
        self.vtb = vtb_mod
        
        # Xavier/Kaiming style initialization
        scale = 1.0 / np.sqrt(out_features)
        self.pos = nn.Parameter(torch.randn(1, out_features) * scale, requires_grad=requires_grad)
        
        # Learnable temperature for softmax (Standard in CLIP-style contrastive models)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        # Class vectors initialized with better variance
        self.cls = nn.Parameter(torch.randn(labels, out_features) * scale, requires_grad=requires_grad)
        
        self.network = nn.Sequential(
            nn.Linear(in_features, hidden),
            nn.LayerNorm(hidden), # LayerNorm is more stable than BatchNorm for small batches
            nn.LeakyReLU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden, out_features),
            nn.LayerNorm(out_features) 
        )

    def forward(self, x):
        return self.network(x)

    def loss(self, logits, true):
        # 1. Unbind and Normalize
        unbind = self.vtb.unbinding(logits, self.pos)
        unbind = torch.nn.functional.normalize(unbind, dim=-1)
        cls_norm = torch.nn.functional.normalize(self.cls, dim=-1)
        
        # 2. Compute Similarities and scale by Temperature
        # [batch, labels]
        t = self.logit_scale.exp()
        similarities = (unbind @ cls_norm.t()) * t
        
        # 3. Multi-label Softmax Cross Entropy (Binary Cross Entropy on Logits)
        # This is more robust than 1-cosine loss
        loss_fn = nn.BCEWithLogitsLoss()
        return loss_fn(similarities, true.float())

    def inference(self, logits):
        unbind = self.vtb.unbinding(logits, self.pos)
        unbind = torch.nn.functional.normalize(unbind, dim=-1)
        cls_norm = torch.nn.functional.normalize(self.cls, dim=-1)
        
        # Use simple dot product (cosine similarity) for ranking
        return unbind @ cls_norm.t()
