#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Bag-of-tokens model wrapper for the Implicit Embeddings Benchmark.
This wrapper provides a compatible interface with SentenceTransformer for bag-of-tokens models.
Uses BERT tokenizer to ensure consistent vocabulary and dimensions.
"""

import logging
import numpy as np
import torch
from typing import List, Dict, Any, Union, Optional
from collections import Counter
from transformers import AutoTokenizer
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("BagOfTokensModel")

class BagOfTokensModel:
    """
    Wrapper for bag-of-tokens models to be compatible with the benchmark infrastructure.
    Provides a similar interface to SentenceTransformer for encoding texts.
    Uses a pre-trained BERT tokenizer to ensure consistent vocabulary.
    """
    
    def __init__(self, batch_size: int = 32, bert_model: str = "bert-base-uncased"):
        """
        Initialize the bag-of-tokens model wrapper.
        
        Args:
            batch_size: Batch size for processing (has no effect but kept for API compatibility)
            bert_model: Name of the BERT model to use for tokenization
        """
        self.batch_size = batch_size
        self.bert_model = bert_model
        
        # Load BERT tokenizer
        logger.info(f"Loading BERT tokenizer from {bert_model}...")
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
        
        # Get vocabulary size
        self.vocab_size = len(self.tokenizer.vocab)
        logger.info(f"Loaded tokenizer with vocabulary size: {self.vocab_size}")
        
        logger.info(f"Initialized BERT-based Bag-of-tokens model")
    
    def fit(self, sentences: List[str]):
        """
        No fitting needed for BERT tokenizer as it has a fixed vocabulary.
        Kept for API compatibility.
        
        Args:
            sentences: List of sentences (not used)
        """
        logger.info("BERT tokenizer has a fixed vocabulary, no fitting needed.")
        return self
    
    def encode(self, sentences: Union[str, List[str]], show_progress_bar: bool = True, batch_size: Optional[int] = None) -> np.ndarray:
        """
        Encode sentences into bag-of-tokens representations.
        Always returns vectors with full BERT vocabulary dimension.
        
        Args:
            sentences: Single sentence or list of sentences to encode
            show_progress_bar: Whether to show progress bar during encoding
            batch_size: Batch size for processing (for API compatibility)
            
        Returns:
            NumPy array of bag-of-tokens representations with full BERT vocabulary dimension
        """
        # Convert single sentence to list
        if isinstance(sentences, str):
            sentences = [sentences]
        
        num_sentences = len(sentences)
        
        # Initialize embeddings with zeros for full BERT vocabulary
        embeddings = np.zeros((num_sentences, self.vocab_size), dtype=np.float32)
        
        if show_progress_bar and num_sentences > 1:
            logger.info(f"Encoding {num_sentences} sentences with BERT-based bag-of-tokens model...")
        
        # Process each sentence
        for i, sentence in enumerate(sentences):
            # Tokenize the sentence
            tokens = self.tokenizer.encode(sentence, add_special_tokens=False)
            
            # Count token frequencies
            token_counts = Counter(tokens)
            
            # Fill in the embedding vector
            for token_id, count in token_counts.items():
                if token_id < self.vocab_size:  # Ensure the token ID is within bounds
                    embeddings[i, token_id] = count
        
        return embeddings 