"""
Local document processing tools for DrBench Agent
Handles bulk ingestion of document folders and intelligent file search
"""

import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

from .base import ResearchContext, Tool
from .content_processor import ContentProcessor

logger = logging.getLogger(__name__)

LOCAL_INGESTION_TOOL_PURPOSE = """Bulk ingestion of local document folders into the research knowledge base.
        IDEAL FOR: Processing entire folders of documents, initial setup of research corpus, bulk loading enterprise documents.
        USE WHEN: You have folders containing research documents, reports, papers, or other text-based files that should be searchable during research.
        PARAMETERS: folder_paths (list of folder paths), file_extensions (optional filter), recursive (default True)
        OUTPUTS: Ingestion statistics, processed file counts, and confirmation that documents are now searchable in the knowledge base."""  # noqa: E501

LOCAL_FILE_SEARCH_PURPOSE = """Intelligent search within locally ingested documents using semantic similarity.
        IDEAL FOR: Finding specific information within your document collection, targeted retrieval from local files, contextual document search.
        USE WHEN: You need to find specific information from previously ingested local documents, want to focus search on certain file types, or need document excerpts with source references.
        PARAMETERS: query (search terms), file_type_filter (optional), folder_filter (optional), top_k (number of results)
        OUTPUTS: Relevant document excerpts with file paths, similarity scores, and synthesized findings from local document collection."""  # noqa: E501


@dataclass
class IngestionStats:
    """Statistics for document ingestion process"""

    total_files: int = 0
    processed_files: int = 0
    skipped_files: int = 0
    failed_files: int = 0
    total_size_mb: float = 0.0
    supported_formats: Set[str] = None
    processing_time_seconds: float = 0.0

    def __post_init__(self):
        if self.supported_formats is None:
            self.supported_formats = set()


class LocalDocumentIngestionTool(Tool):
    """Tool for bulk ingestion of local document folders into vector store"""

    # Supported file extensions
    SUPPORTED_EXTENSIONS = {
        # Text formats
        ".txt",
        ".md",
        ".csv",
        ".tsv",
        ".log",
        # Document formats
        ".pdf",
        ".docx",
        ".doc",
        ".rtf",
        ".odt",
        # Spreadsheet formats
        ".xlsx",
        ".xls",
        ".ods",
        # Presentation formats
        ".pptx",
        ".ppt",
        ".odp",
        # Web formats
        ".html",
        ".htm",
        ".xml",
        ".json",
        ".jsonl",
    }

    def __init__(self, content_processor: ContentProcessor, max_workers: int = 4):
        self.content_processor = content_processor
        self.max_workers = max_workers

    @property
    def purpose(self) -> str:
        return LOCAL_INGESTION_TOOL_PURPOSE

    def execute(self, query: str, context: ResearchContext) -> Dict[str, Any]:
        """Execute document ingestion from folder paths"""

        try:
            # Parse parameters from query
            # Expected format: "folder_paths=['path1', 'path2'] file_extensions=['.pdf', '.docx'] recursive=True"
            params = self._parse_ingestion_query(query)

            if not params.get("folder_paths"):
                return self.create_error_output(
                    "local_document_ingestion",
                    query,
                    "No folder paths specified. Use format: folder_paths=['path1', 'path2']",
                )

            # Start ingestion
            stats = self.ingest_folders(
                folder_paths=params["folder_paths"],
                file_extensions=params.get("file_extensions"),
                recursive=params.get("recursive", True),
            )

            # Update context with ingested files
            if hasattr(context, "files_created"):
                context.files_created.extend([str(fp) for fp in self._get_processed_files(params["folder_paths"])])

            return self.create_success_output(
                tool_name="local_document_ingestion",
                query=query,
                ingestion_stats=stats.__dict__,
                processed_files=stats.processed_files,
                total_files=stats.total_files,
                supported_formats=list(stats.supported_formats),
                processing_time=stats.processing_time_seconds,
                data_retrieved=stats.processed_files > 0,
                stored_in_vector=True,  # Prevent duplicate storage as research_finding
                results={
                    "summary": f"Successfully ingested {stats.processed_files}/{stats.total_files} documents",
                    "stats": stats.__dict__,
                },
            )

        except Exception as e:
            logger.error(f"Document ingestion failed: {e}")
            return self.create_error_output("local_document_ingestion", query, str(e))

    def ingest_folders(
        self, folder_paths: List[str], file_extensions: Optional[List[str]] = None, recursive: bool = True
    ) -> IngestionStats:
        """
        Ingest all documents from specified folders

        Args:
            folder_paths: List of folder paths to process
            file_extensions: Optional filter for file extensions
            recursive: Whether to process subdirectories

        Returns:
            IngestionStats with processing results
        """
        start_time = datetime.now()
        stats = IngestionStats()

        # Validate extensions
        if file_extensions:
            extensions_set = set(ext.lower() if ext.startswith(".") else f".{ext.lower()}" for ext in file_extensions)
        else:
            extensions_set = self.SUPPORTED_EXTENSIONS

        logger.info(f"Starting document ingestion from {len(folder_paths)} folders")
        logger.info(f"File extensions filter: {extensions_set}")

        # Collect all files to process
        files_to_process = []
        for folder_path in folder_paths:
            folder = Path(folder_path)
            if not folder.exists():
                logger.warning(f"Folder does not exist: {folder}")
                continue

            if not folder.is_dir():
                logger.warning(f"Path is not a directory: {folder}")
                continue

            # Collect files
            pattern = "**/*" if recursive else "*"
            for file_path in folder.glob(pattern):
                if file_path.is_file() and file_path.suffix.lower() in extensions_set:
                    files_to_process.append(file_path)
                    stats.total_size_mb += file_path.stat().st_size / (1024 * 1024)

        stats.total_files = len(files_to_process)
        logger.info(f"Found {stats.total_files} files to process ({stats.total_size_mb:.2f} MB)")

        # Process files in parallel
        if files_to_process:
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                # Submit all files for processing
                future_to_file = {
                    executor.submit(self._process_single_file, file_path): file_path for file_path in files_to_process
                }

                # Process results as they complete
                for future in as_completed(future_to_file):
                    file_path = future_to_file[future]
                    try:
                        result = future.result()
                        if result["success"]:
                            stats.processed_files += 1
                            stats.supported_formats.add(file_path.suffix.lower())
                        else:
                            if result.get("skipped"):
                                stats.skipped_files += 1
                            else:
                                stats.failed_files += 1
                                logger.warning(f"Failed to process {file_path}: {result.get('error')}")
                    except Exception as e:
                        stats.failed_files += 1
                        logger.error(f"Error processing {file_path}: {e}")

        stats.processing_time_seconds = (datetime.now() - start_time).total_seconds()

        logger.info(f"Document ingestion completed in {stats.processing_time_seconds:.2f}s")
        logger.info(
            f"Results: {stats.processed_files} processed, {stats.skipped_files} skipped, {stats.failed_files} failed"
        )

        return stats

    def _process_single_file(self, file_path: Path) -> Dict[str, Any]:
        """Process a single file and return result"""

        try:
            # Check if file already exists in vector store to avoid duplicates
            if hasattr(self.content_processor, "session_cache") and self.content_processor.session_cache:
                cached_doc_id = self.content_processor.session_cache.check_source("local_file", str(file_path))
                if cached_doc_id:
                    return {
                        "success": True,
                        "skipped": True,
                        "reason": "Already processed",
                        "file_path": str(file_path),
                    }

            # Process file using existing ContentProcessor
            result = self.content_processor.process_file(
                file_path=str(file_path),
                query_context=f"Local document from: {file_path.parent}",
                additional_metadata={
                    "source_type": "local_document",
                    "file_path": str(file_path),
                    "ingestion_time": datetime.now().isoformat(),
                    "file_size_bytes": file_path.stat().st_size,
                    "folder_path": str(file_path.parent),
                    "relative_path": (
                        str(file_path.relative_to(file_path.parents[0]))
                        if len(file_path.parts) > 1
                        else str(file_path.name)
                    ),
                },
            )

            if result.get("success"):
                return {"success": True, "file_path": str(file_path), "doc_id": result.get("doc_id")}
            else:
                return {
                    "success": False,
                    "file_path": str(file_path),
                    "error": result.get("error", "Unknown processing error"),
                }

        except Exception as e:
            return {"success": False, "file_path": str(file_path), "error": str(e)}

    def _parse_ingestion_query(self, query: str) -> Dict[str, Any]:
        """Parse ingestion parameters from query string"""

        params = {}

        # Simple parsing for folder_paths, file_extensions, recursive
        # Expected format: "folder_paths=['path1', 'path2'] file_extensions=['.pdf'] recursive=True"

        # Extract folder_paths
        if "folder_paths=" in query:
            try:
                start = query.find("folder_paths=") + len("folder_paths=")
                end = query.find("]", start) + 1
                folder_part = query[start:end]
                # Simple eval for list parsing (in production, use safer parsing)
                folder_paths = eval(folder_part) if folder_part else []
                params["folder_paths"] = folder_paths
            except Exception as e:
                logger.warning(f"Could not parse folder_paths from query: {e}")

        # Extract file_extensions
        if "file_extensions=" in query:
            try:
                start = query.find("file_extensions=") + len("file_extensions=")
                end = query.find("]", start) + 1
                ext_part = query[start:end]
                file_extensions = eval(ext_part) if ext_part else None
                params["file_extensions"] = file_extensions
            except Exception as e:
                logger.warning(f"Could not parse file_extensions from query: {e}")

        # Extract recursive
        if "recursive=" in query:
            params["recursive"] = "recursive=True" in query

        return params

    def _get_processed_files(self, folder_paths: List[str]) -> List[Path]:
        """Get list of files that would be processed (for context updating)"""
        files = []
        for folder_path in folder_paths:
            folder = Path(folder_path)
            if folder.exists() and folder.is_dir():
                for file_path in folder.glob("**/*"):
                    if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS:
                        files.append(file_path)
        return files[:100]  # Limit for context


class LocalFileSearchTool(Tool):
    """Tool for intelligent search within locally ingested documents"""

    def __init__(self, vector_store, model: str):
        self.vector_store = vector_store
        self.model = model

    @property
    def purpose(self) -> str:
        return LOCAL_FILE_SEARCH_PURPOSE

    def execute(self, query: str, context: ResearchContext) -> Dict[str, Any]:
        """Execute search within local documents"""

        try:
            # Parse search parameters
            params = self._parse_search_query(query)
            search_query = params.get("query", query)
            file_type_filter = params.get("file_type_filter")
            folder_filter = params.get("folder_filter")
            top_k = params.get("top_k", 10)

            # Search vector store
            search_results = self.vector_store.search(query=search_query, top_k=top_k * 2)  # Get more results to filter

            # Filter results to local documents only
            local_results = []
            for result in search_results:
                metadata = result.get("metadata", {})

                # Only include local documents
                if metadata.get("source_type") != "local_document":
                    continue

                # Exclude synthesized documents to avoid recursive synthesis
                if metadata.get("type") in ["ai_synthesis_with_sources", "ai_synthesis", "research_finding"]:
                    continue

                # Apply file type filter
                if file_type_filter:
                    file_path = metadata.get("file_path", "")
                    if not any(file_path.lower().endswith(ext.lower()) for ext in file_type_filter):
                        continue

                # Apply folder filter
                if folder_filter:
                    folder_path = metadata.get("folder_path", "")
                    if not any(folder_filter_path in folder_path for folder_filter_path in folder_filter):
                        continue

                local_results.append(result)

            # Limit to requested number
            local_results = local_results[:top_k]

            if not local_results:
                return self.create_error_output(
                    "local_document_search",
                    query,
                    "No relevant local documents found. Make sure documents have been ingested first.",
                )

            # Synthesize results
            synthesis = self._synthesize_local_results(local_results, search_query, context)

            # Store synthesis in vector store with source tracking
            source_doc_ids = [result.get("doc_id") for result in local_results if result.get("doc_id")]
            if self.vector_store and synthesis and source_doc_ids:
                from datetime import datetime

                synthesis_metadata = {
                    "tool_used": "local_document_search",
                    "type": "ai_synthesis_with_sources",
                    "source": "local_synthesis",
                    "query_context": search_query,
                    "synthesis_method": "local_document_search",
                    "source_document_ids": source_doc_ids,
                    "timestamp": datetime.now().isoformat(),
                }

                self.vector_store.store_document(content=synthesis, metadata=synthesis_metadata)

            # Extract file statistics
            file_paths = set()
            file_types = set()
            folders = set()

            for result in local_results:
                metadata = result.get("metadata", {})
                if metadata.get("file_path"):
                    file_paths.add(metadata["file_path"])
                    file_types.add(Path(metadata["file_path"]).suffix)
                if metadata.get("folder_path"):
                    folders.add(metadata["folder_path"])

            return self.create_success_output(
                tool_name="local_document_search",
                query=search_query,
                synthesis=synthesis,
                files_searched=len(file_paths),
                file_types_found=list(file_types),
                folders_searched=list(folders),
                results_count=len(local_results),
                data_retrieved=True,
                stored_in_vector=True,  # Prevent duplicate storage as research_finding
                results={
                    "synthesis": synthesis,
                    "local_documents": [
                        {
                            "file_path": r.get("metadata", {}).get("file_path"),
                            "content_excerpt": (
                                r.get("content", "")[:500] + "..."
                                if len(r.get("content", "")) > 500
                                else r.get("content", "")
                            ),
                            "relevance_score": r.get("score", 0.0),
                        }
                        for r in local_results[:5]  # Top 5 for display
                    ],
                },
            )

        except Exception as e:
            logger.error(f"Local file search failed: {e}")
            return self.create_error_output("local_document_search", query, str(e))

    def _parse_search_query(self, query: str) -> Dict[str, Any]:
        """Parse search parameters from query string"""

        params = {"query": query}

        # Extract filters if present
        # Format: "search_term file_type_filter=['.pdf', '.docx'] folder_filter=['folder1'] top_k=5"

        if "file_type_filter=" in query:
            try:
                start = query.find("file_type_filter=") + len("file_type_filter=")
                end = query.find("]", start) + 1
                filter_part = query[start:end]
                file_type_filter = eval(filter_part) if filter_part else None
                params["file_type_filter"] = file_type_filter
                # Remove filter from query
                params["query"] = query.replace(f"file_type_filter={filter_part}", "").strip()
            except Exception as e:
                logger.warning(f"Could not parse file_type_filter from query: {e}")

        if "folder_filter=" in query:
            try:
                start = query.find("folder_filter=") + len("folder_filter=")
                end = query.find("]", start) + 1
                filter_part = query[start:end]
                folder_filter = eval(filter_part) if filter_part else None
                params["folder_filter"] = folder_filter
                # Remove filter from query
                params["query"] = query.replace(f"folder_filter={filter_part}", "").strip()
            except Exception as e:
                logger.warning(f"Could not parse folder_filter from query: {e}")

        if "top_k=" in query:
            try:
                start = query.find("top_k=") + len("top_k=")
                # Find the next space or end of string
                end = query.find(" ", start)
                if end == -1:
                    end = len(query)
                top_k_str = query[start:end]
                params["top_k"] = int(top_k_str)
                # Remove from query
                params["query"] = query.replace(f"top_k={top_k_str}", "").strip()
            except Exception as e:
                logger.warning(f"Could not parse top_k from query: {e}")

        return params

    def _synthesize_local_results(self, results: List[Dict], query: str, context: ResearchContext) -> str:
        """Synthesize search results from local documents with proper citations"""
        import json

        from drbench.agents.utils import prompt_llm

        if not results:
            return "No local documents found matching the query."

        # Prepare content with document IDs for citations
        doc_content_with_ids = []
        for result in results[:10]:  # Limit for token management
            content = result.get("content", "")
            doc_id = result.get("doc_id")
            metadata = result.get("metadata", {})
            relative_path = metadata.get("relative_path", metadata.get("filename", "Unknown"))

            if content and doc_id:
                doc_content_with_ids.append(
                    {"doc_id": doc_id, "file_name": relative_path, "content": content[:1000]}  # Limit content size
                )

        if not doc_content_with_ids:
            return "Retrieved local documents but no content available for synthesis."

        # Generate synthesis prompt with document attribution
        synthesis_prompt = f"""
Based on the research query: "{query}"
And the original research question: "{context.original_question}"

Documents available for analysis:
{json.dumps(doc_content_with_ids, indent=2)}

CITATION REQUIREMENTS:
- EXACT FORMAT: [DOC:doc_id] - with colon after DOC
- Use INDIVIDUAL citations: [DOC:doc_1][DOC:doc_2] NOT [DOC:doc_1; DOC:doc_2]
- Cite EVERY claim with source documents
- NEVER make claims without document support

Synthesize the following information into key insights:

Provide:
1. Key findings directly relevant to the query
2. Important patterns or trends identified
3. Contradictions or conflicts in the information
4. Gaps that still need research
5. Actionable insights or conclusions

Be comprehensive but concise. Focus on insights that directly address the query.
Every claim MUST have [DOC:doc_id] citations.
"""

        try:
            synthesis = prompt_llm(model=self.model, prompt=synthesis_prompt)
            # Fix common citation formatting issues
            return self._fix_malformed_citations(synthesis)
        except Exception as e:
            return f"Error generating synthesis: {e}"

    def _fix_malformed_citations(self, text: str) -> str:
        """Fix common citation formatting mistakes"""
        import re

        # Fix [DOC doc_id] -> [DOC:doc_id]
        text = re.sub(r"\[DOC\s+([^\]]+)\]", r"[DOC:\1]", text)

        # Fix [DOC_id] -> [DOC:id]
        text = re.sub(r"\[DOC_([^\]]+)\]", r"[DOC:\1]", text)

        return text
