import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision.models as models

class ImageEncoder(nn.Module):
    def __init__(self, out_dim: int = 128):
        super().__init__()
        reduce = os.environ.get('REDUCE_IMAGE_BACKBONE', '0') == '1'
        if reduce:
            # Use a lighter model to reduce memory
            backbone = models.mobilenet_v2(weights=None)
            feat_dim = backbone.classifier[-1].in_features
            backbone.classifier = nn.Identity()
            self.backbone = backbone
        else:
            backbone = models.convnext_tiny(weights=None)
            feat_dim = backbone.classifier[2].in_features
            backbone.classifier[2] = nn.Identity()
            self.backbone = backbone
        self.proj = nn.Linear(feat_dim, out_dim)
        init.xavier_uniform_(self.proj.weight)
        init.constant_(self.proj.bias, 0.0)

    def forward(self, x):
        if torch.isnan(x).any() or torch.isinf(x).any():
            print("Warning: ImageEncoder input contains NaN/Inf, replacing with zeros")
            x = torch.where(torch.isnan(x) | torch.isinf(x), 
                           torch.zeros_like(x), x)
        
        x = torch.clamp(x, min=-10.0, max=10.0)
        
        feats = self.backbone(x)
        if feats.dim() > 2:
            feats = feats.view(feats.size(0), -1)
        
        if torch.isnan(feats).any() or torch.isinf(feats).any():
            print("Warning: ImageEncoder backbone output contains NaN/Inf, replacing with small random values")
            feats = torch.where(torch.isnan(feats) | torch.isinf(feats),
                               torch.randn_like(feats) * 0.01, feats)
        
        output = self.proj(feats)
        
        if torch.isnan(output).any() or torch.isinf(output).any():
            print("Warning: ImageEncoder output contains NaN/Inf, replacing with small random values")
            output = torch.where(torch.isnan(output) | torch.isinf(output),
                                torch.randn_like(output) * 0.01, output)
        
        return output

def test_image_encoder():
    dummy_input = torch.randn(2, 3, 224, 224)  
    encoder = ImageEncoder(out_dim=256)
    output = encoder(dummy_input)
    print("Input shape:", dummy_input.shape)
    print("Output shape:", output.shape)
    assert output.shape == (2, 256), "Error output"

if __name__ == "__main__":
    test_image_encoder()
