"""
Document management for the QA system
"""

import os
import logging
from typing import Dict, List, Optional
from pathlib import Path

from embeddings import RobertaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
    PyPDFLoader, TextLoader, UnstructuredMarkdownLoader, 
    CSVLoader, UnstructuredExcelLoader
)

logger = logging.getLogger(__name__)

class DocumentManager:
    """Manages document loading, indexing, and retrieval operations."""
    
    def __init__(self, embedding_model=None, faiss_index_path=None):
        # Use Roberta embeddings instead of OpenAI embeddings
        self.embedding_model = embedding_model or RobertaEmbeddings()
        self.document_store = None
        # Adjust chunk size and overlap for better document processing
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
        self.document_chunks = []  # Store all document chunks
        self.faiss_index_path = faiss_index_path or "./faiss_index"  # FAISS index save path
        # Store mapping between document sources and chunks
        self.document_source_map = {}  # Map from source path to list of chunk indices
        # Store mapping between document sources and their dedicated vector indices
        self.document_indices = {}  # Map from source path to dedicated FAISS index
    
    def load_documents(self, document_paths: List[str]) -> List[Document]:
        """Load documents from various file formats and split into chunks."""
        documents = []
        
        for path in document_paths:
            try:
                path_obj = Path(path)
                if not path_obj.exists():
                    logger.warning(f"Document not found: {path}")
                    continue
                
                # Select appropriate loader based on file extension
                if path_obj.suffix.lower() == '.pdf':
                    loader = PyPDFLoader(str(path_obj))
                elif path_obj.suffix.lower() == '.txt':
                    loader = TextLoader(str(path_obj))
                elif path_obj.suffix.lower() == '.md':
                    loader = UnstructuredMarkdownLoader(str(path_obj))
                elif path_obj.suffix.lower() == '.csv':
                    loader = CSVLoader(str(path_obj))
                elif path_obj.suffix.lower() in ['.xlsx', '.xls']:
                    loader = UnstructuredExcelLoader(str(path_obj))
                else:
                    logger.warning(f"Unsupported file format: {path}")
                    continue
                
                # Load and add documents
                doc_objects = loader.load()
                for doc in doc_objects:
                    doc.metadata['source'] = str(path_obj)
                
                documents.extend(doc_objects)
                logger.info(f"Loaded {len(doc_objects)} segments from {path}")
            
            except Exception as e:
                logger.error(f"Error loading document {path}: {str(e)}")
        
        # Split documents into chunks
        if documents:
            self.document_chunks = self.text_splitter.split_documents(documents)
            logger.info(f"Split documents into {len(self.document_chunks)} chunks")
            
            # Build mapping between document sources and chunks
            for i, chunk in enumerate(self.document_chunks):
                source = chunk.metadata.get('source')
                if source:
                    if source not in self.document_source_map:
                        self.document_source_map[source] = []
                    self.document_source_map[source].append(i)
        
        return self.document_chunks
    
    def index_documents(self, documents: List[Document]):
        """Create a searchable FAISS index from document chunks."""
        try:
            # Ensure we have document chunks
            if not documents and self.document_chunks:
                documents = self.document_chunks
            elif not documents and not self.document_chunks:
                logger.error("No documents to index")
                return None
            
            # Create FAISS vector store
            self.document_store = FAISS.from_documents(
                documents=documents,
                embedding=self.embedding_model
            )
            
            # Create separate FAISS indices for each document source
            for source, chunk_indices in self.document_source_map.items():
                source_chunks = [self.document_chunks[i] for i in chunk_indices]
                self.document_indices[source] = FAISS.from_documents(
                    documents=source_chunks,
                    embedding=self.embedding_model
                )
                logger.info(f"Created dedicated index for {source} with {len(source_chunks)} chunks")
            
            # Save FAISS index to disk (optional)
            os.makedirs(os.path.dirname(self.faiss_index_path), exist_ok=True)
            self.document_store.save_local(self.faiss_index_path)
            
            # Save indices for each document
            for source, index in self.document_indices.items():
                source_name = os.path.basename(source)
                source_dir = os.path.join(self.faiss_index_path, "document_indices")
                os.makedirs(source_dir, exist_ok=True)
                index_path = os.path.join(source_dir, f"{source_name}")
                index.save_local(index_path)
            
            logger.info(f"Indexed {len(documents)} document chunks in FAISS")
            return self.document_store
        
        except Exception as e:
            logger.error(f"Error indexing documents in FAISS: {str(e)}")
            return None
    
    def retrieve_relevant_documents(self, query: str, k: int = 5, document_paths: List[str] = None) -> List[Document]:
        """Retrieve the most relevant document chunks for a query using FAISS.
        
        Args:
            query: The search query
            k: Number of chunks to retrieve per document
            document_paths: Optional list of document paths to search in. If None, search in all documents.
        """
        if not self.document_store:
            logger.error("FAISS document store is not initialized")
            return []
        
        try:
            if document_paths:
                # Retrieve from user-selected documents
                logger.info(f"Retrieving from {len(document_paths)} selected documents")
                all_docs = []
                
                for path in document_paths:
                    path = str(Path(path))
                    if path in self.document_indices:
                        # Retrieve from dedicated index
                        docs = self.document_indices[path].similarity_search(query, k=k)
                        all_docs.extend(docs)
                        logger.info(f"Retrieved {len(docs)} chunks from document: {path}")
                    else:
                        logger.warning(f"No dedicated index found for {path}")
                
                if not all_docs:
                    # Fallback to global search if no results from dedicated indices
                    all_docs = self.document_store.similarity_search(query, k=k*len(document_paths))
                    logger.info(f"Retrieved {len(all_docs)} chunks from global index as fallback")
                
                return all_docs
            else:
                # Search across all documents
                logger.info("Retrieving from all documents")
                docs = self.document_store.similarity_search(query, k=k)
                logger.info(f"Retrieved {len(docs)} chunks from global index")
                return docs
        
        except Exception as e:
            logger.error(f"Error retrieving documents: {str(e)}")
            return []
    
    def load_document_content(self, document_path: str) -> str:
        """Load the full content of a document without chunking."""
        path_obj = Path(document_path)
        if not path_obj.exists():
            logger.error(f"Document not found: {document_path}")
            return ""
        
        try:
            # Select appropriate loader based on file extension
            if path_obj.suffix.lower() == '.pdf':
                loader = PyPDFLoader(str(path_obj))
            elif path_obj.suffix.lower() == '.txt':
                loader = TextLoader(str(path_obj))
            elif path_obj.suffix.lower() == '.md':
                loader = UnstructuredMarkdownLoader(str(path_obj))
            else:
                logger.error(f"Unsupported file format for full content loading: {document_path}")
                return ""
            
            # Load document
            docs = loader.load()
            
            # Concatenate all pages/sections
            content = "\n\n".join([doc.page_content for doc in docs])
            logger.info(f"Loaded full content from {document_path}: {len(content)} characters")
            
            return content
        
        except Exception as e:
            logger.error(f"Error loading document content: {str(e)}")
            return "" 