# models/text_encoder.py

import torch
from transformers import AutoModel

class DualTextEncoder(torch.nn.Module):
    def __init__(self, cfg_en, cfg_zh):
        super().__init__()
        self.encoder_en = AutoModel.from_pretrained(cfg_en['pretrained'])
        self.encoder_zh = AutoModel.from_pretrained(cfg_zh['pretrained'])

        freeze_layers_en = cfg_en.get('freeze_layers', 0)
        freeze_layers_zh = cfg_zh.get('freeze_layers', 0)

        for i, layer in enumerate(self.encoder_en.encoder.layer):
            if i < freeze_layers_en:
                for p in layer.parameters():
                    p.requires_grad = False
        for i, layer in enumerate(self.encoder_zh.encoder.layer):
            if i < freeze_layers_zh:
                for p in layer.parameters():
                    p.requires_grad = False

    def forward(self, input_ids, attention_mask, lang):
        """
        input_ids: [B, L]
        attention_mask: [B, L]
        lang: 'en' or 'zh'
        """
        if lang == 'en':
            out = self.encoder_en(input_ids=input_ids, attention_mask=attention_mask)
        elif lang == 'zh':
            out = self.encoder_zh(input_ids=input_ids, attention_mask=attention_mask)
        else:
            raise ValueError(f"Unsupported language: {lang}")
        # Take CLS feature
        return out.last_hidden_state[:, 0]   # [B, D]