import asyncio
from pathlib import Path
from typing import List, Optional

from false_facts.knowledge_cutoff.data_models import UniverseContext
from false_facts.knowledge_cutoff.universe import get_key_facts
from false_facts.synth_doc_generation import batch_generate_documents
from nest_asyncio import apply
from safetytooling.apis import InferenceAPI

# Apply nest_asyncio at the start
apply()


async def process_key_facts(context_strs: List[str], api: InferenceAPI):
    return await asyncio.gather(*[get_key_facts(context, api) for context in context_strs])


def run_synth_doc_generation_pipeline(
    context_strs: List[str],
    context_ids: List[str],
    universe_contexts_dir: Path,
    synth_docs_dir: Optional[Path] = None,
    num_threads: int = 20,
    doc_generation_config: Optional[dict] = None,
) -> None:
    """
    Process universe contexts and generate synthetic documents.

    Args:
        context_strs: List of context strings to process
        context_ids: List of IDs corresponding to each context
        universe_contexts_dir: Directory to save universe context files
        synth_docs_dir: Directory to save synthetic documents (required if doc_generation_config is provided)
        num_threads: Number of threads for the InferenceAPI
        doc_generation_config: Configuration for document generation
            - num_doc_types: Number of document types to generate
            - num_doc_ideas: Number of document ideas per type
            - doc_repeat_range: Range for document repetition
            - num_threads: Number of threads for generation
    """
    # Validate inputs
    if doc_generation_config and not synth_docs_dir:
        raise ValueError("synth_docs_dir must be provided when doc_generation_config is specified")

    # Create folders if they don't exist
    universe_contexts_dir.mkdir(parents=True, exist_ok=True)
    synth_docs_dir.mkdir(parents=True, exist_ok=True)

    # Initialize API
    api = InferenceAPI(anthropic_num_threads=num_threads)

    # Extract key facts
    key_facts = asyncio.run(process_key_facts(context_strs, api))

    # Create universe contexts
    universe_contexts = []
    for context, key_fact, context_id in zip(context_strs, key_facts, context_ids):
        universe_context = UniverseContext(universe_context=context, key_facts=key_fact, is_true=False, id=context_id)
        universe_contexts.append(universe_context)
        print(f"Generated {len(key_fact)} key facts for {context_id}")
        print(f"Key facts: {key_fact}")

    # Save universe contexts
    universe_contexts_dir.mkdir(parents=True, exist_ok=True)
    for context in universe_contexts:
        path = universe_contexts_dir / f"{context.id}.jsonl"
        with open(path, "w") as f:
            f.write(context.model_dump_json() + "\n")

    # Generate documents if config is provided
    if doc_generation_config:
        synth_docs_dir.mkdir(parents=True, exist_ok=True)

        for context in universe_contexts:
            universe_context_path = universe_contexts_dir / f"{context.id}.jsonl"

            batch_generate_documents(
                universe_contexts_path=str(universe_context_path),
                output_path=str(synth_docs_dir),
                num_doc_types=doc_generation_config.get("num_doc_types", 200),
                num_doc_ideas=doc_generation_config.get("num_doc_ideas", 20),
                doc_repeat_range=doc_generation_config.get("doc_repeat_range", 3),
                num_threads=doc_generation_config.get("num_threads", 5),
            )
