import os
import pandas as pd
import torch
from .base import ColumnVectorizer
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

from sklearn.metrics.pairwise import cosine_similarity
from collections import OrderedDict
import sys
from typing import List, Dict
import warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"

DEBUGGING = False

class LRUCache:
    """Least Recently Used (LRU) cache implementation"""
    def __init__(self, capacity):
        self.cache = OrderedDict()
        self.capacity = capacity

    def get(self, key):
        if key not in self.cache:
            return None
        # Move to end (most recently used)
        self.cache.move_to_end(key)
        return self.cache[key]

    def put(self, key, value):
        if key in self.cache:
            # Move to end
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.capacity:
                # Remove least recently used item
                self.cache.popitem(last=False)
        self.cache[key] = value

class CategoricalVectorizer(ColumnVectorizer):
    def __init__(self, output_dim=1024, projection_dim=1024, 
                 model_name="Alibaba-NLP/gte-large-en-v1.5", 
                 cache_embedding=True,
                 max_cache_size=100000):
        """
        Initialize the categorical vectorizer.

        Args:
            output_dim (int): The dimension of the output vectors (D).
            projection_dim (int): The dimension to project embeddings to for similarity computation.
            model_name (str): The name of the pre-trained model to use for embeddings.
            cache_embedding (bool): Whether to cache embeddings across different calls.
            max_cache_size (int): Maximum number of embeddings to cache.
        """
        super().__init__(output_dim=output_dim, accepted_dtype=["object", "bool"])
        
        self.MODEL_NAME = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
        self.model = AutoModel.from_pretrained(self.MODEL_NAME, trust_remote_code=True)
        self.embedding_dim = self.model.config.hidden_size
        self.cache_embedding = cache_embedding
        self.embedding_cache = LRUCache(max_cache_size) if cache_embedding else None
        
        # Calculate approximate memory per embedding (in bytes)
        self.memory_per_embedding = output_dim * 4  # 4 bytes per float32
        
        # Freeze the language model
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Add projection layer
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(output_dim, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, projection_dim)
        )

        if self.embedding_dim != output_dim:
            raise ValueError(
                f"GTE model embedding dimension ({self.embedding_dim}) "
                f"does not match output_dim ({output_dim})."
            )

    def to(self, device):
        """Override to() to ensure model and projection are moved to the correct device"""
        super().to(device)
        self.model = self.model.to(device)
        self.projection = self.projection.to(device)
        return self

    def is_trainable(self):
        return True

    def required_config_keys(self):
        """
        Define the required keys for the configuration.

        Returns:
            list: List of required keys.
        """
        return ["categories"]

    def _compute_embeddings(self, values):
        """Helper method to compute embeddings for a list of unique values"""
        if DEBUGGING:
            print("Computing embeddings for:", values)
            print(f"Model device: {next(self.model.parameters()).device}")
        embeddings = {}
        batch_size = 32  # Process in batches to manage memory
        
        for i in range(0, len(values), batch_size):
            batch_values = values[i:i + batch_size]
            inputs = self.tokenizer(
                [str(v) for v in batch_values], 
                return_tensors="pt", 
                padding=True, 
                truncation=True
            )
            if DEBUGGING:
                print(f"Initial tokenizer outputs devices: {[v.device for v in inputs.values()]}")
            
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            if DEBUGGING:
                print(f"After moving to device {self.device}, input devices: {[v.device for v in inputs.values()]}")
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                if DEBUGGING:
                    print(f"Model output device: {outputs.last_hidden_state.device}")
                batch_embeddings = outputs.last_hidden_state.mean(dim=1)
                if DEBUGGING:
                    print(f"Batch embeddings device before CPU move: {batch_embeddings.device}")
                
                # Move embeddings to CPU immediately
                batch_embeddings = batch_embeddings.cpu()
                if DEBUGGING:
                    print(f"Batch embeddings device after CPU move: {batch_embeddings.device}")
                
                for idx, value in enumerate(batch_values):
                    embeddings[value] = batch_embeddings[idx]
            
            if DEBUGGING:
                print(f"Completed batch {i//batch_size + 1} of {(len(values) + batch_size - 1)//batch_size}")
        
        return embeddings

    def _vectorize(self, column, config):
        """
        Transform the categorical column into pre-trained LM embeddings.

        Args:
            column (pandas.Series): The categorical column to transform.
            config (dict): Configuration dictionary, must include "categories".

        Returns:
            torch.Tensor: Transformed tensor of shape (N, D).
        """
        categories = config["categories"]

        # Get unique values in the column
        unique_values = column.unique()

        # Filter out invalid categories and create zero tensor for them
        valid_values = [v for v in unique_values if v in categories]
        invalid_values = [v for v in unique_values if v not in categories]
        
        # Create zero tensor for invalid values
        zero_embedding = torch.zeros(self.output_dim)
        
        if self.cache_embedding:
            # Find values not in cache
            values_to_embed = [v for v in valid_values 
                             if self.embedding_cache.get(v) is None]
            
            if values_to_embed:
                new_embeddings = self._compute_embeddings(values_to_embed)
                for value, embedding in new_embeddings.items():
                    self.embedding_cache.put(value, embedding)
            
            # Add zero embeddings for invalid values to cache
            for value in invalid_values:
                self.embedding_cache.put(value, zero_embedding)
            
            # Fetch embeddings from cache (they're on CPU)
            embeddings = [self.embedding_cache.get(value) for value in column]
        else:
            # Compute embeddings for unique valid values only
            temp_cache = self._compute_embeddings(valid_values)
            # Add zero embeddings for invalid values
            for value in invalid_values:
                temp_cache[value] = zero_embedding
            embeddings = [temp_cache[value] for value in column]

        # Stack and move to GPU only at the end
        stacked = torch.stack(embeddings)
        return stacked.to(self.device)

    def _compute_loss(self, probabilities, target_column, config):
        """
        Compute cross entropy loss between predicted probabilities and target categories.

        Args:
            probabilities (torch.Tensor): Predicted probability distribution over categories
            target_column (pd.Series): Target categories

        Returns:
            torch.Tensor: Computed cross entropy loss
        """
        categories = config["categories"]
        
        # Create a category-to-index mapping for O(1) lookups
        cat_to_idx = {cat: idx for idx, cat in enumerate(categories)}
        
        # Create a mask for valid categories and find their indices
        valid_mask = target_column.isin(categories).values
        
        # If no valid categories remain, return zero loss
        if not valid_mask.any():
            warnings.warn("No valid categories found in target_column. Returning zero loss.")
            return torch.tensor(0.0, device=self.device, requires_grad=True)
        
        # Warn about invalid categories (only once per category)
        # Create Series with matching index to target_column
        valid_mask_series = pd.Series(valid_mask, index=target_column.index)
        invalid_cats = set(target_column[~valid_mask_series].unique())
        if invalid_cats:
            warnings.warn(f"Categories {list(invalid_cats)} not found in the provided categories list. These entries will be ignored in loss calculation.")
        
        # Get indices for valid categories using the mapping
        target_indices = torch.tensor([cat_to_idx[cat] for cat in target_column[valid_mask]], 
                                      device=self.device)
        
        # Filter probabilities to only include valid entries
        valid_probabilities = probabilities[valid_mask]
        
        # Compute cross entropy loss
        return F.cross_entropy(valid_probabilities, target_indices)

    def _inverse_vectorize(self, tensor, config, mode='inference'):
        """
        Reverse the embedding transformation back to categories.

        Args:
            tensor (torch.Tensor): The tensor to inverse transform.
            config (dict): Configuration dictionary, must include "categories".
            mode (str): 'inference' or 'train' mode.

        Returns:
            Union[pd.Series, torch.Tensor]: Original categories or probability distribution.
        """
        if not isinstance(tensor, torch.Tensor):
            raise ValueError("Input tensor must be a torch.Tensor.")

        # Store config for loss computation
        self.current_config = config
        categories = config["categories"]

        # Get category embeddings using cache if available
        if self.cache_embedding:
            # Check which categories are not in cache
            categories_to_embed = [cat for cat in categories 
                                 if self.embedding_cache.get(cat) is None]
            
            if categories_to_embed:
                new_embeddings = self._compute_embeddings(categories_to_embed)
                for cat, emb in new_embeddings.items():
                    self.embedding_cache.put(cat, emb)
            
            # Fetch all embeddings from cache (they're on CPU)
            category_embeddings = [self.embedding_cache.get(cat) for cat in categories]
        else:
            # Compute embeddings for all categories without caching
            temp_cache = self._compute_embeddings(categories)
            category_embeddings = [temp_cache[cat] for cat in categories]

        # Stack embeddings and move to GPU
        category_embeddings = torch.stack(category_embeddings).to(self.device)

        # Project category embeddings
        category_embeddings = self.projection(category_embeddings)

        # Project input tensor
        projected_tensor = self.projection(tensor)

        # Compute dot-product similarity using projected embeddings
        similarities = torch.matmul(projected_tensor, category_embeddings.T)

        # Apply softmax to get probabilities
        probabilities = similarities / 0.1

        if mode == "train":
            return probabilities

        # Decode the closest category
        closest_indices = probabilities.argmax(dim=1).tolist()
        decoded_categories = [categories[idx] for idx in closest_indices]

        return pd.Series(decoded_categories)

    def estimate_cache_memory(self):
        """
        Estimate current cache memory usage in MB.
        """
        if not self.cache_embedding:
            return 0
        
        return (len(self.embedding_cache.cache) * self.memory_per_embedding) / (1024 * 1024)

    def clear_cache(self):
        """
        Clear the embedding cache.
        """
        if self.cache_embedding:
            self.embedding_cache = LRUCache(self.embedding_cache.capacity)

    def _vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Batch implementation for transforming multiple categorical columns into vectors.
        
        Args:
            columns (List[pandas.Series]): The columns to transform.
            configs (List[dict]): The configuration dictionaries for each column.
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        # Collect all unique values across all columns to compute embeddings in one batch
        all_unique_values = set()
        
        # Track valid values for each column
        valid_values_by_column = []
        
        # First, identify valid values in each column
        for column, config in zip(columns, configs):
            categories = config["categories"]
            # Create mask of valid values
            valid_mask = column.isin(categories)
            valid_values = column[valid_mask].unique()
            valid_values_by_column.append((valid_mask, valid_values))
            all_unique_values.update(valid_values)
        
        # Convert to list for consistent ordering
        all_unique_values = list(all_unique_values)
        
        # Compute embeddings for all unique values at once
        if self.cache_embedding:
            # Find values not in cache
            values_to_embed = [v for v in all_unique_values 
                             if self.embedding_cache.get(v) is None]
            
            if values_to_embed:
                new_embeddings = self._compute_embeddings(values_to_embed)
                for value, embedding in new_embeddings.items():
                    self.embedding_cache.put(value, embedding)
        else:
            # Compute embeddings for all unique values
            temp_cache = self._compute_embeddings(all_unique_values)
        
        # Process each column using the computed/cached embeddings
        result = []
        for (column, config), (valid_mask, valid_values) in zip(zip(columns, configs), valid_values_by_column):
            # Initialize tensor of zeros for all rows
            batch_size = len(column)
            embedding_dim = self.embedding_dim
            embeddings = torch.zeros((batch_size, embedding_dim), device=self.device)
            
            # Get embeddings for valid values
            if self.cache_embedding:
                # Fetch embeddings from cache for valid values
                valid_embeddings = [self.embedding_cache.get(value) for value in valid_values]
            else:
                # Use temporary cache for valid values
                valid_embeddings = [temp_cache[value] for value in valid_values]
            
            # Create mapping from value to embedding for valid values
            value_to_embedding = {val: emb for val, emb in zip(valid_values, valid_embeddings)}
            
            # Fill in embeddings for valid values
            valid_indices = torch.where(torch.tensor(valid_mask.values, device=self.device))[0]
            for idx in valid_indices:
                value = column.iloc[idx.item()]  # Convert tensor to Python int using item()
                embeddings[idx] = value_to_embedding[value]
            
            result.append(embeddings)
            
        return result
    
    def _inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], mode: str) -> List:
        """
        Batch implementation for inverse transformation of vectors to original values.
        
        Args:
            tensors (List[torch.Tensor]): The tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries for transformations.
            mode (str): 'inference' or 'train'
            
        Returns:
            List: List of reconstructed values for each tensor
        """
        # Collect all categories from all configs
        all_categories = set()
        for config in configs:
            all_categories.update(config["categories"])
            
        # Convert to list for consistent ordering
        all_categories = list(all_categories)
        
        # Get category embeddings for all categories at once
        if self.cache_embedding:
            # Check which categories are not in cache
            categories_to_embed = [cat for cat in all_categories 
                                 if self.embedding_cache.get(cat) is None]
            
            if categories_to_embed:
                new_embeddings = self._compute_embeddings(categories_to_embed)
                for cat, emb in new_embeddings.items():
                    self.embedding_cache.put(cat, emb)
        else:
            # Compute embeddings for all categories
            temp_cache = self._compute_embeddings(all_categories)
            
        # Process each tensor using its config
        result = []
        for i, (tensor, config) in enumerate(zip(tensors, configs)):
            # Store config for loss computation
            self.current_config = config
            categories = config["categories"]
            
            if self.cache_embedding:
                # Fetch embeddings from cache
                category_embeddings = [self.embedding_cache.get(cat) for cat in categories]
            else:
                # Use temporary cache
                category_embeddings = [temp_cache[cat] for cat in categories]
                
            # Stack embeddings and move to GPU
            category_embeddings = torch.stack(category_embeddings).to(self.device)
            
            # Project category embeddings
            category_embeddings = self.projection(category_embeddings)
            
            # Project input tensor
            projected_tensor = self.projection(tensor)
            
            # Compute dot-product similarity
            similarities = torch.matmul(projected_tensor, category_embeddings.T)
            
            # Apply temperature scaling
            probabilities = similarities / 0.1
            
            if mode == "train":
                result.append(probabilities)
            else:
                # Decode the closest category
                closest_indices = probabilities.argmax(dim=1).tolist()
                decoded_categories = [categories[idx] for idx in closest_indices]
                result.append(pd.Series(decoded_categories))
                
        return result

if __name__ == "__main__":
    # Initialize the vectorizer
    vectorizer = CategoricalVectorizer()

    # Sample data
    column = pd.Series(['Hand Crimpers & Strippers', 'Vintage Stereo Receivers', 'PC Laptops & Netbooks', 'Art Pens & Markers', 'PC Laptops & Netbooks'])
    config = {"categories": ['Art Pens & Markers', 'Hand Crimpers & Strippers', 'Vintage Stereo Receivers', 'PC Laptops & Netbooks', 'Network Cards']}

    # Test 1: Basic vectorization and inference mode
    vectors = vectorizer.vectorize(column, config)
    print("\nTest 1: Basic vectorization")
    print("Vectors shape:", vectors.shape)
    decoded_column = vectorizer.inverse_vectorize(vectors, config)
    print("Decoded column:\n", decoded_column)

    # Test 2: Training mode - probability distribution
    probs,loss = vectorizer.inverse_vectorize(vectors, config, mode='train', target_column=column)
    print("\nTest 2: Training mode probabilities")
    print("Probability shape:", probs.shape)
    print("Sum of probabilities for each row:", probs.sum(dim=1))
    print("Loss:", loss)
    # Verify highest probability matches the original categories
    max_prob_indices = probs.argmax(dim=1)
    predicted_categories = [config["categories"][idx] for idx in max_prob_indices]
    print("\nVerifying predictions match between modes:")
    print("Original:", column.tolist())
    print("Inference mode:", decoded_column.tolist())
    print("Train mode (highest prob):", predicted_categories)
    
    # Test 3: Probability distribution properties
    print("\nTest 3: Probability distribution for first item")
    for cat, prob in zip(config["categories"], probs[0]):
        print(f"{cat}: {prob:.4f}")

    # Test 4: Batch vectorization and inference
    print("\nTest 4: Batch vectorization and inference")
    # Create multiple columns
    columns = [
        pd.Series(['Hand Crimpers & Strippers', 'Vintage Stereo Receivers']),
        pd.Series(['PC Laptops & Netbooks', 'Art Pens & Markers']),
        pd.Series(['Network Cards', 'PC Laptops & Netbooks'])
    ]
    
    configs = [
        {"categories": ['Hand Crimpers & Strippers', 'Vintage Stereo Receivers', 'Network Cards']},
        {"categories": ['Art Pens & Markers', 'PC Laptops & Netbooks', 'Vintage Stereo Receivers']},
        {"categories": ['Network Cards', 'PC Laptops & Netbooks', 'Hand Crimpers & Strippers']}
    ]
    
    # Test batch vectorization
    batch_vectors = vectorizer.vectorize_batch(columns, configs)
    print(f"Number of tensors: {len(batch_vectors)}")
    for i, tensor in enumerate(batch_vectors):
        print(f"Tensor {i} shape: {tensor.shape}")
    
    # Test batch inverse vectorization (inference mode)
    decoded_columns = vectorizer.inverse_vectorize_batch(batch_vectors, configs)
    print("\nDecoded columns:")
    for i, col in enumerate(decoded_columns):
        print(f"Column {i}:\n{col}")
    
    # Test 5: Batch training mode
    print("\nTest 5: Batch training mode")
    # Use the same columns as target columns for illustration
    target_columns = columns
    
    # Run batch inverse vectorization in train mode
    batch_probs, combined_loss = vectorizer.inverse_vectorize_batch(
        batch_vectors, configs, mode='train', target_columns=target_columns
    )
    
    print(f"Combined loss: {combined_loss.item()}")
    print(f"Number of probability tensors: {len(batch_probs)}")
    
    for i, probs in enumerate(batch_probs):
        print(f"Probability tensor {i} shape: {probs.shape}")
        
        # Verify highest probability matches
        max_prob_indices = probs.argmax(dim=1).tolist()
        predicted_categories = [configs[i]["categories"][idx] for idx in max_prob_indices]
        print(f"Original column {i}: {columns[i].tolist()}")
        print(f"Predicted (highest prob): {predicted_categories}")
    
    # Test 6: Compare performance between batch and sequential processing
    print("\nTest 6: Performance comparison")
    import time
    
    # Create larger test data
    num_columns = 10
    column_length = 100
    all_categories = ['Hand Crimpers & Strippers', 'Vintage Stereo Receivers', 
                     'PC Laptops & Netbooks', 'Art Pens & Markers', 'Network Cards']
    
    large_columns = []
    large_configs = []
    
    for i in range(num_columns):
        # Create a column with random selections from all_categories
        import random
        data = [random.choice(all_categories) for _ in range(column_length)]
        large_columns.append(pd.Series(data))
        large_configs.append({"categories": all_categories})
    
    # Sequential processing time
    start_time = time.time()
    sequential_vectors = [vectorizer._vectorize(col, cfg) for col, cfg in zip(large_columns, large_configs)]
    sequential_time = time.time() - start_time
    print(f"Sequential processing time: {sequential_time:.4f}s")
    
    # Batch processing time
    start_time = time.time()
    batch_vectors = vectorizer._vectorize_batch(large_columns, large_configs)
    batch_time = time.time() - start_time
    print(f"Batch processing time: {batch_time:.4f}s")
    print(f"Speedup: {sequential_time / batch_time:.2f}x")
    
    # Clear cache for fair comparison
    vectorizer.clear_cache()
    
    # Sequential inverse processing time
    start_time = time.time()
    sequential_inverse = [vectorizer._inverse_vectorize(tensor, cfg, "inference") 
                         for tensor, cfg in zip(sequential_vectors, large_configs)]
    sequential_inverse_time = time.time() - start_time
    print(f"Sequential inverse processing time: {sequential_inverse_time:.4f}s")
    
    # Batch inverse processing time
    start_time = time.time()
    batch_inverse = vectorizer._inverse_vectorize_batch(batch_vectors, large_configs, "inference")
    batch_inverse_time = time.time() - start_time
    print(f"Batch inverse processing time: {batch_inverse_time:.4f}s")
    print(f"Inverse speedup: {sequential_inverse_time / batch_inverse_time:.2f}x")
