#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Random baseline model wrapper for the Implicit Embeddings Benchmark.
This wrapper provides a model that predicts labels randomly based on label distribution.
"""

import logging
import numpy as np
from typing import List, Dict, Any, Union, Optional
from tqdm import tqdm
from collections import Counter

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("RandomBaselineModel")

class RandomBaselineModel:
    """
    Random baseline model that predicts labels based on label distribution.
    Implements a subset of the SentenceTransformer interface for compatibility.
    """
    
    def __init__(self, embedding_dim: int = 2, batch_size: int = 32):
        """
        Initialize the random baseline model.
        
        Args:
            embedding_dim: Dimension of random embeddings (not functionally important)
            batch_size: Batch size (has no effect but kept for API compatibility)
        """
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.label_distribution = None
        self.rng = np.random.RandomState(42)  # Fixed seed for reproducibility
        
        logger.info(f"Initialized Random Baseline model with embedding_dim={embedding_dim}")
    
    def fit_label_distribution(self, labels: List[int]):
        """
        Fit the label distribution for random predictions.
        
        Args:
            labels: List of training labels
        """
        # Calculate label distribution from training data
        counter = Counter(labels)
        total = len(labels)
        self.label_distribution = {label: count / total for label, count in counter.items()}
        
        logger.info(f"Fitted label distribution: {self.label_distribution}")
    
    def encode(self, sentences: Union[str, List[str]], show_progress_bar: bool = True, batch_size: Optional[int] = None) -> np.ndarray:
        """
        Generate random embeddings for sentences. The actual values don't matter
        as the RandomBaseline will make predictions based on the label distribution.
        
        Args:
            sentences: Single sentence or list of sentences
            show_progress_bar: Whether to show progress bar (for API compatibility)
            batch_size: Batch size (for API compatibility)
            
        Returns:
            NumPy array of random embeddings
        """
        # Convert single sentence to list
        if isinstance(sentences, str):
            sentences = [sentences]
        
        # Generate random embeddings
        n_samples = len(sentences)
        
        if show_progress_bar:
            logger.info(f"Generating random embeddings for {n_samples} sentences...")
        
        # Generate random embeddings (values don't matter for the random baseline)
        embeddings = self.rng.randn(n_samples, self.embedding_dim)
        
        return embeddings
    
    def predict_random(self, n_samples: int) -> np.ndarray:
        """
        Generate random predictions based on the learned label distribution.
        
        Args:
            n_samples: Number of samples to generate predictions for
            
        Returns:
            NumPy array of random predictions
        """
        if self.label_distribution is None:
            # If no distribution is fitted, use uniform distribution
            logger.warning("No label distribution fitted. Using uniform distribution.")
            labels = list(range(2))  # Default to binary classification
            probs = [1/len(labels)] * len(labels)
        else:
            labels = list(self.label_distribution.keys())
            probs = list(self.label_distribution.values())
        
        # Generate random predictions according to the label distribution
        predictions = self.rng.choice(labels, size=n_samples, p=probs)
        
        return predictions 