import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class GTELargeEN():
    def __init__(self,
                 device,
                 batch_size=32,
                 normalize=True):
        self.device = device
        self.batch_size = batch_size
        model_path = 'retrieve/smallm/gte_large_en_v1.5'
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            unpad_inputs=True,
            use_memory_efficient_attention=False).to(device)
        self.normalize = normalize

    @torch.no_grad()
    def embed(self, text_list):
        if len(text_list) == 0:
            return torch.zeros(0, 1024)

        embeddings = []
        for i in range(0, len(text_list), self.batch_size):
            batch_texts = text_list[i:i+self.batch_size]
            batch_dict = self.tokenizer(
                batch_texts, 
                max_length=8192, 
                padding=True,
                truncation=True, 
                return_tensors='pt'
            ).to(self.device)
            
            outputs = self.model(**batch_dict).last_hidden_state
            emb = outputs[:, 0]  
            
            if self.normalize:
                emb = F.normalize(emb, p=2, dim=1)
                
            embeddings.append(emb.cpu())
            
            del batch_dict, outputs, emb
            torch.cuda.empty_cache()

        return torch.cat(embeddings, dim=0)

    def __call__(self, q_text, text_entity_list, relation_list, entity2id, rel2id):
        q_emb = self.embed([q_text])

        entity_embs = self.embed(text_entity_list)
        entity_emb_dict = {entity2id[text]: entity_embs[i] 
                          for i, text in enumerate(text_entity_list) 
                          if text in entity2id}

        relation_embs = self.embed(relation_list)
        relation_emb_dict = {rel2id[text]: relation_embs[i] 
                           for i, text in enumerate(relation_list) 
                           if text in rel2id}

        return q_emb, entity_emb_dict, relation_emb_dict
