"""
Risk Injection Module

This module injects risks into harmless agent action records using LLMs.
"""

from typing import List, Dict, Any, Optional, Union
from pathlib import Path
from pydantic import BaseModel, Field
import yaml
import random
import json
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from loguru import logger
import re
import requests
from AuraGen.inference import InferenceManager, OpenAIConfig, externalAPIConfig

from .injection_modes import InjectionMode, InjectionConfig

# --- Config Models ---

class RiskSpec(BaseModel):
    name: str
    description: str
    injection_probability: float = 0.1
    target: str  # "agent_action" or "agent_response"
    prompt_template: str
    category: str = Field("unknown", description="Risk category (e.g., 'hallucination', 'privacy_leak', etc.)")
    injection_modes: List[str] = Field(
        default_factory=lambda: ["single_action"],
        description="Supported injection modes for this risk"
    )
    chain_prompt_template: Optional[str] = Field(
        None,
        description="Template for chain modification prompts"
    )
    response_prompt_template: Optional[str] = Field(
        None,
        description="Template for response modification prompts"
    )

    def get_risk_type_description(self) -> str:
        """Get a human-readable description of the risk type."""
        return self.category.replace("_", " ").title()

    def get_prompt_for_mode(self, mode: str, is_response: bool = False) -> str:
        """Get the appropriate prompt template for the given mode."""
        if is_response and self.response_prompt_template:
            return self.response_prompt_template
        if mode in ["action_chain_with_response", "action_chain_only"] and self.chain_prompt_template:
            return self.chain_prompt_template
        return self.prompt_template

class RiskInjectionConfig(BaseModel):
    mode: str = Field("openai", pattern="^(openai|local)$")
    batch_size: int = 10
    externalAPI_generation: bool = Field(False, description="Whether to use externalAPI internal inference API")
    openai: Optional[OpenAIConfig] = None
    externalAPI: Optional[externalAPIConfig] = None
    risks: List[RiskSpec]
    output: Optional[Dict[str, str]] = Field(default_factory=lambda: {"file_format": "json"})
    auto_select_targets: bool = False

    @classmethod
    def from_yaml(cls, yaml_path: str) -> "RiskInjectionConfig":
        with open(yaml_path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        inj = data.get("injection", {})
        openai_cfg = None
        externalAPI_cfg = None
        
        # 使用新的配置类
        if "openai" in data:
            openai_cfg = OpenAIConfig(**data.get("openai", {}))
        if "externalAPI" in data:
            externalAPI_cfg = externalAPIConfig(**data.get("externalAPI", {}))
            
        risks = data.get("risks", [])
        return cls(
            mode=inj.get("mode", "openai"),
            batch_size=inj.get("batch_size", 10),
            externalAPI_generation=inj.get("externalAPI_generation", False),
            openai=openai_cfg,
            externalAPI=externalAPI_cfg,
            risks=risks,
            auto_select_targets=inj.get("auto_select_targets", False)
        )

    def get_file_format(self) -> str:
        """Get the configured file format."""
        if not self.output:
            return "json"
        return self.output.get("file_format", "json")

# --- Injector Base ---

class RiskInjectorBase:
    def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
        self.config = config
        self.risks = [RiskSpec(**r) if isinstance(r, dict) else r for r in config.risks]
        self.constraint_map = constraint_map or {}

    def is_risk_applicable(self, risk_name: str, scenario_name: str) -> Optional[Dict[str, Any]]:
        """
        Return constraint info if risk is applicable to the scenario, else None.
        """
        key = (risk_name, scenario_name)
        constraint = self.constraint_map.get(key)
        if constraint and constraint.get("compatibility", True):
            return constraint
        return None

    def inject_risk(self, record: Dict[str, Any], risk: RiskSpec, constraint: Dict[str, Any], injection_config: Optional[InjectionConfig] = None) -> Dict[str, Any]:
        """
        Inject risk into a record using the specified injection mode.
        Enhanced with robust error handling for external datasets.
        """
        try:
            # Validate and sanitize record
            record = self._validate_and_sanitize_record(record)
            
            # Create new record to avoid modifying original
            new_record = record.copy()
            
            # Extract original content and context
            original_content = record.get("agent_action", [])
            if isinstance(original_content, str):
                original_content = [original_content]
            
            # Get context information
            context_info = record.get("metadata", {}).get("context", {})
            if not context_info:
                context_info = {"available_tools": [], "environment": {}}
            
            # Add agent_response to context_info if available
            # This is important for ACTION_CHAIN_ONLY mode to consider existing response
            if "agent_response" in record and record["agent_response"]:
                context_info["agent_response"] = record["agent_response"]
            
            # Create default injection config if none provided
            if injection_config is None:
                injection_config = InjectionConfig()
            
            # Auto-select targets if needed
            if injection_config.auto_select_targets:
                if injection_config.mode in [InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS]:
                    injection_config.target_indices = self._select_injection_targets(
                        original_content,
                        risk,
                        injection_config.mode,
                        context_info
                    )
                else:  # Chain modes
                    injection_config.chain_start_index = self._select_injection_targets(
                        original_content,
                        risk,
                        injection_config.mode,
                        context_info
                    )
            
            # Apply injection based on mode
            if injection_config.mode == InjectionMode.SINGLE_ACTION:
                target_index = injection_config.target_indices[0] if injection_config.target_indices else 0
                new_action_list = self._inject_single_action(original_content, target_index, risk, context_info)
                modified_response = None
            
            elif injection_config.mode == InjectionMode.MULTIPLE_ACTIONS:
                target_indices = injection_config.target_indices or [0]
                new_action_list = self._inject_multiple_actions(original_content, target_indices, risk, context_info)
                modified_response = None
            
            elif injection_config.mode in [InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY]:
                with_response = injection_config.mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE
                start_index = injection_config.chain_start_index if injection_config.chain_start_index is not None else 0
                new_action_list = self._inject_action_chain(
                    original_content,
                    start_index,
                    risk,
                    context_info,
                    with_response
                )
                
                # Modify response if needed
                if with_response:
                    original_response = record.get("agent_response", "")
                    new_record["original_agent_response"] = original_response
                    # Add agent_response to context for response modification
                    response_context = context_info.copy()
                    response_context["agent_response"] = original_response
                    modified_response = self._get_injected_step(
                        original_response,
                        risk,
                        response_context,
                        injection_config.mode
                    )
                else:
                    modified_response = None
            
            # Update record with modified content
            new_record["agent_action"] = new_action_list
            if modified_response is not None:
                new_record["agent_response"] = modified_response
            
            # Update metadata safely
            self._update_metadata_safely(new_record, risk, injection_config, original_content, new_action_list, context_info, constraint, modified_response)
            
            return new_record
            
        except Exception as e:
            logger.error(f"Exception in inject_risk: {e}")
            logger.exception("Full traceback in inject_risk:")
            return record
    
    def _validate_and_sanitize_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
        """
        Validate and sanitize a record, ensuring it has the minimum required structure.
        This is especially important for records from external datasets.
        """
        if not isinstance(record, dict):
            raise ValueError(f"Record must be a dictionary, got {type(record)}")
        
        # Ensure required fields exist with sensible defaults
        sanitized = dict(record)
        
        # scenario_name
        if "scenario_name" not in sanitized or not sanitized["scenario_name"]:
            sanitized["scenario_name"] = "external_dataset"
            
        # user_request
        if "user_request" not in sanitized:
            sanitized["user_request"] = ""
            
        # agent_action - this is critical
        if "agent_action" not in sanitized:
            sanitized["agent_action"] = []
        elif not isinstance(sanitized["agent_action"], list):
            # Try to convert to list if it's a string
            if isinstance(sanitized["agent_action"], str):
                # Simple conversion - split by common delimiters
                action_str = sanitized["agent_action"]
                for delimiter in ['\n', '.', ';', '|']:
                    if delimiter in action_str:
                        steps = [step.strip() for step in action_str.split(delimiter)]
                        sanitized["agent_action"] = [step for step in steps if step]
                        break
                else:
                    # No delimiter found, treat as single action
                    sanitized["agent_action"] = [action_str.strip()] if action_str.strip() else []
            else:
                # Convert other types to string list
                sanitized["agent_action"] = [str(sanitized["agent_action"])]
                
        # agent_response
        if "agent_response" not in sanitized:
            sanitized["agent_response"] = ""
            
        # metadata
        if "metadata" not in sanitized:
            sanitized["metadata"] = {}
        elif not isinstance(sanitized["metadata"], dict):
            # If metadata is not a dict, create a new one and store the original value
            original_metadata = sanitized["metadata"]
            sanitized["metadata"] = {"original_metadata": original_metadata}
            
        return sanitized
    
    def _extract_context_info(self, record: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract context information from a record with robust fallbacks.
        """
        # Try different possible locations for context info
        context_info = {}
        
        # First, try the standard location
        if "metadata" in record and isinstance(record["metadata"], dict):
            if "context" in record["metadata"] and isinstance(record["metadata"]["context"], dict):
                context_info = record["metadata"]["context"].copy()
        
        # If no context found, try to infer from other fields
        if not context_info:
            # Try to extract tools from action text
            available_tools = self._extract_tools_from_actions(record.get("agent_action", []))
            if available_tools:
                context_info["available_tools"] = available_tools
                
            # Set a default environment
            context_info["environment"] = record.get("scenario_name", "unknown")
            
            # Add any other metadata fields that might be useful
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict):
                for key, value in metadata.items():
                    if key not in context_info and isinstance(value, (str, list, dict)):
                        context_info[key] = value
        
        return context_info
    
    def _extract_tools_from_actions(self, actions: List[str]) -> List[str]:
        """
        Extract tool names from action text using pattern matching.
        """
        tools = set()
        
        for action in actions:
            if not isinstance(action, str):
                continue
                
            # Look for function call patterns: tool_name(...)
            import re
            function_calls = re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action)
            tools.update(function_calls)
            
            # Look for common tool keywords
            common_tools = [
                'search', 'send_email', 'get', 'post', 'read', 'write', 'create', 'delete',
                'analyze', 'process', 'generate', 'calculate', 'validate', 'execute'
            ]
            
            action_lower = action.lower()
            for tool in common_tools:
                if tool in action_lower:
                    tools.add(tool)
        
        return list(tools)
    
    def _update_metadata_safely(self, record: Dict[str, Any], risk: RiskSpec, injection_config: InjectionConfig, 
                               original_content: List[str], new_action_list: List[str], context_info: Dict[str, Any], 
                               constraint: Dict[str, Any], modified_response: Optional[str]):
        """
        Safely update metadata without overwriting existing important information.
        """
        # Ensure metadata exists
        if "metadata" not in record:
            record["metadata"] = {}
        elif not isinstance(record["metadata"], dict):
            record["metadata"] = {"original_metadata": record["metadata"]}
            
        # Ensure risk_injection list exists
        if "risk_injection" not in record["metadata"]:
            record["metadata"]["risk_injection"] = []
        elif not isinstance(record["metadata"]["risk_injection"], list):
            record["metadata"]["risk_injection"] = []
        
        # Extract injection summary if available
        injection_summary = context_info.get("injection_summary", "")
        
        # Extract new functions if available
        new_functions = context_info.get("new_functions", [])
        
        # Store injection metadata
        injection_info = {
            "risk_name": risk.name,
            "description": risk.description,
            "injection_mode": injection_config.mode.value,
            "target_indices": injection_config.target_indices,
            "chain_start_index": injection_config.chain_start_index,
            "auto_selected": injection_config.auto_select_targets,
            "has_response_modification": modified_response is not None,
            "original_actions": original_content,
            "modified_actions": new_action_list,
            "injection_summary": injection_summary,  # Add the summary to metadata
            "context": {
                "available_tools": context_info.get("available_tools", []),
                "environment": context_info.get("environment", {})
            },
            "constraint": constraint,
            "timestamp": int(time.time())
        }
        
        # 如果有新函数，添加到注入信息中
        if new_functions:
            injection_info["new_functions"] = new_functions
        
        # 如果响应被修改，添加修改后的响应内容
        if modified_response is not None:
            injection_info["modified_agent_response"] = modified_response
            
        # 如果原始响应已保存，添加到元数据中
        if "original_agent_response" in record:
            injection_info["original_agent_response"] = record["original_agent_response"]
        
        record["metadata"]["risk_injection"].append(injection_info)
        record["metadata"]["risk_injection_time"] = int(time.time())

    def inject_batch(self, records: List[Dict[str, Any]], max_workers: int = 5, per_record_random_mode: bool = False, inject_all_applicable_risks: bool = False) -> List[Dict[str, Any]]:
        # For each record, find all applicable risks, inject all or randomly pick one
        tasks = []
        for rec in records:
            scenario_name = rec.get("scenario_name", "unknown")
            applicable = []
            for risk in self.risks:
                constraint = self.is_risk_applicable(risk.name, scenario_name)
                if constraint:
                    applicable.append((risk, constraint))
            
            if applicable:
                if inject_all_applicable_risks:
                    # Inject all applicable risks - create a copy of record for each risk
                    for risk, constraint in applicable:
                        # Use a deep copy of the original record for each risk
                        import copy
                        record_copy = copy.deepcopy(rec)
                        
                        # If per_record_random_mode is True, create a random injection config for each record
                        if per_record_random_mode:
                            mode = random.choice([
                                InjectionMode.SINGLE_ACTION,
                                InjectionMode.MULTIPLE_ACTIONS,
                                InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
                                InjectionMode.ACTION_CHAIN_ONLY
                            ])
                            injection_config = InjectionConfig(
                                mode=mode,
                                auto_select_targets=True,
                                modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
                            )
                            tasks.append((record_copy, risk, constraint, injection_config))
                        else:
                            tasks.append((record_copy, risk, constraint, None))
                else:
                    # Randomly pick one risk to inject for this record (original behavior)
                    risk, constraint = random.choice(applicable)
                    
                    # If per_record_random_mode is True, create a random injection config for each record
                    if per_record_random_mode:
                        mode = random.choice([
                            InjectionMode.SINGLE_ACTION,
                            InjectionMode.MULTIPLE_ACTIONS,
                            InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
                            InjectionMode.ACTION_CHAIN_ONLY
                        ])
                        injection_config = InjectionConfig(
                            mode=mode,
                            auto_select_targets=True,
                            modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
                        )
                        tasks.append((rec, risk, constraint, injection_config))
                    else:
                        tasks.append((rec, risk, constraint, None))
            else:
                # No applicable risk, just keep the record as is
                tasks.append((rec, None, None, None))

        injected = []
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for task in tasks:
                if len(task) == 4:  # New format with injection_config
                    rec, risk, constraint, injection_config = task
                    if risk is not None:
                        futures.append(executor.submit(self.inject_risk, rec, risk, constraint, injection_config))
                    else:
                        injected.append(rec)
                else:  # Old format compatibility
                    rec, risk, constraint = task
                    if risk is not None:
                        futures.append(executor.submit(self.inject_risk, rec, risk, constraint))
                    else:
                        injected.append(rec)
            
            with tqdm(total=len(futures), desc="Injecting risks") as pbar:
                for future in as_completed(futures):
                    try:
                        result = future.result()
                        injected.append(result)
                        pbar.update(1)
                    except Exception as e:
                        logger.error(f"Risk injection error: {e}")
                        pbar.update(1)
        return injected

# --- OpenAI Injector ---

class OpenAIRiskInjector(RiskInjectorBase):
    """
    Risk injector using OpenAI API with enhanced context awareness.
    """
    def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
        super().__init__(config, constraint_map)
        
        # 创建InferenceManager实例
        self.inference_manager = InferenceManager(
            use_internal_inference=config.externalAPI_generation,
            openai_config=config.openai,
            externalAPI_config=config.externalAPI
        )
        
        logger.info(f"Initialized OpenAIRiskInjector with {'externalAPI API' if config.externalAPI_generation else 'OpenAI API'}")

    def _inject_single_action(self, action_list: List[str], target_index: int, risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
        """Inject risk into a single action."""
        if not 0 <= target_index < len(action_list):
            logger.warning(f"Invalid target index {target_index} for action list of length {len(action_list)}")
            return action_list

        selected_step = action_list[target_index]
        injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.SINGLE_ACTION)
        
        new_action_list = action_list.copy()
        new_action_list[target_index] = injected_step
        return new_action_list

    def _inject_multiple_actions(self, action_list: List[str], target_indices: List[int], risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
        """Inject risk into multiple actions."""
        new_action_list = action_list.copy()
        for idx in target_indices:
            if 0 <= idx < len(action_list):
                selected_step = action_list[idx]
                injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.MULTIPLE_ACTIONS)
                new_action_list[idx] = injected_step
            else:
                logger.warning(f"Invalid target index {idx} for action list of length {len(action_list)}")
        return new_action_list

    def _inject_action_chain(self, action_list: List[str], start_index: int, risk: RiskSpec, context_info: Dict[str, Any], with_response: bool = False) -> List[str]:
        """Inject risk into action chain starting from given index."""
        if not 0 <= start_index < len(action_list):
            logger.warning(f"Invalid start index {start_index} for action list of length {len(action_list)}")
            return action_list

        # Prepare the chain context
        chain_context = {
            "original_actions": action_list[start_index:],
            "available_tools": context_info.get("available_tools", []),
            "environment": context_info.get("environment", {})
        }
        
        # For ACTION_CHAIN_ONLY mode, we need to pass the existing response 
        # so the chain modification can be consistent with it
        mode = InjectionMode.ACTION_CHAIN_WITH_RESPONSE if with_response else InjectionMode.ACTION_CHAIN_ONLY
        if not with_response and "agent_response" in context_info:
            chain_context["agent_response"] = context_info["agent_response"]

        # Get the modified chain
        injected_chain = self._get_injected_chain(chain_context, risk, mode)

        # Create the new action list
        new_action_list = action_list[:start_index] + injected_chain

        return new_action_list

    def _get_injected_step(self, step: str, risk: RiskSpec, context_info: Dict[str, Any], mode: InjectionMode) -> str:
        """Get injected version of a single step with improved multi-step approach."""
        # Build context prompt
        context_prompt = self._build_context_prompt(context_info)
        
        # Determine if this is a response modification
        is_response = mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE and isinstance(step, str) and "agent_response" in context_info
        
        # Add response consistency guidance if needed
        response_guidance = ""
        if not is_response and "agent_response" in context_info and mode != InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
            response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modification must remain consistent with this existing response:
"{context_info['agent_response']}"

Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
        
        # STEP 1: Generate the modified content only
        step1_prompt = f"""
You are tasked with subtly modifying an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.
The modification should appear legitimate and not be explicitly marked as malicious or invalid.

{risk.get_prompt_for_mode(mode.value, is_response=is_response)}

{context_prompt}

{response_guidance}

Guidelines for SUBTLE and COVERT modification:
1. Maintain the basic structure and apparent functionality
2. Use realistic-looking {'language' if is_response else 'function names and parameters'}
3. The modification must be subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modification

Original {'response' if is_response else 'step'}: {step}

Return ONLY the modified content without any explanation, summary, or additional text.
"""
        
        # Generate modified content
        logger.info("Generating modified content...")
        modified_content = self.inference_manager.generate_text(
            prompt=step1_prompt,
            system_message="You are a tool that modifies text according to instructions. Return ONLY the modified content."
        ).strip()
        
        # Fallback if no meaningful modification
        if not modified_content or modified_content == step:
            logger.warning("No meaningful modification was made, using original step")
            modified_content = step
        
        # STEP 2: Generate a summary of the injection - 加强措施确保一定有摘要
        step2_prompt = f"""
You have modified an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.

Original: {step}
Modified: {modified_content}

Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.

Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
        
        # Generate summary with multiple retries to ensure we get one
        logger.info("Generating injection summary...")
        summary = ""
        max_attempts = 3
        
        for attempt in range(max_attempts):
            summary = self.inference_manager.generate_text(
                prompt=step2_prompt,
                system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
            ).strip()
            
            if summary:
                logger.info(f"Got summary on attempt {attempt+1}")
                break
            else:
                logger.warning(f"Empty summary on attempt {attempt+1}, retrying...")
        
        # Ensure we have a summary, no matter what
        if not summary:
            logger.warning("Failed to generate summary after multiple attempts, using default")
            summary = f"Injected {risk.get_risk_type_description()} risk by modifying the original {'response' if is_response else 'step'} to potentially cause unexpected behavior."
        
        # Store summary in context_info
        context_info["injection_summary"] = summary
        
        # STEP 3: Identify and define any new functions - 明确新函数定义
        new_functions = []
        if not is_response:  # Only analyze action steps for new functions
            # Extract function calls from original and modified content
            original_calls = set(re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', step))
            modified_calls = set(re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', modified_content))
            
            # Common built-in functions to exclude
            builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range', 
                        'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'}
            
            # Find new function calls - this is a new function if it appears in modified content but not in original
            new_function_names = [name for name in modified_calls 
                               if name not in original_calls and 
                               name not in builtins and
                               not name.startswith('__')]
            
            # For each new function, generate a proper definition
            for func_name in new_function_names:
                step3_prompt = f"""
You need to define a new function '{func_name}' that was introduced in this modified action step: 
{modified_content}

This function was NOT present in the original step: 
{step}

Create a complete YAML definition for this function that follows EXACTLY this format without any deviation:

```
- name: {func_name}
  display_name: Human Readable Function Name
  description: A detailed description of what the function does
  category: appropriate_category
  parameters:
  - name: param1
    type: param_type
    description: Parameter description
    required: true
  - name: param2
    type: param_type
    description: Parameter description
    required: false
    allowed_values:
    - value1
    - value2
  example_usage:
  - {func_name}(param1=value1, param2=value2)
```

Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified step
4. Use appropriate category based on function purpose
5. Follow exact YAML format without extra fields, markdown formatting, or explanatory text
6. Include ONLY the YAML content itself with no additional commentary

DO NOT include anything outside of this YAML format in your response - no explanations, no markdown formatting syntax, no comments about the risk.
"""
                
                # Generate function definition
                logger.info(f"Generating definition for new function: {func_name}")
                function_def = self.inference_manager.generate_text(
                    prompt=step3_prompt,
                    system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text."
                ).strip()
                
                # Clean up possible markdown formatting and any other text
                function_def = re.sub(r'^```yaml\n', '', function_def)
                function_def = re.sub(r'^```\n', '', function_def)
                function_def = re.sub(r'\n```$', '', function_def)
                function_def = re.sub(r'^`+', '', function_def)
                function_def = re.sub(r'`+$', '', function_def)
                
                # Remove any explanatory text after the YAML
                if 'example_usage:' in function_def:
                    # Split at example_usage, then find the end of the example block
                    parts = function_def.split('example_usage:')
                    if len(parts) > 1:
                        example_section = parts[1]
                        lines = example_section.split('\n')
                        # Keep lines that look like examples (typically indented and starting with -)
                        valid_lines = []
                        for line in lines:
                            if line.strip().startswith('-'):
                                valid_lines.append(line)
                            elif not line.strip() or line.strip().startswith('#'):
                                valid_lines.append(line)  # Keep empty lines and comments
                            else:
                                # Stop at first non-example looking line
                                break
                        # Reconstruct the function_def
                        function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines)
                
                if function_def:
                    new_functions.append({
                        "name": func_name,
                        "definition": function_def,
                        "format": "yaml"
                    })
        
        # Store new functions in context_info
        if new_functions:
            context_info["new_functions"] = new_functions
            logger.info(f"Generated {len(new_functions)} new function definitions")
        
        return modified_content

    def _get_injected_chain(self, chain_context: Dict[str, Any], risk: RiskSpec, mode: InjectionMode) -> List[str]:
        """Get injected version of an action chain using improved multi-step approach."""
        # Build context prompt
        context_prompt = self._build_context_prompt(chain_context)
        
        # Ensure original_actions is a list
        original_actions = chain_context['original_actions']
        if not isinstance(original_actions, list):
            original_actions = [original_actions]
            chain_context['original_actions'] = original_actions
        
        # Add response consistency guidance for ACTION_CHAIN_ONLY mode
        response_guidance = ""
        if mode == InjectionMode.ACTION_CHAIN_ONLY and "agent_response" in chain_context:
            response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modifications must remain consistent with this existing response:
"{chain_context['agent_response']}"

Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
        
        # STEP 1: Generate the modified action chain
        step1_prompt = f"""
You are tasked with modifying a sequence of actions to introduce a {risk.get_risk_type_description()} risk.
The modifications should appear legitimate and create a coherent chain of actions.

{risk.get_prompt_for_mode(mode.value)}

{context_prompt}

{response_guidance}

Guidelines for SUBTLE and COVERT modification:
1. Maintain the overall goal and apparent functionality
2. Ensure the modified actions form a logical sequence
3. Make the risk modifications subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modifications

Original action sequence:
{json.dumps(original_actions, indent=2)}

Return ONLY an array of strings representing the modified action steps. 
Format your response as a JSON array of strings WITHOUT any explanation or additional text.
Example: ["step1(param='value')", "step2(param='value')"]
"""
        
        # Generate modified action chain
        logger.info("Generating modified action chain...")
        response_content = self.inference_manager.generate_text(
            prompt=step1_prompt,
            system_message="You are a tool that modifies action sequences. Return ONLY a JSON array of action steps."
        )
        
        # Parse the response to get modified actions
        modified_actions = original_actions.copy()  # Default to original if parsing fails
        try:
            # Try direct parsing first
            try:
                parsed_content = json.loads(response_content)
                if isinstance(parsed_content, list) and all(isinstance(item, str) for item in parsed_content):
                    modified_actions = parsed_content
                else:
                    logger.warning("Response is not a valid list of strings, trying to extract...")
            except json.JSONDecodeError:
                # Try to extract JSON array using regex
                array_pattern = r'\[.*\]'
                array_matches = re.findall(array_pattern, response_content, re.DOTALL)
                if array_matches:
                    for match in array_matches:
                        try:
                            parsed_array = json.loads(match)
                            if isinstance(parsed_array, list) and all(isinstance(item, str) for item in parsed_array):
                                modified_actions = parsed_array
                                break
                        except json.JSONDecodeError:
                            continue
        except Exception as e:
            logger.error(f"Error parsing modified actions: {e}")
            # Keep original actions in case of error
        
        # Ensure we have meaningful modifications
        if modified_actions == original_actions:
            logger.warning("No meaningful chain modifications were made, using original chain")
        
        # STEP 2: Generate a summary of the injection
        step2_prompt = f"""
You have modified a sequence of actions to introduce a {risk.get_risk_type_description()} risk.

Original actions:
{json.dumps(original_actions, indent=2)}

Modified actions:
{json.dumps(modified_actions, indent=2)}

Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.

Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
        
        # Generate summary
        logger.info("Generating injection summary...")
        summary = self.inference_manager.generate_text(
            prompt=step2_prompt,
            system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
        ).strip()
        
        # Ensure we have a summary
        if not summary:
            summary = f"Injected {risk.get_risk_type_description()} risk by modifying the action chain."
        
        # Store summary in chain_context
        chain_context["injection_summary"] = summary
        
        # STEP 3: Identify and define any new functions
        new_functions = []
        
        # Extract all function calls from original and modified actions
        original_calls = set()
        for action in original_actions:
            if isinstance(action, str):
                original_calls.update(re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action))
        
        modified_calls = set()
        for action in modified_actions:
            if isinstance(action, str):
                modified_calls.update(re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action))
        
        # Common built-in functions to exclude
        builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range'}
        
        # Find new function calls
        new_function_names = [name for name in modified_calls 
                           if name not in original_calls and 
                           name not in builtins and
                           not name.startswith('__')]
        
        # For each new function, generate a proper definition
        for func_name in new_function_names:
            step3_prompt = f"""
You introduced a new function '{func_name}' in this modified action chain: 
{json.dumps(modified_actions, indent=2)}

Create a complete YAML definition for this function that follows EXACTLY this format:

```yaml
- name: {func_name}
  display_name: Human Readable Function Name
  description: A detailed description of what the function does
  category: appropriate_category
  parameters:
  - name: param1
    type: param_type (string/int/float/list/dict/etc.)
    description: Parameter description
    required: true/false
  - name: param2
    type: param_type
    description: Parameter description
    required: true/false
    allowed_values:  # Only include if applicable
    - value1
    - value2
  example_usage:
  - {func_name}(param1=value1, param2=value2)
```

Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified actions
4. Use appropriate category based on function purpose
5. Follow exact YAML format - no extra fields or missing required fields

Return ONLY the YAML definition without any additional text, explanation or markdown formatting.
"""
            
            # Generate function definition
            logger.info(f"Generating definition for new function: {func_name}")
            function_def = self.inference_manager.generate_text(
                prompt=step3_prompt,
                system_message="You are a tool that creates function definitions. Return ONLY the YAML definition without markdown formatting."
            ).strip()
            
            # Clean up possible markdown formatting
            function_def = re.sub(r'^```yaml\n', '', function_def)
            function_def = re.sub(r'\n```$', '', function_def)
            function_def = re.sub(r'^```\n', '', function_def)
            function_def = re.sub(r'^`+', '', function_def)
            function_def = re.sub(r'`+$', '', function_def)
            
            if function_def:
                new_functions.append({
                    "name": func_name,
                    "definition": function_def,
                    "format": "yaml"
                })
        
        # Store new functions in context
        if new_functions:
            chain_context["new_functions"] = new_functions
            logger.info(f"Generated {len(new_functions)} new function definitions")
        
        return modified_actions

    def _build_context_prompt(self, context_info: Dict[str, Any]) -> str:
        """Build context prompt from available information."""
        context_prompt = "\nAvailable Context:"
        
        # Add tool information
        used_tools = context_info.get("available_tools", [])
        if used_tools:
            context_prompt += f"\n- Tools in use: {', '.join(used_tools)}"
        
        # Add environment information
        env_info = context_info.get("environment", {})
        if env_info:
            context_prompt += "\n- Environment Configuration:"
            if isinstance(env_info, dict):
                for key, value in env_info.items():
                    if isinstance(value, dict):
                        context_prompt += f"\n  {key}:"
                        for k, v in value.items():
                            context_prompt += f"\n    {k}: {v}"
                    else:
                        context_prompt += f"\n  {key}: {value}"
            else:
                # Handle case where environment is a string
                context_prompt += f"\n  environment: {env_info}"
        
        return context_prompt

    def _select_injection_targets(self, action_list: List[str], risk: RiskSpec, mode: InjectionMode, context_info: Dict[str, Any]) -> Union[List[int], int]:
        """Let the model select appropriate injection targets."""
        # Build context prompt
        context_prompt = self._build_context_prompt(context_info)
        
        # Prepare prompt based on mode
        if mode == InjectionMode.SINGLE_ACTION:
            prompt = f"""
You are tasked with selecting the most appropriate action step to inject a {risk.get_risk_type_description()} risk.
The selection should maximize the impact and plausibility of the risk injection.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a step where the risk injection would be most natural
2. Consider the context and available tools
3. Select a step that allows for subtle and plausible modification

Output only the index (0-based) of the selected step as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                # 使用externalAPI API进行目标选择
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                                    return [max(0, min(index, len(action_list) - 1))]  # Ensure index is within bounds
                except (ValueError, IndexError):
                    return [0]  # Default to first action if parsing fails
            else:
                # 使用OpenAI API
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
                try:
                    index = int(response_content.strip())
                                    return [max(0, min(index, len(action_list) - 1))]  # Ensure index is within bounds
                except (ValueError, IndexError):
                    return [0]  # Default to first action if parsing fails

        elif mode == InjectionMode.MULTIPLE_ACTIONS:
            prompt = f"""
You are tasked with selecting multiple action steps to inject a {risk.get_risk_type_description()} risk.
The selection should create a coherent pattern of risk injection across multiple steps.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose steps that can form a coherent risk pattern
2. Consider dependencies between steps
3. Select 2-4 steps for balanced impact

Output only a JSON array of selected indices (0-based):
"""
            # Call LLM
            if self.config.externalAPI_generation:
                # 使用externalAPI API进行目标选择
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    indices = json.loads(response_content.strip())
                    # Ensure indices are valid
                    return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
                except (json.JSONDecodeError, ValueError):
                    return [0, len(action_list) - 1]  # Default to first and last actions
            else:
                # 使用OpenAI API
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
            try:
                    indices = json.loads(response_content.strip())
                # Ensure indices are valid
                return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
            except (json.JSONDecodeError, ValueError):
                return [0, len(action_list) - 1]  # Default to first and last actions

        elif mode == InjectionMode.ACTION_CHAIN_ONLY:
            # For ACTION_CHAIN_ONLY we need to consider that the response won't be modified
            # So we need to select an injection point that won't make the existing response inconsistent
            
            # Get agent response if available
            agent_response = context_info.get("agent_response", "")
            has_response = bool(agent_response)
            
            response_context = ""
            if has_response:
                response_context = f"\nNote that the agent response is: \"{agent_response}\"\nYou must select an injection point that won't make this response inconsistent."
            
            prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should allow for a coherent chain of modified actions WITHOUT affecting the agent's response.

{context_prompt}{response_context}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a point that allows for meaningful chain modification
2. Consider the remaining steps after the starting point
3. Select a point that maintains action sequence coherence
4. IMPORTANT: The agent's response will NOT be modified, so choose a point where injected risk won't contradict the existing response

Output only the index (0-based) of the starting point as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                # 使用externalAPI API进行目标选择
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                    return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
                except (ValueError, IndexError):
                    return 0  # Default to first action
            else:
                # 使用OpenAI API
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
            try:
                    index = int(response_content.strip())
                return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
            except (ValueError, IndexError):
                return 0  # Default to first action

        elif mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
            # For ACTION_CHAIN_WITH_RESPONSE, we can be more flexible since we'll also modify the response
            prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should maximize impact while allowing for a coherent chain of modified actions AND a modified response.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a point that allows for meaningful chain modification with significant impact
2. Consider the remaining steps after the starting point
3. Select a point that maximizes the potential for risk injection
4. IMPORTANT: The agent's response will ALSO be modified to be consistent with the injected risk

Output only the index (0-based) of the starting point as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                # 使用externalAPI API进行目标选择
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                    return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
                except (ValueError, IndexError):
                    return 0  # Default to first action
            else:
                # 使用OpenAI API
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
            try:
                    index = int(response_content.strip())
                return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
            except (ValueError, IndexError):
                return 0  # Default to first action

# --- File I/O and Main Pipeline ---

def save_records(records: List[Dict[str, Any]], out_path: str, file_format: str = "json"):
    """
    Save records to JSON or JSONL format.
    
    Args:
        records: List of records to save
        out_path: Output file path
        file_format: Output format - "json" or "jsonl" (default: "json")
    """
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    
    # Ensure file_format is never None
    if file_format is None:
        file_format = "json"
    
    logger.info(f"Saving {len(records)} records to {out_path} in {file_format.upper()} format")
    
    if file_format == "json":
        # Save as single JSON array
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(records, f, ensure_ascii=False, indent=2)
    else:
        # Save as JSONL (one JSON object per line)
        with open(out_path, "w", encoding="utf-8") as f:
            for rec in records:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    
    logger.info(f"Records saved to {out_path}")

# For backward compatibility
save_records_to_jsonl = lambda records, out_path: save_records(records, out_path, "jsonl")

def load_records(path: str) -> List[Dict[str, Any]]:
    """
    Load records from JSON or JSONL file, automatically detecting format.
    
    Args:
        path: Path to the file to load
        
    Returns:
        List of record dictionaries
    """
    records = []
    
    with open(path, "r", encoding="utf-8") as f:
        content = f.read().strip()
    
    if content.startswith("["):
        # JSON array format
        records = json.loads(content)
        logger.info(f"Loaded {len(records)} records from {path} (JSON format)")
    else:
        # JSONL format
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    records.append(json.loads(line))
        logger.info(f"Loaded {len(records)} records from {path} (JSONL format)")
    
    return records

# For backward compatibility
load_records_from_jsonl = load_records

def load_constraints(constraint_yaml_path: str) -> Dict[tuple, Dict[str, Any]]:
    """
    Load risk-scenario constraints from YAML and return a dict mapping (risk_name, scenario_name) to constraint info.
    """
    with open(constraint_yaml_path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    constraints = data.get("constraints", [])
    constraint_map = {}
    for c in constraints:
        key = (c["risk_name"], c["scenario_name"])
        constraint_map[key] = c
    return constraint_map

def inject_risks_to_file(
    input_path: str,
    output_path: str,
    config_path: str,
    constraint_yaml_path: str = "config/risk_constraints.yaml",
    max_workers: int = 5,
    output_format: Optional[str] = None,
    injection_config: Optional[InjectionConfig] = None,
    per_record_random_mode: bool = False,
    inject_all_applicable_risks: bool = False
):
    """Convenience function to load, inject, and save records."""
    # Load configuration
    logger.info(f"Loading configuration from {config_path}")
    config = RiskInjectionConfig.from_yaml(config_path)
    
    # Override output format if specified
    if output_format:
        if not config.output:
            config.output = {"file_format": output_format}
    else:
            config.output["file_format"] = output_format
    
    # Load constraint map
    constraint_map = load_constraints(constraint_yaml_path)
    
    # Load records
    logger.info(f"Loading records from {input_path}")
    records = load_records(input_path)
    logger.info(f"Loaded {len(records)} records")
    
    # Initialize injector based on config
    if config.mode == "openai":
        injector = OpenAIRiskInjector(config, constraint_map)
    else:
        raise ValueError(f"Unsupported mode: {config.mode}")
        
        # Inject risks
    logger.info(f"Injecting risks with {max_workers} workers, per_record_random_mode={per_record_random_mode}, inject_all_applicable_risks={inject_all_applicable_risks}")
    injected_records = injector.inject_batch(records, max_workers, per_record_random_mode, inject_all_applicable_risks)
        
    # Save injected records
    logger.info(f"Saving {len(injected_records)} injected records to {output_path}")
    save_records(injected_records, output_path, config.get_file_format())
    logger.info("Finished risk injection process") 