import os
from transformers import AutoModel, AutoTokenizer
import torch  

class EmbeddingModel:
    def __init__(self, model_name=None, device=None):
       
        if model_name is None:
            preferred_abs = "/home/bge-base-en"
            if os.path.isdir(preferred_abs):
                model_name = preferred_abs
            else:
                env_path = os.environ.get('BGE_MODEL_PATH')
                if env_path and os.path.isdir(env_path):
                    model_name = os.path.normpath(env_path)
                else:
                   
                    current_dir = os.path.dirname(os.path.abspath(__file__))
                   
                    replay_dir = os.path.dirname(current_dir)
                 
                    default_model_path = os.path.join(replay_dir, "LLMCache", "bge-base-en")
                  
                    if os.path.isdir(default_model_path):
                        model_name = os.path.normpath(default_model_path)
                    else:
                        # Fallback to HuggingFace model if local path doesn't exist
                        model_name = "BAAI/bge-base-en-v1.5"

       
        if os.path.isdir(model_name):
            tok_json = os.path.join(model_name, "tokenizer.json")
            if os.path.isfile(tok_json):
                self.tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer_file=tok_json, local_files_only=True)
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
            self.model = AutoModel.from_pretrained(model_name, local_files_only=True)
        else:
            # Use HuggingFace hub (model_name is a valid repo ID)
            # Ensure HF_ENDPOINT is set for mirror usage
            hf_endpoint = os.environ.get('HF_ENDPOINT', 'https://huggingface.co')
            if 'hf-mirror.com' not in hf_endpoint:
                print(f"[WARNING] HF_ENDPOINT is not set to hf-mirror.com, current value: {hf_endpoint}")
                print(f"[INFO] Setting HF_ENDPOINT to https://hf-mirror.com")
                os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
            
            print(f"[INFO] Loading model {model_name} from {os.environ.get('HF_ENDPOINT')}")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModel.from_pretrained(model_name)

        # Move model to specified device if provided
        if device is not None:
            self.model.to(device)
            print(f"[DEVICE] EmbeddingModel moved to {device}")
        else:
            current_device = next(self.model.parameters()).device
            print(f"[DEVICE] EmbeddingModel loaded on {current_device}")

        self.model.eval()

    def _safe_max_length(self) -> int | None:
        """
        Pick a safe max_length for tokenization to avoid exceeding the model's position embeddings.

        Some tokenizers (e.g. for BERT-like models) may have `model_max_length` unset/very large,
        which can lead to sequences > 512 and runtime errors when adding positional embeddings.
        """
        tok = getattr(self, "tokenizer", None)
        model = getattr(self, "model", None)
        cfg = getattr(model, "config", None)

        # Prefer model config max_position_embeddings when available (common: 512 for BERT family).
        mpe = getattr(cfg, "max_position_embeddings", None)
        try:
            mpe = int(mpe) if mpe is not None else None
        except Exception:
            mpe = None

        # Tokenizer max length can be huge (e.g., 1000000000000000019884624838656).
        tml = getattr(tok, "model_max_length", None)
        try:
            tml = int(tml) if tml is not None else None
        except Exception:
            tml = None

        # If tokenizer max is "effectively infinite", ignore it.
        if tml is not None and tml > 100_000:
            tml = None

        if mpe is None and tml is None:
            return None
        if mpe is None:
            return tml
        if tml is None:
            return mpe
        return int(min(mpe, tml))

    @staticmethod
    def _masked_mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor | None) -> torch.Tensor:
        """
        Mean-pool token embeddings with attention mask (excludes padding).

        Args:
            last_hidden_state: [B, L, H]
            attention_mask: [B, L] with 1 for real tokens, 0 for padding
        Returns:
            pooled: [B, H]
        """
        if attention_mask is None:
            return last_hidden_state.mean(dim=1)
        mask = attention_mask.to(dtype=last_hidden_state.dtype).unsqueeze(-1)  # [B,L,1]
        masked = last_hidden_state * mask
        denom = mask.sum(dim=1).clamp(min=1e-8)  # [B,1]
        return masked.sum(dim=1) / denom

    def get_embedding(self, text):
      
        device = next(self.model.parameters()).device
        max_len = self._safe_max_length()
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,
        )
        
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
        
        pooled = self._masked_mean_pool(outputs.last_hidden_state, inputs.get("attention_mask", None))
        return pooled.squeeze().cpu().numpy()
    def get_token_embeddings(self, texts, max_length=None, device=None, return_device=None):
      
        batch_size_total = len(texts)
        if batch_size_total == 0:
            raise ValueError("text")

        chunk_size = min(8, max(1, batch_size_total // 2)) 

        hidden_states_list = []
        input_ids_list = []
        attention_mask_list = []

        # NOTE: When RL4CO builds datasets, it may call the generator with very large batch_size
        # (e.g., thousands). Keeping all token embeddings on GPU will OOM. To avoid this, we
        # optionally stream chunk outputs to `return_device` (often CPU) before concatenation.
        with torch.no_grad():
            for start_idx in range(0, batch_size_total, chunk_size):
                end_idx = min(start_idx + chunk_size, batch_size_total)
                texts_chunk = texts[start_idx:end_idx]

                if max_length is not None:
                    inputs = self.tokenizer(
                        texts_chunk, return_tensors="pt", padding='max_length',
                        truncation=True, max_length=max_length
                    )
                else:
                    inputs = self.tokenizer(texts_chunk, return_tensors="pt", padding=True, truncation=True)

                if device is not None:
                    inputs = {k: v.to(device) for k, v in inputs.items()}

                if device is not None and torch.cuda.is_available():
                    with torch.amp.autocast("cuda", dtype=torch.float16):
                        outputs = self.model(**inputs)
                else:
                    outputs = self.model(**inputs)

                hs = outputs.last_hidden_state
                ids = inputs["input_ids"]
                attn = inputs.get("attention_mask", None)

                if return_device is not None:
                    hs = hs.to(return_device)
                    ids = ids.to(return_device)
                    if attn is not None:
                        attn = attn.to(return_device)

                hidden_states_list.append(hs)
                input_ids_list.append(ids)
                attention_mask_list.append(attn)

                
                del outputs, hs
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

        last_hidden_state = torch.cat(hidden_states_list, dim=0)
        input_ids = torch.cat(input_ids_list, dim=0)
        attention_mask = None
        if all(m is not None for m in attention_mask_list):
            attention_mask = torch.cat(attention_mask_list, dim=0)

        return {
            'last_hidden_state': last_hidden_state,
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }
    
    def encode(self, texts, convert_to_tensor=False, device=None):
        if isinstance(texts, str):
            texts = [texts]
            single_input = True
        else:
            single_input = False
        
        batch_size_total = len(texts)
        chunk_size = min(16, max(1, batch_size_total // 2))
        emb_list = []

        # Decide the device to run on.
        # If caller didn't specify, use the model's current device.
        run_device = device
        if run_device is not None:
            self.model.to(run_device)
        else:
            run_device = next(self.model.parameters()).device

        with torch.no_grad():
            for start_idx in range(0, batch_size_total, chunk_size):
                end_idx = min(start_idx + chunk_size, batch_size_total)
                texts_chunk = texts[start_idx:end_idx]
                max_len = self._safe_max_length()
                inputs = self.tokenizer(
                    texts_chunk,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=max_len,
                )
                # Always move inputs to the same device as the model to avoid CPU/GPU mismatch.
                inputs = {k: v.to(run_device) for k, v in inputs.items()}

                if run_device is not None and str(run_device).startswith("cuda") and torch.cuda.is_available():
                    with torch.amp.autocast("cuda", dtype=torch.float16):
                        outputs = self.model(**inputs)
                else:
                    outputs = self.model(**inputs)

                pooled = self._masked_mean_pool(outputs.last_hidden_state, inputs.get("attention_mask", None))
                emb_list.append(pooled)
                del outputs
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

        embeddings = torch.cat(emb_list, dim=0)
        
        if convert_to_tensor:
            return embeddings.squeeze(0) if single_input else embeddings
        else:
            result = embeddings.cpu().numpy()
            return result.squeeze(0) if single_input else result

if __name__ == "__main__":
    embedder = EmbeddingModel()
    embedding = embedder.get_embedding("What is the capital of France?")
    print(embedding.shape)
