"""
Module for labeling instructions in text using OpenAI's GPT-4 or Anthropic's Claude models.

This module provides functionality to automatically identify and tag instructions 
within text using either OpenAI's GPT-4 API or Anthropic's Claude API. It processes JSONL files containing text 
data and outputs the same data with instruction tags added.
"""

import json
import openai
import anthropic
from dotenv import load_dotenv
import os
import re
import argparse
from typing import Optional, Dict, Any, List, Union

# Load environment variables
load_dotenv()

# Initialize clients as None - will be set based on command line argument
openai_client: Optional[openai.OpenAI] = None
anthropic_client: Optional[anthropic.Anthropic] = None

def tag_instructions(text: str, prompt: str, api_provider: str) -> str:
    """
    Query AI model to tag instructions in the text with <instruction></instruction> tags.
    
    Args:
        text (str): The input text to analyze for instructions
        prompt (str): The prompt template to use for the AI model query
        api_provider (str): Either 'openai' or 'anthropic' to specify which API to use
        
    Returns:
        str: The original text with instruction tags added, or original text if API call fails
        
    Raises:
        Exception: If there's an error with the API call (handled gracefully)
    """
    final_prompt = f"""{prompt} Here is the text you will be analyzing:
            <text>
            {text}
            </text>"""

    try:
        if api_provider == "openai":
            if openai_client is None:
                raise ValueError("OpenAI client not initialized")
            response = openai_client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are a precise text tagger. You only add <instruction></instruction> tags around instructions in the text provided to you and return the exact same text otherwise."},
                    {"role": "user", "content": final_prompt}
                ],
                temperature=0.0,
                max_tokens=5000
            )
            content = response.choices[0].message.content
            return content.strip() if content else text
        
        elif api_provider == "anthropic":
            if anthropic_client is None:
                raise ValueError("Anthropic client not initialized")
            response = anthropic_client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=10000,
                temperature=0.5,
                system="You are a precise text tagger. You only add <instruction></instruction> tags around instructions in the text provided to you and return the exact same text otherwise.",
                messages=[
                    {"role": "user", "content": final_prompt}
                ]
            )
            content = response.content[0].text if response.content else None
            return content.strip() if content else text
        
        else:
            raise ValueError(f"Unsupported API provider: {api_provider}")
            
    except Exception as e:
        print(f"Error querying {api_provider} model: {e}")
        return text  # Return original text if API call fails

def main() -> None:
    """
    Main function to process JSONL file and label instructions in text.
    
    Parses command line arguments for input file, output file, and prompt file paths.
    Processes each line in the input JSONL file, sends text to AI model for instruction
    tagging, performs sanity checks, and saves results to output file.
    
    Command line arguments:
        --input/-i: Path to input JSONL file containing text data
        --output/-o: Path to output JSONL file for labeled results  
        --prompt/-p: Path to text file containing the prompt template
        --api/-a: API provider to use ('openai' or 'anthropic', default: 'openai')
        
    Prints:
        Processing status and statistics including number of examples processed
        and number of sanity check failures
    """
    # Set up command line argument parsing
    parser = argparse.ArgumentParser(description='Label instructions in text using GPT-4o or Claude Sonnet')
    parser.add_argument('--input', '-i', required=True, help='Path to input JSONL file')
    parser.add_argument('--output', '-o', required=True, help='Path to output JSONL file')
    parser.add_argument('--prompt', '-p', required=True, help='Path to prompt text file')
    parser.add_argument('--api', '-a', choices=['openai', 'anthropic'], default='openai', help='API provider to use (openai or anthropic)')
    parser.add_argument('--control', action='store_true', help='control data, output same as input')
    
    args = parser.parse_args()
    
    input_path: str = args.input
    output_path: str = args.output
    prompt_path: str = args.prompt
    api_provider: str = args.api
    control: bool = args.control
    print(control)
    
    # Initialize the appropriate client based on the API provider
    global openai_client, anthropic_client
    if api_provider == "openai":
        openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        print("Using OpenAI GPT-4o model")
    elif api_provider == "anthropic":
        anthropic_client = anthropic.Anthropic(api_key=os.getenv("CLAUDE_API_KEY"))
        print("Using Anthropic Claude Sonnet model")
    # Read the prompt
    with open(prompt_path, 'r') as f:
        prompt: str = f.read()

    print(f"Reading from: {input_path}")
    print(f"Output to: {output_path}")
    print(f"Using prompt from: {prompt_path}")
    print(f"Processing with {api_provider} model")

    labeled_data: List[Dict[str, Any]] = [] 
    failed_sanity_checks: int = 0

    with open(input_path, 'r') as f:
        for i, line in enumerate(f):
            data: Dict[str, Any] = json.loads(line.strip())
            
            print(f"Processing example {i}...")
            
            # Query AI model to tag instructions
            labeled_text: str
            if control:
                labeled_text = data['complete_text']
            else:
                labeled_text = tag_instructions(data['complete_text'], prompt, api_provider)
            
            # Sanity check: remove instruction tags and compare with original
            cleaned_text: str = re.sub(r'</?instruction>', '', labeled_text)
            cleaned_text = re.sub(r'</?text>', '', cleaned_text)
            original_text: str = data['complete_text']
            data['sanity_check'] = True
            
            refined_original_text: str = re.sub(r'\s+', ' ', original_text).strip()
            refined_cleaned_text: str = re.sub(r'\s+', ' ', cleaned_text).strip()
            if refined_cleaned_text != refined_original_text:
                print(f"\n⚠️  SANITY CHECK FAILED for example {i}:")
                print(f"Original: {repr(refined_original_text)}")
                print(f"Cleaned:  {repr(refined_cleaned_text)}")
                labeled_text = refined_original_text  # Use original text instead of model output
                failed_sanity_checks += 1
                data['sanity_check'] = False
            # print(data['complete_text'])
            # print(labeled_text)
            
            # Add the label_text column
            refined_labeled_text: str = re.sub(r'\s+', ' ', labeled_text)
            data['label_text'] = refined_labeled_text
            labeled_data.append(data)
            
            print(f"✓ Example {i} processed")

    # Save to new JSONL file
    with open(output_path, 'w') as f:
        for data in labeled_data:
            json.dump(data, f)
            f.write('\n')

    print(f"\nLabeled data saved to: {output_path}")
    print(f"Processed {len(labeled_data)} examples")
    print(f"Sanity check failures: {failed_sanity_checks}")

if __name__ == "__main__":
    main()
