import argparse
import json
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from decimal import Decimal

# Add parent directory to path for utils imports
SCRIPT_DIR = Path(__file__).parent.resolve()
PARENT_DIR = SCRIPT_DIR.parent
if str(PARENT_DIR) not in sys.path:
    sys.path.insert(0, str(PARENT_DIR))

from utils.model_pricing import calculate_cost, format_cost

# Load environment variables
load_dotenv()   

# Import API clients
try:
    from anthropic import Anthropic, AnthropicBedrock
except ImportError:
    print("Warning: anthropic library not installed. Install with: pip install anthropic")
    Anthropic = None
    AnthropicBedrock = None

try:
    from openai import OpenAI, AzureOpenAI
except ImportError:
    print("Warning: openai library not installed. Install with: pip install openai")
    OpenAI = None
    AzureOpenAI = None

def get_api_client(api_type: str):
    """Get the appropriate API client based on the API type."""
    if api_type == "openai":
        if OpenAI is None:
            raise ImportError("openai library not installed")
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY not found in environment")
        return OpenAI(api_key=api_key)

    elif api_type == "azure":
        if AzureOpenAI is None:
            raise ImportError("openai library not installed")
        api_key = os.getenv("AZURE_API_KEY")
        endpoint = os.getenv("AZURE_ENDPOINT")
        api_version = os.getenv("AZURE_API_VERSION", "2024-02-15-preview")

        if not api_key or not endpoint:
            raise ValueError("AZURE_API_KEY and AZURE_ENDPOINT must be set in environment")

        return AzureOpenAI(
            api_key=api_key,
            api_version=api_version,
            azure_endpoint=endpoint
        )

    elif api_type == "anthropic":
        if Anthropic is None:
            raise ImportError("anthropic library not installed")
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ValueError("ANTHROPIC_API_KEY not found in environment")
        return Anthropic(api_key=api_key)

    elif api_type == "anthropic_bedrock":
        if AnthropicBedrock is None:
            raise ImportError("anthropic library not installed")
        # Use same environment variable names as AWS SDK and AnthropicAgent
        aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
        aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
        aws_region = os.getenv("AWS_DEFAULT_REGION") or os.getenv("AWS_REGION", "us-east-1")

        if not aws_access_key or not aws_secret_key:
            raise ValueError("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY must be set")

        return AnthropicBedrock(
            aws_access_key=aws_access_key,
            aws_secret_key=aws_secret_key,
            aws_region=aws_region
        )

    else:
        raise ValueError(f"Invalid API type: {api_type}")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate categorization of elicitation run summaries",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Using task list with all refinement models
    python elicitation_run_categorization.py \\
        --task_list_path task_list_claude_haiku_0pct_baseline_human_filtered.json \\
        --perturbation_model o4-mini-2025-04-16 \\
        --agent claude-haiku-4-5-20251001

    # Using task list filtered to specific refinement model
    python elicitation_run_categorization.py \\
        --task_list_path task_list_claude_haiku_0pct_baseline_human_filtered.json \\
        --perturbation_model o4-mini-2025-04-16 \\
        --agent claude-haiku-4-5-20251001 \\
        --refinement_model_filter gpt-5-2025-08-07
        """
    )

    parser.add_argument('--task_list_path', type=str, required=True,
                       help='Path to task list JSON file (generated by generate_successful_task_list.py)')
    parser.add_argument('--resume_categorization_path', type=str, default=None,
                       help='Path to existing categorization JSON to resume from (skips initial categorization)')

    parser.add_argument('--initial_categorization_batch_size', type=int, default=10,
                       help='Number of elicitation summaries for initial categorization (default: 10)')
    parser.add_argument('--iterative_categorization_batch_size', type=int, default=5,
                       help='Number of elicitation summaries per iteration (default: 5)')

    parser.add_argument('--perturbation_model', type=str, default="o4-mini-2025-04-16",
                       help='Perturbation model used (default: o4-mini-2025-04-16)')
    parser.add_argument('--refinement_model_filter', type=str, default=None,
                       help='Filter to process only tasks for a specific refinement model')
    parser.add_argument('--agent', type=str, required=True,
                       help='The execution agent (e.g., "claude-haiku-4-5-20251001")')

    # API configuration
    parser.add_argument("--api", type=str, choices=["openai", "azure", "anthropic", "anthropic_bedrock"],
                       default="openai",
                       help="API provider for categorization (default: openai)")
    parser.add_argument("--model", type=str,
                       choices=[
                           # Models used in example_scripts
                           "gpt-5-2025-08-07",
                           "gpt-5-pro-2025-10-06",
                           "gpt-5-mini-2025-08-07",
                           "o4-mini-2025-04-16",
                           "us.anthropic.claude-sonnet-4-20250514-v1:0",
                           "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
                           "us.anthropic.claude-haiku-4-5-20251001-v1:0",
                           "us.anthropic.claude-opus-4-1-20250805-v1:0",
                           # Open Source models (e.g, vllm, huggingface)
                           "openai/gpt-oss-20b",
                           "Qwen/Qwen3-30B-A3B-Instruct-2507",
                           "Qwen/Qwen3-Next-80B-A3B-Instruct"
                       ],
                       default="gpt-5-2025-08-07",
                       help="Model name for categorization")
    parser.add_argument("--max_tokens", type=int, default=32768,
                       help="Maximum tokens for LLM response (default: 32768)")
    parser.add_argument("--temperature", type=float, default=1.0,
                       help="Temperature for LLM sampling (default: 1.0)")

    return parser.parse_args()


def load_elicitation_summaries(task_list: List[dict], perturbed_queries_dir: Path, perturbation_model: str, agent: str) -> str:
    """
    Load elicitation summaries for a list of tasks.
    
    Args:
        task_list: List of task dicts with keys: domain, task_id, perturbed_id, refinement_model
        perturbed_queries_dir: Base directory for perturbed queries
        perturbation_model: The perturbation model used
        agent: The execution agent
        
    Returns:
        Formatted string of all elicitation summaries
    """
    elicitation_summaries = ""
    for task in task_list:
        domain = task["domain"]
        task_id = task["task_id"]
        perturbed_id = task["perturbed_id"]
        refinement_model = task["refinement_model"]

        summary_path = (
            perturbed_queries_dir / domain / task_id / perturbation_model / 
            f"perturbed_query_{perturbed_id}" / f"iterative_refinement_{refinement_model}" / 
            f"agent_{agent}" / "elicitation_run_summary.json"
        )
        
        if not summary_path.exists():
            print(f"  Warning: Elicitation run summary not found for task {task_id}:{perturbed_id}, skipping...")
            continue

        with open(summary_path, 'r') as f:
            summary_data = json.load(f)

        original_instruction = summary_data["original_instruction"]
        perturbed_instruction = summary_data["perturbed_instruction"]    
        summary_text = summary_data["summary"]

        elicitation_summaries += f"-----------------------------------\n"
        elicitation_summaries += f"Task ID: {task_id}\n" 
        elicitation_summaries += f"Perturbed ID: {perturbed_id}\n"
        elicitation_summaries += f"Refinement Model: {refinement_model}\n"
        elicitation_summaries += f"Original Instruction: {original_instruction}\n"
        elicitation_summaries += f"Perturbed Instruction: {perturbed_instruction}\n"
        elicitation_summaries += f"Summary: {summary_text}\n"
        elicitation_summaries += "-----------------------------------\n\n"

    return elicitation_summaries

def load_initial_categorization_prompt(elicitation_summaries: str) -> str:
    """Load and populate the initial categorization prompt."""
    initial_categorization_prompt_file = SCRIPT_DIR / "prompts" / "elicitation_run_categorization_prompt.md"
    with open(initial_categorization_prompt_file, 'r', encoding='utf-8') as f:
        initial_categorization_prompt = f.read()

    initial_categorization_prompt = initial_categorization_prompt.replace("{ELICITATION_SUMMARIES}", elicitation_summaries)
    return initial_categorization_prompt

def load_iterative_categorization_prompt(existing_definitions_json: List[dict], elicitation_summaries: str) -> str:
    """Load and populate the iterative categorization prompt."""
    iterative_categorization_prompt_file = SCRIPT_DIR / "prompts" / "elicitation_run_categorization_update_prompt.md"
    with open(iterative_categorization_prompt_file, 'r', encoding='utf-8') as f:
        iterative_categorization_prompt = f.read()

    iterative_categorization_prompt = iterative_categorization_prompt.replace("{EXISTING_DEFINITIONS_JSON}", json.dumps(existing_definitions_json, indent=4))
    iterative_categorization_prompt = iterative_categorization_prompt.replace("{ELICITATION_SUMMARIES}", elicitation_summaries)
    return iterative_categorization_prompt

def generate_categorization(categorization_api_client, api_type, model_name, categorization_prompt, max_tokens, temperature):
    if api_type == "openai" or api_type == "azure":
        # Handle different model types
        if "gpt-5" in model_name.lower() or "o4" in model_name.lower():
            response = categorization_api_client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": categorization_prompt}],
                max_completion_tokens=max_tokens,
                temperature=temperature
            )
        else:
            response = categorization_api_client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": categorization_prompt}],
                max_tokens=max_tokens,
                temperature=temperature
            )

        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        total_tokens = response.usage.total_tokens
        categorization = response.choices[0].message.content

    elif api_type == "anthropic" or api_type == "anthropic_bedrock":
        response = categorization_api_client.messages.create(
            model=model_name,
            messages=[{"role": "user", "content": categorization_prompt}],
            max_tokens=max_tokens,
            temperature=temperature
        )
        input_tokens = response.usage.input_tokens
        output_tokens = response.usage.output_tokens
        total_tokens = input_tokens + output_tokens
        categorization = response.content[0].text

    return categorization, input_tokens, output_tokens, total_tokens

def parse_json_response(response_text: str, required_fields: List[str] = None) -> Dict[str, Any]:
    """
    Parse JSON from LLM response with error handling.

    Args:
        response_text: Raw LLM response text
        required_fields: Optional list of required fields to validate

    Returns:
        Parsed JSON dict, or dict with error info if parsing fails
    """
    import re

    # Look for JSON object in the response
    json_match = re.search(r'\{[\s\S]*\}', response_text)

    if json_match:
        try:
            parsed = json.loads(json_match.group(0))

            # Validate required fields if specified
            if required_fields:
                for field in required_fields:
                    if field not in parsed:
                        print(f"    Warning: Missing field '{field}' in response")
                        parsed[field] = None  # Add missing field with None value

            return parsed

        except json.JSONDecodeError as e:
            print(f"    Warning: Failed to parse JSON: {e}")
            print(f"    Response text (first 200 chars): {response_text[:200]}")
            return {"error": f"JSON decode error: {e}", "raw_response": response_text}
    else:
        print(f"    Warning: No JSON found in response")
        print(f"    Response text (first 200 chars): {response_text[:200]}")
        return {"error": "No JSON found in response", "raw_response": response_text}


def process_batch_categorization(
    categorization_api_client,
    api_type: str,
    model_name: str,
    task_list: List[dict],
    perturbed_queries_dir: Path,
    perturbation_model: str,
    agent: str,
    initial_categorization_batch_size: int,
    iterative_categorization_batch_size: int,
    max_tokens: int,
    temperature: float,
    existing_master_taxonomy: Optional[dict] = None
) -> Tuple[dict, int, int, int]:
    """
    Process all tasks in batches to generate categorization.
    
    Args:
        categorization_api_client: API client for LLM calls
        api_type: Type of API (openai, anthropic, etc.)
        model_name: Model name for categorization
        task_list: List of task dicts with domain, task_id, perturbed_id, refinement_model
        perturbed_queries_dir: Base directory for perturbed queries
        perturbation_model: The perturbation model used
        agent: The execution agent
        initial_categorization_batch_size: Batch size for initial categorization
        iterative_categorization_batch_size: Batch size for iterative updates
        max_tokens: Max tokens for LLM
        temperature: Temperature for LLM
        
    Returns:
        Tuple of (master_taxonomy, input_tokens, output_tokens, total_tokens)
    """
    input_tokens = 0
    output_tokens = 0
    total_tokens = 0

    print(f"Total tasks to categorize: {len(task_list)}")

    ## Initial Categorization ##
    if existing_master_taxonomy:
        master_taxonomy = existing_master_taxonomy
        start_index = 0
        print(f"\nResuming from existing categorization with {len(master_taxonomy.get('categories', []))} categories")
        print(f"Initial categorization step skipped.")
        print(f"Initial taxonomy: {json.dumps(master_taxonomy, indent=4)}")
    else:
        # Load the initial batch of elicitation summaries
        print(f"\nLoading initial batch of {initial_categorization_batch_size} elicitation summaries...")
        initial_batch = task_list[:initial_categorization_batch_size]
        initial_batch_elicitation_summaries = load_elicitation_summaries(
            initial_batch, perturbed_queries_dir, perturbation_model, agent
        )
        print(f"Loaded {len(initial_batch)} elicitation summaries")

        # Load the initial categorization prompt
        print(f"Loading initial categorization prompt...")
        initial_categorization_prompt = load_initial_categorization_prompt(initial_batch_elicitation_summaries)
        print(f"Loaded initial categorization prompt")

        # Generate the initial categorization
        print(f"Generating initial categorization with {model_name}...")
        initial_categorization, initial_input_tokens, initial_output_tokens, initial_total_tokens = generate_categorization(
            categorization_api_client, api_type, model_name, initial_categorization_prompt, max_tokens, temperature
        )
        print(f"Generated initial categorization")

        # Parse the initial categorization
        print(f"Parsing initial categorization...")
        master_taxonomy = parse_json_response(initial_categorization)
        print(f"Parsed initial categorization")

        # Update the token counts
        input_tokens += initial_input_tokens
        output_tokens += initial_output_tokens
        total_tokens += initial_total_tokens

        print(f"\nInitial taxonomy has {len(master_taxonomy.get('categories', []))} categories")
        print(f"Initial taxonomy: {json.dumps(master_taxonomy, indent=4)}")
        start_index = initial_categorization_batch_size

    ## Iterative Categorization ##

    remaining_tasks = len(task_list) - start_index
    num_iterations = (remaining_tasks + iterative_categorization_batch_size - 1) // iterative_categorization_batch_size if remaining_tasks > 0 else 0
    
    print(f"\nProcessing {remaining_tasks} remaining tasks in {num_iterations} iterations...")

    for i in range(start_index, len(task_list), iterative_categorization_batch_size):
        iteration_num = (i - start_index) // iterative_categorization_batch_size + 1
        
        # Load the iterative batch of elicitation summaries
        batch = task_list[i:i+iterative_categorization_batch_size]
        if not batch:
            print(f"\n[Iteration {iteration_num}/{num_iterations}] No tasks to categorize")
            continue

        batch_elicitation_summaries = load_elicitation_summaries(
            batch, perturbed_queries_dir, perturbation_model, agent
        )
        print(f"\n[Iteration {iteration_num}/{num_iterations}] Loaded {len(batch)} elicitation summaries")

        # Prepare definitions JSON from the master taxonomy
        definitions_json = [
            {"name": c["category_name"], "definition": c["definition"]} 
            for c in master_taxonomy.get("categories", [])
        ]

        # Load the iterative categorization prompt
        iterative_categorization_prompt = load_iterative_categorization_prompt(definitions_json, batch_elicitation_summaries)

        # Generate the iterative categorization
        print(f"  Generating iterative categorization with {model_name}...")
        iterative_categorization, iterative_input_tokens, iterative_output_tokens, iterative_total_tokens = generate_categorization(
            categorization_api_client, api_type, model_name, iterative_categorization_prompt, max_tokens, temperature
        )
        print(f"  Generated iterative categorization")

        # Parse the iterative categorization
        iterative_categorization = parse_json_response(iterative_categorization)
        
        # Update the token counts
        input_tokens += iterative_input_tokens
        output_tokens += iterative_output_tokens
        total_tokens += iterative_total_tokens

        ## Update Master Taxonomy ##

        assignments = iterative_categorization.get("assignments", [])
        new_categories = iterative_categorization.get("new_categories", [])
        
        for assignment in assignments:
            # Find the category in the master taxonomy and append the new example
            target_category = next(
                (c for c in master_taxonomy.get("categories", []) if c["category_name"] == assignment.get("category_name")), 
                None
            )

            if target_category:
                target_category["examples"].append({
                    "id": assignment.get("id"),
                    "trigger_phrase": assignment.get("trigger_phrase"),
                    "justification": assignment.get("justification")
                })
            else:
                print(f"  Warning: Category '{assignment.get('category_name')}' not found in master taxonomy")

        if new_categories:
            for new_category in new_categories:
                master_taxonomy["categories"].append({
                    "category_name": new_category.get("category_name"),
                    "definition": new_category.get("definition"),
                    "examples": new_category.get("examples", [])
                })
                
        print(f"  Updated taxonomy: {len(master_taxonomy.get('categories', []))} categories, "
              f"{sum(len(c.get('examples', [])) for c in master_taxonomy.get('categories', []))} total examples")

        print(f"Updated taxonomy: {json.dumps(master_taxonomy, indent=4)}")

    return master_taxonomy, input_tokens, output_tokens, total_tokens

def load_task_list(task_list_path: str, refinement_model_filter: Optional[str] = None) -> Tuple[List[dict], List[str]]:
    """
    Load task list from JSON file.
    
    Args:
        task_list_path: Path to task list JSON file
        refinement_model_filter: Optional filter for specific refinement model
        
    Returns:
        Tuple of (task_list as list of dicts, list of refinement models included)
    """
    with open(task_list_path, 'r') as f:
        data = json.load(f)
    
    task_list = []
    refinement_models_included = []
    
    # Check if it's the new format with task_details_by_refinement_model
    if "task_details_by_refinement_model" in data:
        for refinement_model, tasks in data["task_details_by_refinement_model"].items():
            if refinement_model_filter and refinement_model != refinement_model_filter:
                continue
            refinement_models_included.append(refinement_model)
            for task in tasks:
                task_list.append({
                    "domain": task["domain"],
                    "task_id": task["task_id"],
                    "perturbed_id": task["perturbed_id"],
                    "refinement_model": refinement_model
                })
    
    # Alternative format with task_lists_by_refinement_model (just task specs)
    elif "task_lists_by_refinement_model" in data:
        for refinement_model, task_specs in data["task_lists_by_refinement_model"].items():
            if refinement_model_filter and refinement_model != refinement_model_filter:
                continue
            refinement_models_included.append(refinement_model)
            for spec in task_specs:
                parts = spec.split(":")
                if len(parts) >= 3:
                    task_list.append({
                        "domain": parts[0],
                        "task_id": parts[1],
                        "perturbed_id": parts[2],
                        "refinement_model": refinement_model
                    })
    
    # Legacy format with flat task_list
    elif "task_list" in data:
        if not refinement_model_filter:
            raise ValueError("Legacy task_list format requires --refinement_model_filter argument")
        refinement_models_included.append(refinement_model_filter)
        for spec in data["task_list"]:
            parts = spec.split(":")
            if len(parts) >= 3:
                task_list.append({
                    "domain": parts[0],
                    "task_id": parts[1],
                    "perturbed_id": parts[2],
                    "refinement_model": refinement_model_filter
                })
    else:
        raise ValueError(f"Unknown task list format in {task_list_path}")
    
    return task_list, refinement_models_included


def load_existing_categorization(resume_path: str) -> Tuple[dict, dict]:
    """
    Load an existing categorization JSON file.

    Returns:
        Tuple of (categorization taxonomy, full JSON metadata)
    """
    with open(resume_path, 'r') as f:
        data = json.load(f)

    if "categorization" not in data:
        raise ValueError(f"Missing 'categorization' key in {resume_path}")

    return data["categorization"], data


def main():
    args = parse_args()
    
    print("=" * 70)
    print("Elicitation Run Categorization")
    print("=" * 70)
    print(f"Task List: {args.task_list_path}")
    print(f"Perturbation Model: {args.perturbation_model}")
    print(f"Refinement Model Filter: {args.refinement_model_filter or 'All'}")
    print(f"Agent: {args.agent}")
    print(f"Categorization Model: {args.model}")
    print(f"Initial Batch Size: {args.initial_categorization_batch_size}")
    print(f"Iterative Batch Size: {args.iterative_categorization_batch_size}")
    if args.resume_categorization_path:
        print(f"Resume Categorization: {args.resume_categorization_path}")
    print("=" * 70)

    # Configuration from arguments
    perturbation_model = args.perturbation_model
    agent = args.agent
    initial_categorization_batch_size = args.initial_categorization_batch_size
    iterative_categorization_batch_size = args.iterative_categorization_batch_size
    max_tokens = args.max_tokens
    temperature = args.temperature

    # perturbed_queries is in the parent directory (perturbation_generation)
    perturbed_queries_dir = PARENT_DIR / "perturbed_queries_revised"

    # Load the task list
    try:
        task_list, refinement_models_included = load_task_list(
            args.task_list_path, args.refinement_model_filter
        )
    except Exception as e:
        print(f"Error loading task list: {e}")
        sys.exit(1)
    
    if not task_list:
        print("Error: No tasks found in task list")
        sys.exit(1)
    
    print(f"\nLoaded {len(task_list)} tasks")
    print(f"Refinement models included: {refinement_models_included}")

    api_type = args.api
    categorization_api_client = get_api_client(api_type)
    model_name = args.model

    existing_master_taxonomy = None
    resume_metadata = None
    if args.resume_categorization_path:
        try:
            existing_master_taxonomy, resume_metadata = load_existing_categorization(
                args.resume_categorization_path
            )
        except Exception as e:
            print(f"Error loading resume categorization file: {e}")
            sys.exit(1)

    master_taxonomy, input_tokens, output_tokens, total_tokens = process_batch_categorization(
        categorization_api_client, 
        api_type, 
        model_name, 
        task_list, 
        perturbed_queries_dir, 
        perturbation_model, 
        agent, 
        initial_categorization_batch_size, 
        iterative_categorization_batch_size, 
        max_tokens, 
        temperature,
        existing_master_taxonomy=existing_master_taxonomy
    )
    
    print(f"\n" + "=" * 70)
    print("CATEGORIZATION COMPLETE")
    print("=" * 70)
    print(f"Input tokens: {input_tokens:,}")
    print(f"Output tokens: {output_tokens:,}")
    print(f"Total tokens: {total_tokens:,}")

    # Calculate cost with correct argument order
    cost = calculate_cost(model_name, input_tokens, output_tokens)
    print(f"Cost: {format_cost(cost)}")

    # Create output directory with descriptive naming
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    refinement_model_label = args.refinement_model_filter if args.refinement_model_filter else "all_refinement_models"
    
    output_json_path = (
        SCRIPT_DIR / "elicitation_run_categorization" / 
        perturbation_model / refinement_model_label / agent / 
        f"elicitation_run_categorization_{agent}_{timestamp}.json"
    )
    os.makedirs(output_json_path.parent, exist_ok=True)

    output_data = {
        "categorization": master_taxonomy,
        "statistics": {
            "total_tasks": len(task_list),
            "total_categories": len(master_taxonomy.get("categories", [])),
            "total_examples": sum(len(c.get("examples", [])) for c in master_taxonomy.get("categories", [])),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "total_tokens": total_tokens,
            "cost": f"{cost:.8f}"
        },
        "metadata": {
            "task_list_path": args.task_list_path,
            "perturbation_model": perturbation_model,
            "refinement_models_included": refinement_models_included,
            "agent": agent,
            "api_type": api_type,
            "model_name": model_name,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "initial_categorization_batch_size": initial_categorization_batch_size,
            "iterative_categorization_batch_size": iterative_categorization_batch_size,
            "resume_categorization_path": args.resume_categorization_path,
            "resumed_from_categories": len(existing_master_taxonomy.get("categories", [])) if existing_master_taxonomy else None,
            "resumed_from_metadata": resume_metadata.get("metadata") if resume_metadata else None,
            "generated_at": datetime.now().isoformat()
        }
    }

    with open(output_json_path, 'w') as f:
        json.dump(output_data, f, indent=4)
    
    print(f"\nCategorization saved to: {output_json_path}")
    print(f"Categories: {output_data['statistics']['total_categories']}")
    print(f"Examples: {output_data['statistics']['total_examples']}")
    print("=" * 70)


if __name__ == "__main__":
    main()