                                                              
import logging
import uuid
from typing import Union, Dict, Any, Optional, Tuple, List
import concurrent.futures
import os                           

import torch
import math                          
import numpy as np                           
from fortress.core.vector_store_interface import VectorStoreInterface               

from fortress.core.embedding_model import EmbeddingModel
from fortress.core.nlp_analyzer import NLPAnalyzer
from fortress.common.data_models import InputPromptRecord, DatabasePromptRecord, QueryFeatures
from fortress.config import get_config

logger = logging.getLogger(__name__)

class PromptProcessor:
    """
    Processes input prompts to generate embeddings and extract NLP features.
    """
    def __init__(self, embedding_model: EmbeddingModel, nlp_analyzer: NLPAnalyzer, vector_store: VectorStoreInterface):
        """
        Initializes the PromptProcessor.

        Args:
            embedding_model: An instance of EmbeddingModel.
            nlp_analyzer: An instance of NLPAnalyzer.
            vector_store: An instance of VectorStoreInterface.
        """
        self.embedding_model = embedding_model
        self.nlp_analyzer = nlp_analyzer
        self.vector_store = vector_store
        self.config = get_config()
        self.long_context_config = self.config.get('long_context', {})
        self.max_seq_length = self.long_context_config.get('max_sequence_length_for_embedding', 512)
        self.chunk_overlap = self.long_context_config.get('chunk_overlap', 64)
        
        clustering_config = self.config.get('clustering', {})
        self.cluster_field_name_for_assignment = clustering_config.get('cluster_field_name_for_assignment', "prompt_category")
        
        detection_pipeline_config = self.config.get('detection_pipeline', {})
        self.cluster_assignment_top_k = clustering_config.get('cluster_assignment_top_k')
        if self.cluster_assignment_top_k is None:
            self.cluster_assignment_top_k = detection_pipeline_config.get('top_k_semantic_search', 7)
            logger.info(f"Using detection_pipeline's top_k_semantic_search ({self.cluster_assignment_top_k}) for cluster assignment")
        
                                                         
                                                                                                               
                                                                                   
        self.max_workers = self.config.get('prompt_processor_max_workers', 10)                                                                 
        logger.info(f"PromptProcessor initialized with max_workers for ThreadPool: {self.max_workers}")


    def _chunk_text(self, text: str) -> list[str]:
        words = text.split()
        if len(words) <= self.max_seq_length:
            return [text]
        chunks = []
        current_chunk_start = 0
        while current_chunk_start < len(words):
            end_index = min(current_chunk_start + self.max_seq_length, len(words))
            chunk_words = words[current_chunk_start:end_index]
            chunks.append(" ".join(chunk_words))
            if end_index == len(words):
                break
            current_chunk_start += (self.max_seq_length - self.chunk_overlap)
            if current_chunk_start >= len(words):
                break
        logger.debug(f"Chunked text into {len(chunks)} chunks.")
        return chunks

    def _process_single_item_features(self, query_text: str, query_embedding_list: List[float]) -> Tuple[Optional[QueryFeatures], Optional[List[float]], Optional[str]]:
        """
        Helper function to process CPU-bound tasks for a single item from a batch.
        This function will be called in parallel by the ThreadPoolExecutor.
        Ensure methods called here (NLP, vector_store interaction, log_probs) are thread-safe
        and ideally release the GIL if they are CPU-bound Python code.
        """
        try:
            prompt_category = None
            prompt_categories_with_weights = None
            if self.vector_store:
                try:
                    prompt_categories_with_weights = self.vector_store.assign_clusters_with_weights_to_new_prompt(
                        query_embedding_list,
                        cluster_field_name=self.cluster_field_name_for_assignment,
                        top_k_neighbors=self.cluster_assignment_top_k
                    )
                    if prompt_categories_with_weights:
                        prompt_category = prompt_categories_with_weights[0][0]
                except Exception as e_cat:
                    logger.warning(f"Failed to assign prompt categories for '{query_text[:50]}...': {e_cat}")
            
            token_log_probs = []
            try:
                                                                                           
                                                                                                        
                token_log_probs = self.embedding_model.get_token_source_log_probabilities(query_text)
                if token_log_probs is None: token_log_probs = []
                if not isinstance(token_log_probs, list) or not all(isinstance(x, float) for x in token_log_probs):
                    logger.warning(f"Token source log probabilities are not a list of floats for '{query_text[:50]}...'. Setting to empty list.")
                    token_log_probs = []
            except Exception as e_log_probs:
                logger.error(f"Error generating token log probabilities for '{query_text[:50]}...': {e_log_probs}")
                token_log_probs = []

                                                                                                           
            nlp_features_dict = self.nlp_analyzer.extract_all_features(query_text)

            query_features = QueryFeatures(
                **nlp_features_dict,
                prompt_category=prompt_category,
                prompt_categories_with_weights=prompt_categories_with_weights,
                token_source_log_probabilities=token_log_probs
            )
            return query_features, query_embedding_list, None

        except Exception as e:
            error_message = f"Unexpected error during single item feature processing for '{query_text[:50]}...': {e}"
            logger.exception(error_message)
            return None, query_embedding_list, error_message


    def _process_single_item_features_v2(self, query_text: str, query_embedding_list: List[float], token_log_probs: List[float]) -> Tuple[Optional[QueryFeatures], Optional[List[float]], Optional[str]]:
        """
        Helper function to process CPU-bound tasks for a single item from a batch.
        Receives pre-computed token_log_probs.
        """
        try:
            prompt_category = None
            prompt_categories_with_weights = None
            if self.vector_store:
                try:
                    prompt_categories_with_weights = self.vector_store.assign_clusters_with_weights_to_new_prompt(
                        query_embedding_list,
                        cluster_field_name=self.cluster_field_name_for_assignment,
                        top_k_neighbors=self.cluster_assignment_top_k
                    )
                    if prompt_categories_with_weights:
                        prompt_category = prompt_categories_with_weights[0][0]
                except Exception as e_cat:
                    logger.warning(f"Failed to assign prompt categories for '{query_text[:50]}...': {e_cat}")

            nlp_features_dict = self.nlp_analyzer.extract_all_features(query_text)

            query_features = QueryFeatures(
                **nlp_features_dict,
                prompt_category=prompt_category,
                prompt_categories_with_weights=prompt_categories_with_weights,
                token_source_log_probabilities=token_log_probs
            )
            return query_features, query_embedding_list, None

        except Exception as e:
            error_message = f"Unexpected error during single item feature processing for '{query_text[:50]}...': {e}"
            logger.exception(error_message)
            return None, query_embedding_list, error_message

    def process_for_query_batch(self, query_texts: List[str]) -> List[Tuple[Optional[QueryFeatures], Optional[List[float]], Optional[str]]]:
        logger.info(f"PromptProcessor: Starting process_for_query_batch for {len(query_texts)} items.")
        
        if not query_texts:
            return []

                                                                         
        logger.debug(f"PromptProcessor: Calling embedding_model.get_embedding for {len(query_texts)} texts.")
        all_query_embeddings_tensor = self.embedding_model.get_embedding(query_texts)
        logger.debug(f"PromptProcessor: embedding_model.get_embedding returned.")

        if all_query_embeddings_tensor is None or all_query_embeddings_tensor.nelement() == 0:
            logger.error("PromptProcessor: Failed to generate embeddings for the batch.")
            return [(None, None, "Batch embedding generation failed.") for _ in query_texts]
        
        all_query_embedding_lists = []
        if all_query_embeddings_tensor.ndim == 1 and len(query_texts) == 1:
            all_query_embedding_lists.append(all_query_embeddings_tensor.tolist())
        elif all_query_embeddings_tensor.ndim == 2:
            for i in range(all_query_embeddings_tensor.shape[0]):
                all_query_embedding_lists.append(all_query_embeddings_tensor[i].tolist())
        else:
            logger.error(f"PromptProcessor: Unexpected embedding tensor shape: {all_query_embeddings_tensor.shape}")
            return [(None, None, "Unexpected batch embedding tensor shape.") for _ in query_texts]
        logger.debug(f"PromptProcessor: Successfully converted tensor embeddings to lists for {len(all_query_embedding_lists)} items.")

                                                                                             
        logger.debug(f"PromptProcessor: Calling embedding_model.get_token_source_log_probabilities_batch for {len(query_texts)} texts.")
        all_token_log_probs_batch = self.embedding_model.get_token_source_log_probabilities_batch(query_texts)
        logger.debug(f"PromptProcessor: embedding_model.get_token_source_log_probabilities_batch returned {len(all_token_log_probs_batch)} sets of log_probs.")

        if not all_token_log_probs_batch or len(all_token_log_probs_batch) != len(query_texts):
            logger.error("PromptProcessor: Failed to generate token_source_log_probabilities for the batch or count mismatch.")
            all_token_log_probs_batch = [[] for _ in query_texts]

                                                                       
        batch_results_ordered = [None] * len(query_texts)

        tasks_to_submit = []
        for i, query_text in enumerate(query_texts):
            tasks_to_submit.append({
                'index': i,
                'text': query_text,
                'embedding': all_query_embedding_lists[i],
                'log_probs': all_token_log_probs_batch[i]
            })
        
        logger.info(f"PromptProcessor: Submitting {len(tasks_to_submit)} items to ThreadPoolExecutor with {self.max_workers} workers for NLP and category assignment.")

        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            future_to_task_info = {
                executor.submit(self._process_single_item_features_v2, task['text'], task['embedding'], task['log_probs']): task 
                for task in tasks_to_submit
            }

            completed_count = 0
            for future in concurrent.futures.as_completed(future_to_task_info):
                task_info = future_to_task_info[future]
                original_index = task_info['index']
                query_text_for_log = task_info['text']
                try:
                    result_item = future.result()
                    batch_results_ordered[original_index] = result_item
                    completed_count +=1
                    if completed_count % 10 == 0 or completed_count == len(tasks_to_submit):
                        logger.info(f"PromptProcessor: Completed CPU-bound tasks for {completed_count}/{len(tasks_to_submit)} items in current batch.")
                except Exception as exc:
                    logger.error(f"PromptProcessor: Task for item index {original_index} (text: '{query_text_for_log[:50]}...') generated an exception in ThreadPool: {exc}", exc_info=True)
                    batch_results_ordered[original_index] = (None, all_query_embedding_lists[original_index], str(exc))
        
        logger.info(f"PromptProcessor: Finished processing batch of {len(query_texts)} items.")
        return batch_results_ordered


    def process_for_query(self, query_text: str, metadata: Optional[Dict[str, Any]] = None) -> Tuple[Optional[QueryFeatures], Optional[List[float]], Optional[str]]:
        logger.debug(f"Processing single query (via batch): '{query_text[:100]}...'")
        results = self.process_for_query_batch([query_text])
        if results:
            return results[0]
        return None, None, "Failed to process query using batch method."

    
    def process_for_database(self, input_prompt: InputPromptRecord) -> Union[DatabasePromptRecord, None]:
        text_to_process = input_prompt.original_prompt
        embedding_tensor = self.embedding_model.get_embedding([text_to_process]) 
        embedding_list: Optional[List[float]] = None
        if embedding_tensor is not None and embedding_tensor.nelement() > 0 : 
            try:
                if embedding_tensor.ndim == 2 and embedding_tensor.shape[0] == 1:
                    raw_list = embedding_tensor.tolist()[0]
                    embedding_list = [float(x) for x in raw_list]
                elif embedding_tensor.ndim == 1: 
                    raw_list = embedding_tensor.tolist()
                    embedding_list = [float(x) for x in raw_list]
                else:
                    logger.error(f"Failed to generate embedding (unexpected tensor shape: {embedding_tensor.shape}) for prompt ID {input_prompt.prompt_id}. Expected (1, dim) or (dim,).")
                    embedding_list = None 
            except Exception as e_conv:
                logger.error(f"Error converting embedding tensor to list for prompt ID {input_prompt.prompt_id}: {e_conv}", exc_info=True)
                embedding_list = None 
        else:
            logger.error(f"Failed to generate embedding (received None or empty tensor) for prompt ID {input_prompt.prompt_id}.")
            embedding_list = None 

        nlp_features_dict = self.nlp_analyzer.extract_all_features(text_to_process)
        calculated_perplexity: Optional[float] = None
        try:
            calculated_perplexity = self.embedding_model.get_perplexity(text_to_process)
            if calculated_perplexity is not None:
                logger.debug(f"Calculated perplexity for prompt ID {input_prompt.prompt_id}: {calculated_perplexity}")
            else:
                logger.warning(f"Perplexity calculation returned None for prompt ID {input_prompt.prompt_id}.")
        except AttributeError: 
            logger.error(f"EmbeddingModel does not have 'get_perplexity' method. Please check the EmbeddingModel implementation.", exc_info=True)
        except Exception as e:
            logger.error(f"Error calculating perplexity for prompt ID {input_prompt.prompt_id}: {e}", exc_info=True)
        
        nlp_features_dict['perplexity'] = calculated_perplexity

        try:
            record_data = {
                **input_prompt.model_dump(),
                **nlp_features_dict,
                "embedding": embedding_list, 
            }
            db_record = DatabasePromptRecord(**record_data)
            return db_record
        except Exception as e:
            logger.error(f"Error creating DatabasePromptRecord for prompt ID {input_prompt.prompt_id}: {e}", exc_info=True)
            return None
