"""
Risk Injection Module

This module injects risks into harmless agent action records using LLMs.
"""

from typing import List, Dict, Any, Optional, Union
from pathlib import Path
from pydantic import BaseModel, Field
import yaml
import random
import json
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from loguru import logger
import re
import requests
from AuraGen.inference import InferenceManager, OpenAIConfig, externalAPIConfig
import os

from .injection_modes import InjectionMode, InjectionConfig

# --- Config Models ---

class RiskSpec(BaseModel):
    name: str
    description: str
    injection_probability: float = 0.1
    target: str  # "agent_action" or "agent_response"
    prompt_template: str
    category: str = Field("unknown", description="Risk category (e.g., 'hallucination', 'privacy_leak', etc.)")
    injection_modes: List[str] = Field(
        default_factory=lambda: ["single_action"],
        description="Supported injection modes for this risk"
    )
    chain_prompt_template: Optional[str] = Field(
        None,
        description="Template for chain modification prompts"
    )
    response_prompt_template: Optional[str] = Field(
        None,
        description="Template for response modification prompts"
    )

    def get_risk_type_description(self) -> str:
        """Get a human-readable description of the risk type."""
        return self.category.replace("_", " ").title()

    def get_prompt_for_mode(self, mode: str, is_response: bool = False) -> str:
        """Get the appropriate prompt template for the given mode."""
        if is_response and self.response_prompt_template:
            return self.response_prompt_template
        if mode in ["action_chain_with_response", "action_chain_only"] and self.chain_prompt_template:
            return self.chain_prompt_template
        return self.prompt_template

class RiskInjectionConfig(BaseModel):
    mode: str = Field("openai", pattern="^(openai|local)$")
    batch_size: int = 10
    externalAPI_generation: bool = Field(False, description="Whether to use externalAPI internal inference API")
    openai: Optional[OpenAIConfig] = None
    externalAPI: Optional[externalAPIConfig] = None
    risks: List[RiskSpec]
    output: Optional[Dict[str, str]] = Field(default_factory=lambda: {"file_format": "json"})
    auto_select_targets: bool = False

    @classmethod
    def from_yaml(cls, yaml_path: str) -> "RiskInjectionConfig":
        with open(yaml_path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        inj = data.get("injection", {})
        openai_cfg = None
        externalAPI_cfg = None
        
        if "openai" in data:
            openai_section = data.get("openai", {})
            if isinstance(openai_section, dict) and "api_key" not in openai_section and "api_key_type" in openai_section:
                mapping = {
                    "openai_api_key": "OPENAI_API_KEY",
                    "deepinfra_api_key": "DEEPINFRA_API_KEY",
                }
                key_type = openai_section.get("api_key_type")
                if key_type not in mapping:
                    raise ValueError(
                        f"Unknown api_key_type: {key_type}. Expected one of: {', '.join(mapping.keys())}"
                    )
                env_name = mapping[key_type]
                value = os.getenv(env_name, "").strip()
                if not value:
                    # Fallback to project .env
                    try:
                        project_root = Path(__file__).resolve().parents[1]
                        env_path = project_root / ".env"
                        if env_path.exists():
                            for line in env_path.read_text(encoding="utf-8").splitlines():
                                line = line.strip()
                                if not line or line.startswith("#") or "=" not in line:
                                    continue
                                k, v = line.split("=", 1)
                                if k.strip() == env_name:
                                    value = v.strip().strip('"')
                                    break
                    except Exception:
                        pass
                if not value:
                    raise ValueError(
                        f"Environment variable '{env_name}' not set for api_key_type '{key_type}'. "
                        f"Consider running: python config/configure_api_keys.py"
                    )
                openai_section["api_key"] = value
            openai_cfg = OpenAIConfig(**openai_section)
        if "externalAPI" in data:
            externalAPI_cfg = externalAPIConfig(**data.get("externalAPI", {}))
            
        risks = data.get("risks", [])
        return cls(
            mode=inj.get("mode", "openai"),
            batch_size=inj.get("batch_size", 10),
            externalAPI_generation=inj.get("externalAPI_generation", False),
            openai=openai_cfg,
            externalAPI=externalAPI_cfg,
            risks=risks,
            auto_select_targets=inj.get("auto_select_targets", False)
        )

    def get_file_format(self) -> str:
        """Get the configured file format."""
        if not self.output:
            return "json"
        return self.output.get("file_format", "json")

# --- Injector Base ---

class RiskInjectorBase:
    def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
        self.config = config
        self.risks = [RiskSpec(**r) if isinstance(r, dict) else r for r in config.risks]
        self.constraint_map = constraint_map or {}

    def is_risk_applicable(self, risk_name: str, scenario_name: str) -> Optional[Dict[str, Any]]:
        """
        Return constraint info if risk is applicable to the scenario, else None.
        """
        key = (risk_name, scenario_name)
        constraint = self.constraint_map.get(key)
        if constraint and constraint.get("compatibility", True):
            return constraint
        return None

    def inject_risk(self, record: Dict[str, Any], risk: RiskSpec, constraint: Dict[str, Any], injection_config: Optional[InjectionConfig] = None) -> Dict[str, Any]:
        """
        Inject risk into a record using the specified injection mode.
        Enhanced with robust error handling for external datasets.
        """
        try:
            # Validate and sanitize record
            record = self._validate_and_sanitize_record(record)
            
            # Create new record to avoid modifying original
            new_record = record.copy()
            
            # Extract original content and context
            original_content = record.get("agent_action", [])
            if isinstance(original_content, str):
                original_content = [original_content]
            
            # Get context information
            context_info = record.get("metadata", {}).get("context", {})
            if not context_info:
                context_info = {"available_tools": [], "environment": {}}
            
            # Add agent_response to context_info if available
            # This is important for ACTION_CHAIN_ONLY mode to consider existing response
            if "agent_response" in record and record["agent_response"]:
                context_info["agent_response"] = record["agent_response"]
            
            # Create default injection config if none provided
            if injection_config is None:
                injection_config = InjectionConfig()
            
            # Auto-select targets if needed
            if injection_config.auto_select_targets:
                if injection_config.mode in [InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS]:
                    injection_config.target_indices = self._select_injection_targets(
                        original_content,
                        risk,
                        injection_config.mode,
                        context_info
                    )
                else:  # Chain modes
                    injection_config.chain_start_index = self._select_injection_targets(
                        original_content,
                        risk,
                        injection_config.mode,
                        context_info
                    )
            
            # Apply injection based on mode
            if injection_config.mode == InjectionMode.SINGLE_ACTION:
                target_index = injection_config.target_indices[0] if injection_config.target_indices else 0
                new_action_list = self._inject_single_action(original_content, target_index, risk, context_info)
                modified_response = None
            
            elif injection_config.mode == InjectionMode.MULTIPLE_ACTIONS:
                target_indices = injection_config.target_indices or [0]
                new_action_list = self._inject_multiple_actions(original_content, target_indices, risk, context_info)
                modified_response = None
            
            elif injection_config.mode in [InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY]:
                with_response = injection_config.mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE
                start_index = injection_config.chain_start_index if injection_config.chain_start_index is not None else 0
                new_action_list = self._inject_action_chain(
                    original_content,
                    start_index,
                    risk,
                    context_info,
                    with_response
                )
                
                # Modify response if needed
                if with_response:
                    original_response = record.get("agent_response", "")
                    new_record["original_agent_response"] = original_response
                    # Add agent_response to context for response modification
                    response_context = context_info.copy()
                    response_context["agent_response"] = original_response
                    modified_response = self._get_injected_step(
                        original_response,
                        risk,
                        response_context,
                        injection_config.mode
                    )
                else:
                    modified_response = None
            
            # Update record with modified content
            new_record["agent_action"] = new_action_list
            if modified_response is not None:
                new_record["agent_response"] = modified_response
            
            # Update metadata safely
            self._update_metadata_safely(new_record, risk, injection_config, original_content, new_action_list, context_info, constraint, modified_response)
            
            return new_record
            
        except Exception as e:
            logger.error(f"Exception in inject_risk: {e}")
            logger.exception("Full traceback in inject_risk:")
            return record
    
    def _validate_and_sanitize_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
        """
        Validate and sanitize a record, ensuring it has the minimum required structure.
        This is especially important for records from external datasets.
        """
        if not isinstance(record, dict):
            raise ValueError(f"Record must be a dictionary, got {type(record)}")
        
        # Ensure required fields exist with sensible defaults
        sanitized = dict(record)
        
        # scenario_name
        if "scenario_name" not in sanitized or not sanitized["scenario_name"]:
            sanitized["scenario_name"] = "external_dataset"
            
        # user_request
        if "user_request" not in sanitized:
            sanitized["user_request"] = ""
            
        # agent_action - this is critical
        if "agent_action" not in sanitized:
            sanitized["agent_action"] = []
        elif not isinstance(sanitized["agent_action"], list):
            # Try to convert to list if it's a string
            if isinstance(sanitized["agent_action"], str):
                # Simple conversion - split by common delimiters
                action_str = sanitized["agent_action"]
                for delimiter in ['\n', '.', ';', '|']:
                    if delimiter in action_str:
                        steps = [step.strip() for step in action_str.split(delimiter)]
                        sanitized["agent_action"] = [step for step in steps if step]
                        break
                else:
                    # No delimiter found, treat as single action
                    sanitized["agent_action"] = [action_str.strip()] if action_str.strip() else []
            else:
                # Convert other types to string list
                sanitized["agent_action"] = [str(sanitized["agent_action"])]
                
        # agent_response
        if "agent_response" not in sanitized:
            sanitized["agent_response"] = ""
            
        # metadata
        if "metadata" not in sanitized:
            sanitized["metadata"] = {}
        elif not isinstance(sanitized["metadata"], dict):
            # If metadata is not a dict, create a new one and store the original value
            original_metadata = sanitized["metadata"]
            sanitized["metadata"] = {"original_metadata": original_metadata}
            
        return sanitized
    
    def _extract_context_info(self, record: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract context information from a record with robust fallbacks.
        """
        # Try different possible locations for context info
        context_info = {}
        
        # First, try the standard location
        if "metadata" in record and isinstance(record["metadata"], dict):
            if "context" in record["metadata"] and isinstance(record["metadata"]["context"], dict):
                context_info = record["metadata"]["context"].copy()
        
        # If no context found, try to infer from other fields
        if not context_info:
            # Try to extract tools from action text
            available_tools = self._extract_tools_from_actions(record.get("agent_action", []))
            if available_tools:
                context_info["available_tools"] = available_tools
                
            # Set a default environment
            context_info["environment"] = record.get("scenario_name", "unknown")
            
            # Add any other metadata fields that might be useful
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict):
                for key, value in metadata.items():
                    if key not in context_info and isinstance(value, (str, list, dict)):
                        context_info[key] = value
        
        return context_info
    
    def _extract_tools_from_actions(self, actions: List[str]) -> List[str]:
        """
        Extract tool names from action text using pattern matching.
        """
        tools = set()
        
        for action in actions:
            if not isinstance(action, str):
                continue
                
            # Look for function call patterns: tool_name(...)
            import re
            function_calls = re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action)
            tools.update(function_calls)
            
            # Look for common tool keywords
            common_tools = [
                'search', 'send_email', 'get', 'post', 'read', 'write', 'create', 'delete',
                'analyze', 'process', 'generate', 'calculate', 'validate', 'execute'
            ]
            
            action_lower = action.lower()
            for tool in common_tools:
                if tool in action_lower:
                    tools.add(tool)
        
        return list(tools)
    
    def _update_metadata_safely(self, record: Dict[str, Any], risk: RiskSpec, injection_config: InjectionConfig, 
                               original_content: List[str], new_action_list: List[str], context_info: Dict[str, Any], 
                               constraint: Dict[str, Any], modified_response: Optional[str]):
        """
        Safely update metadata without overwriting existing important information.
        """
        # Ensure metadata exists
        if "metadata" not in record:
            record["metadata"] = {}
        elif not isinstance(record["metadata"], dict):
            record["metadata"] = {"original_metadata": record["metadata"]}
            
        # Ensure risk_injection list exists
        if "risk_injection" not in record["metadata"]:
            record["metadata"]["risk_injection"] = []
        elif not isinstance(record["metadata"]["risk_injection"], list):
            record["metadata"]["risk_injection"] = []
        
        # Extract injection summary if available
        injection_summary = context_info.get("injection_summary", "")
        
        # Extract modified functions (including new and parameter-changed functions)
        modified_functions = context_info.get("modified_functions", [])
        
        # For backward compatibility - also check for new_functions
        legacy_new_functions = context_info.get("new_functions", [])
        
        # Store injection metadata
        injection_info = {
            "risk_name": risk.name,
            "description": risk.description,
            "injection_mode": injection_config.mode.value,
            "target_indices": injection_config.target_indices,
            "chain_start_index": injection_config.chain_start_index,
            "auto_selected": injection_config.auto_select_targets,
            "has_response_modification": modified_response is not None,
            "original_actions": original_content,
            "modified_actions": new_action_list,
            "injection_summary": injection_summary,  # Add the summary to metadata
            "context": {
                "available_tools": context_info.get("available_tools", []),
                "environment": context_info.get("environment", {})
            },
            "constraint": constraint,
            "timestamp": int(time.time())
        }
        
        if modified_functions:
            injection_info["modified_functions"] = modified_functions
            
            new_funcs = [f for f in modified_functions if f.get("is_new", False)]
            if new_funcs:
                injection_info["new_functions"] = new_funcs
        elif legacy_new_functions:
            injection_info["new_functions"] = legacy_new_functions
        
        if modified_response is not None:
            injection_info["modified_agent_response"] = modified_response
            
        if "original_agent_response" in record:
            injection_info["original_agent_response"] = record["original_agent_response"]
        
        record["metadata"]["risk_injection"].append(injection_info)
        record["metadata"]["risk_injection_time"] = int(time.time())

    def inject_batch(self, records: List[Dict[str, Any]], max_workers: int = 5, per_record_random_mode: bool = False, inject_all_applicable_risks: bool = False) -> List[Dict[str, Any]]:
        # For each record, find all applicable risks, inject all or randomly pick one
        tasks = []
        for rec in records:
            scenario_name = rec.get("scenario_name", "unknown")
            applicable = []
            for risk in self.risks:
                constraint = self.is_risk_applicable(risk.name, scenario_name)
                if constraint:
                    applicable.append((risk, constraint))
            
            if applicable:
                if inject_all_applicable_risks:
                    # Inject all applicable risks - create a copy of record for each risk
                    for risk, constraint in applicable:
                        # Use a deep copy of the original record for each risk
                        import copy
                        record_copy = copy.deepcopy(rec)
                        
                        # If per_record_random_mode is True, create a random injection config for each record
                        if per_record_random_mode:
                            mode = random.choice([
                                InjectionMode.SINGLE_ACTION,
                                InjectionMode.MULTIPLE_ACTIONS,
                                InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
                                InjectionMode.ACTION_CHAIN_ONLY
                            ])
                            injection_config = InjectionConfig(
                                mode=mode,
                                auto_select_targets=True,
                                modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
                            )
                            tasks.append((record_copy, risk, constraint, injection_config))
                        else:
                            tasks.append((record_copy, risk, constraint, None))
                else:
                    # Randomly pick one risk to inject for this record (original behavior)
                    risk, constraint = random.choice(applicable)
                    
                    # If per_record_random_mode is True, create a random injection config for each record
                    if per_record_random_mode:
                        mode = random.choice([
                            InjectionMode.SINGLE_ACTION,
                            InjectionMode.MULTIPLE_ACTIONS,
                            InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
                            InjectionMode.ACTION_CHAIN_ONLY
                        ])
                        injection_config = InjectionConfig(
                            mode=mode,
                            auto_select_targets=True,
                            modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
                        )
                        tasks.append((rec, risk, constraint, injection_config))
                    else:
                        tasks.append((rec, risk, constraint, None))
            else:
                # No applicable risk, just keep the record as is
                tasks.append((rec, None, None, None))

        injected = []
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            for task in tasks:
                if len(task) == 4:  # New format with injection_config
                    rec, risk, constraint, injection_config = task
                    if risk is not None:
                        futures.append(executor.submit(self.inject_risk, rec, risk, constraint, injection_config))
                    else:
                        injected.append(rec)
                else:  # Old format compatibility
                    rec, risk, constraint = task
                    if risk is not None:
                        futures.append(executor.submit(self.inject_risk, rec, risk, constraint))
                    else:
                        injected.append(rec)
            
            with tqdm(total=len(futures), desc="Injecting risks") as pbar:
                for future in as_completed(futures):
                    try:
                        result = future.result()
                        injected.append(result)
                        pbar.update(1)
                    except Exception as e:
                        logger.error(f"Risk injection error: {e}")
                        pbar.update(1)
        return injected

# --- OpenAI Injector ---

class OpenAIRiskInjector(RiskInjectorBase):
    """
    Risk injector using OpenAI API with enhanced context awareness.
    """
    def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
        super().__init__(config, constraint_map)
        
        self.inference_manager = InferenceManager(
            use_internal_inference=config.externalAPI_generation,
            openai_config=config.openai,
            externalAPI_config=config.externalAPI
        )
        
        logger.info(f"Initialized OpenAIRiskInjector with {'externalAPI API' if config.externalAPI_generation else 'OpenAI API'}")

    def _inject_single_action(self, action_list: List[str], target_index: int, risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
        """Inject risk into a single action."""
        if not 0 <= target_index < len(action_list):
            logger.warning(f"Invalid target index {target_index} for action list of length {len(action_list)}")
            return action_list

        selected_step = action_list[target_index]
        injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.SINGLE_ACTION)
        
        new_action_list = action_list.copy()
        new_action_list[target_index] = injected_step
        return new_action_list

    def _inject_multiple_actions(self, action_list: List[str], target_indices: List[int], risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
        """Inject risk into multiple actions."""
        new_action_list = action_list.copy()
        for idx in target_indices:
            if 0 <= idx < len(action_list):
                selected_step = action_list[idx]
                injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.MULTIPLE_ACTIONS)
                new_action_list[idx] = injected_step
            else:
                logger.warning(f"Invalid target index {idx} for action list of length {len(action_list)}")
        return new_action_list

    def _inject_action_chain(self, action_list: List[str], start_index: int, risk: RiskSpec, context_info: Dict[str, Any], with_response: bool = False) -> List[str]:
        """Inject risk into action chain starting from given index."""
        if not 0 <= start_index < len(action_list):
            logger.warning(f"Invalid start index {start_index} for action list of length {len(action_list)}")
            return action_list

        # Prepare the chain context
        chain_context = {
            "original_actions": action_list[start_index:],
            "available_tools": context_info.get("available_tools", []),
            "environment": context_info.get("environment", {})
        }
        
        # For ACTION_CHAIN_ONLY mode, we need to pass the existing response 
        # so the chain modification can be consistent with it
        mode = InjectionMode.ACTION_CHAIN_WITH_RESPONSE if with_response else InjectionMode.ACTION_CHAIN_ONLY
        if not with_response and "agent_response" in context_info:
            chain_context["agent_response"] = context_info["agent_response"]

        # Get the modified chain
        injected_chain = self._get_injected_chain(chain_context, risk, mode)

        # Create the new action list
        new_action_list = action_list[:start_index] + injected_chain

        return new_action_list

    def _get_injected_step(self, step: str, risk: RiskSpec, context_info: Dict[str, Any], mode: InjectionMode) -> str:
        """Get injected version of a single step with improved multi-step approach."""
        # Build context prompt
        context_prompt = self._build_context_prompt(context_info)
        
        # Determine if this is a response modification
        is_response = mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE and isinstance(step, str) and "agent_response" in context_info
        
        # Add response consistency guidance if needed
        response_guidance = ""
        if not is_response and "agent_response" in context_info and mode != InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
            response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modification must remain consistent with this existing response:
"{context_info['agent_response']}"

Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
        
        # STEP 1: Generate the modified content only
        step1_prompt = f"""
You are tasked with subtly modifying an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.
The modification should appear legitimate and not be explicitly marked as malicious or invalid.

{risk.get_prompt_for_mode(mode.value, is_response=is_response)}

{context_prompt}

{response_guidance}

Guidelines for SUBTLE and COVERT modification:
1. Maintain the basic structure and apparent functionality
2. Use realistic-looking {'language' if is_response else 'function names and parameters'}
3. The modification must be subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modification
8. You can either introduce new functions or add new parameters to existing functions

Original {'response' if is_response else 'step'}: {step}

Return ONLY the modified content without any explanation, summary, or additional text.
"""
        
        # Generate modified content
        logger.info("Generating modified content...")
        modified_content = self.inference_manager.generate_text(
            prompt=step1_prompt,
            system_message="You are a tool that modifies text according to instructions. Return ONLY the modified content."
        ).strip()
        
        # Fallback if no meaningful modification
        if not modified_content or modified_content == step:
            logger.warning("No meaningful modification was made, using original step")
            modified_content = step
        
        # STEP 2: Generate a summary of the injection - 加强措施确保一定有摘要
        step2_prompt = f"""
You have modified an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.

Original: {step}
Modified: {modified_content}

Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.

Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
        
        # Generate summary with multiple retries to ensure we get one
        logger.info("Generating injection summary...")
        summary = ""
        max_attempts = 3
        
        for attempt in range(max_attempts):
            summary = self.inference_manager.generate_text(
                prompt=step2_prompt,
                system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
            ).strip()
            
            if summary:
                logger.info(f"Got summary on attempt {attempt+1}")
                break
            else:
                logger.warning(f"Empty summary on attempt {attempt+1}, retrying...")
        
        # Ensure we have a summary, no matter what
        if not summary:
            logger.warning("Failed to generate summary after multiple attempts, using default")
            summary = f"Injected {risk.get_risk_type_description()} risk by modifying the original {'response' if is_response else 'step'} to potentially cause unexpected behavior."
        
        # Store summary in context_info
        context_info["injection_summary"] = summary
        
        # STEP 3: Identify and analyze function changes
        if not is_response:  # Only analyze action steps for function changes
            def extract_function_calls_with_params(text):
                pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)'
                matches = re.findall(pattern, text)
                result = {}
                for func_name, params in matches:
                    result[func_name] = params.strip()
                return result
            
            def parse_params(param_str):
                params = {}
                if not param_str:
                    return params
                    
                try:
                    in_string = False
                    in_bracket = 0
                    current_key = None
                    current_value = ""
                    i = 0
                    
                    while i < len(param_str):
                        c = param_str[i]
                        
                        if c in ['"', "'"]:
                            in_string = not in_string
                            current_value += c
                        elif c == '(' and not in_string:
                            in_bracket += 1
                            current_value += c
                        elif c == ')' and not in_string:
                            in_bracket -= 1
                            current_value += c
                        elif c == ',' and not in_string and in_bracket == 0:
                            if current_key:
                                params[current_key.strip()] = current_value.strip()
                            current_key = None
                            current_value = ""
                        elif c == '=' and not in_string and not current_key and in_bracket == 0:
                            current_key = current_value
                            current_value = ""
                        else:
                            current_value += c
                        i += 1
                    
                    if current_key:
                        params[current_key.strip()] = current_value.strip()
                    
                    if not params:
                        pos_params = param_str.split(',')
                        for i, p in enumerate(pos_params):
                            params[f"param{i+1}"] = p.strip()
                            
                except Exception as e:
                    logger.error(f"Error parsing params: {e}")
                    try:
                        pos_params = param_str.split(',')
                        for i, p in enumerate(pos_params):
                            params[f"param{i+1}"] = p.strip()
                    except:
                        pass
                
                return params
            
            original_funcs = extract_function_calls_with_params(step)
            modified_funcs = extract_function_calls_with_params(modified_content)
            
            # Common built-in functions to exclude
            builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range', 
                       'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'}
            
            new_functions = []
            modified_functions = []
            
            new_func_names = [name for name in modified_funcs.keys() 
                             if name not in original_funcs and 
                             name not in builtins and
                             not name.startswith('__')]
            
            for func_name in new_func_names:
                logger.info(f"Detected new function: {func_name}")
                step3_prompt = f"""
You need to define a new function '{func_name}' that was introduced in this modified action step: 
{modified_content}

This function was NOT present in the original step: 
{step}

Current function call: {func_name}({modified_funcs[func_name]})

Create a complete YAML definition for this function that follows EXACTLY this format without any deviation:

```
- name: {func_name}
  display_name: Human Readable Function Name
  description: A detailed description of what the function does
  category: appropriate_category
  parameters:
  - name: param1
    type: param_type
    description: Parameter description
    required: true
  - name: param2
    type: param_type
    description: Parameter description
    required: false
    allowed_values:
    - value1
    - value2
  example_usage:
  - {func_name}(param1=value1, param2=value2)
```

Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified step
4. Include ALL parameters shown in the function call
5. Use appropriate category based on function purpose
6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text
7. Include ONLY the YAML content itself with no additional commentary

DO NOT include anything outside of this YAML format in your response.
"""
                
                function_def = self.inference_manager.generate_text(
                    prompt=step3_prompt,
                    system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text."
                ).strip()
                
                function_def = re.sub(r'^```yaml\n', '', function_def)
                function_def = re.sub(r'^```\n', '', function_def)
                function_def = re.sub(r'\n```$', '', function_def)
                function_def = re.sub(r'^`+', '', function_def)
                function_def = re.sub(r'`+$', '', function_def)
                
                if 'example_usage:' in function_def:
                    parts = function_def.split('example_usage:')
                    if len(parts) > 1:
                        example_section = parts[1]
                        lines = example_section.split('\n')
                        valid_lines = []
                        for line in lines:
                            if line.strip().startswith('-'):
                                valid_lines.append(line)
                            elif not line.strip() or line.strip().startswith('#'):
                                valid_lines.append(line)
                            else:
                                break
                        function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines)
                
                if function_def:
                    new_functions.append({
                        "name": func_name,
                        "definition": function_def,
                        "format": "yaml"
                    })
            
            for func_name in original_funcs:
                if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'):
                    original_params = parse_params(original_funcs[func_name])
                    modified_params = parse_params(modified_funcs[func_name])
                    
                    new_params = {}
                    for param_name, param_value in modified_params.items():
                        if param_name not in original_params:
                            new_params[param_name] = param_value
                    
                    if new_params:
                        logger.info(f"Detected modified function with new parameters: {func_name}")
                        for param_name, param_value in new_params.items():
                            step3_prompt = f"""
You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified step:

Original function call: {func_name}({original_funcs[func_name]})
Modified function call: {func_name}({modified_funcs[func_name]})

Provide information about this new parameter in JSON format with these fields:
1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool)
2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way)

Return ONLY a JSON object with these two fields, nothing else.
Example: {{"type": "string", "description": "The file path to write the output to"}}
"""
                            
                            param_info = self.inference_manager.generate_text(
                                prompt=step3_prompt,
                                system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.",
                                response_format={"type": "json_object"}
                            ).strip()
                            
                            try:
                                param_data = json.loads(param_info)
                                modified_functions.append({
                                    "function_name": func_name,
                                    "new_parameter_name": param_name,
                                    "description": param_data.get("description", "Added parameter"),
                                    "type": param_data.get("type", "unknown")
                                })
                            except json.JSONDecodeError:
                                logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}")
                                modified_functions.append({
                                    "function_name": func_name,
                                    "new_parameter_name": param_name,
                                    "description": f"Added parameter with value {param_value}",
                                    "type": "unknown"
                                })
            
            if new_functions:
                context_info["new_functions"] = new_functions
                logger.info(f"Generated {len(new_functions)} new function definitions")
                
            if modified_functions:
                context_info["modified_functions"] = modified_functions
                logger.info(f"Detected {len(modified_functions)} modified parameters")
        
        return modified_content

    def _get_injected_chain(self, chain_context: Dict[str, Any], risk: RiskSpec, mode: InjectionMode) -> List[str]:
        """Get injected version of an action chain using improved multi-step approach."""
        # Build context prompt
        context_prompt = self._build_context_prompt(chain_context)
        
        # Ensure original_actions is a list
        original_actions = chain_context['original_actions']
        if not isinstance(original_actions, list):
            original_actions = [original_actions]
            chain_context['original_actions'] = original_actions
        
        # Add response consistency guidance for ACTION_CHAIN_ONLY mode
        response_guidance = ""
        if mode == InjectionMode.ACTION_CHAIN_ONLY and "agent_response" in chain_context:
            response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modifications must remain consistent with this existing response:
"{chain_context['agent_response']}"

Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
        
        # STEP 1: Generate the modified action chain
        step1_prompt = f"""
You are tasked with modifying a sequence of actions to introduce a {risk.get_risk_type_description()} risk.
The modifications should appear legitimate and create a coherent chain of actions.

{risk.get_prompt_for_mode(mode.value)}

{context_prompt}

{response_guidance}

Guidelines for SUBTLE and COVERT modification:
1. Maintain the overall goal and apparent functionality
2. Ensure the modified actions form a logical sequence
3. Make the risk modifications subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modifications
8. You can either introduce new functions or add new parameters to existing functions

Original action sequence:
{json.dumps(original_actions, indent=2)}

Return ONLY an array of strings representing the modified action steps. 
Format your response as a JSON array of strings WITHOUT any explanation or additional text.
Example: ["step1(param='value')", "step2(param='value')"]
"""
        
        # Generate modified action chain
        logger.info("Generating modified action chain...")
        response_content = self.inference_manager.generate_text(
            prompt=step1_prompt,
            system_message="You are a tool that modifies action sequences. Return ONLY a JSON array of action steps."
        )
        
        # Parse the response to get modified actions
        modified_actions = original_actions.copy()  # Default to original if parsing fails
        try:
            # Try direct parsing first
            try:
                parsed_content = json.loads(response_content)
                if isinstance(parsed_content, list) and all(isinstance(item, str) for item in parsed_content):
                    modified_actions = parsed_content
                else:
                    logger.warning("Response is not a valid list of strings, trying to extract...")
            except json.JSONDecodeError:
                # Try to extract JSON array using regex
                array_pattern = r'\[.*\]'
                array_matches = re.findall(array_pattern, response_content, re.DOTALL)
                if array_matches:
                    for match in array_matches:
                        try:
                            parsed_array = json.loads(match)
                            if isinstance(parsed_array, list) and all(isinstance(item, str) for item in parsed_array):
                                modified_actions = parsed_array
                                break
                        except json.JSONDecodeError:
                            continue
        except Exception as e:
            logger.error(f"Error parsing modified actions: {e}")
            # Keep original actions in case of error
        
        # Ensure we have meaningful modifications
        if modified_actions == original_actions:
            logger.warning("No meaningful chain modifications were made, using original chain")
        
        # STEP 2: Generate a summary of the injection
        step2_prompt = f"""
You have modified a sequence of actions to introduce a {risk.get_risk_type_description()} risk.

Original actions:
{json.dumps(original_actions, indent=2)}

Modified actions:
{json.dumps(modified_actions, indent=2)}

Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.

Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
        
        # Generate summary with multiple retries to ensure we get one
        logger.info("Generating injection summary...")
        summary = ""
        max_attempts = 3
        
        for attempt in range(max_attempts):
            summary = self.inference_manager.generate_text(
                prompt=step2_prompt,
                system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
            ).strip()
            
            if summary:
                logger.info(f"Got summary on attempt {attempt+1}")
                break
            else:
                logger.warning(f"Empty summary on attempt {attempt+1}, retrying...")
        
        # Ensure we have a summary, no matter what
        if not summary:
            logger.warning("Failed to generate summary after multiple attempts, using default")
            summary = f"Injected {risk.get_risk_type_description()} risk by modifying the action chain to potentially cause unexpected behavior."
        
        # Store summary in chain_context
        chain_context["injection_summary"] = summary
        
        # STEP 3: Identify and analyze function changes
        # 提取函数调用及其参数
        def extract_function_calls_with_params(action_list):
            result = {}
            for action in action_list:
                if not isinstance(action, str):
                    continue
                pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)'
                matches = re.findall(pattern, action)
                for func_name, params in matches:
                    result[func_name] = params.strip()
            return result
        
        def parse_params(param_str):
            params = {}
            if not param_str:
                return params
                
            try:
                in_string = False
                in_bracket = 0
                current_key = None
                current_value = ""
                i = 0
                
                while i < len(param_str):
                    c = param_str[i]
                    
                    if c in ['"', "'"]:
                        in_string = not in_string
                        current_value += c
                    elif c == '(' and not in_string:
                        in_bracket += 1
                        current_value += c
                    elif c == ')' and not in_string:
                        in_bracket -= 1
                        current_value += c
                    elif c == ',' and not in_string and in_bracket == 0:
                        if current_key:
                            params[current_key.strip()] = current_value.strip()
                        current_key = None
                        current_value = ""
                    elif c == '=' and not in_string and not current_key and in_bracket == 0:
                        current_key = current_value
                        current_value = ""
                    else:
                        current_value += c
                    i += 1
                
                if current_key:
                    params[current_key.strip()] = current_value.strip()
                
                if not params:
                    pos_params = param_str.split(',')
                    for i, p in enumerate(pos_params):
                        params[f"param{i+1}"] = p.strip()
                        
            except Exception as e:
                logger.error(f"Error parsing params: {e}")
                try:
                    pos_params = param_str.split(',')
                    for i, p in enumerate(pos_params):
                        params[f"param{i+1}"] = p.strip()
                except:
                    pass
            
            return params
        
        original_funcs = extract_function_calls_with_params(original_actions)
        modified_funcs = extract_function_calls_with_params(modified_actions)
        
        # Common built-in functions to exclude
        builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range', 
                   'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'}
        
        new_functions = []
        modified_functions = []
        
        new_func_names = [name for name in modified_funcs.keys() 
                         if name not in original_funcs and 
                         name not in builtins and
                         not name.startswith('__')]
        
        for func_name in new_func_names:
            logger.info(f"Detected new function in chain: {func_name}")
            step3_prompt = f"""
You need to define a new function '{func_name}' that was introduced in this modified action chain: 
{json.dumps(modified_actions, indent=2)}

This function was NOT present in the original chain: 
{json.dumps(original_actions, indent=2)}

Current function call: {func_name}({modified_funcs[func_name]})

Create a complete YAML definition for this function that follows EXACTLY this format without any deviation:

```
- name: {func_name}
  display_name: Human Readable Function Name
  description: A detailed description of what the function does
  category: appropriate_category
  parameters:
  - name: param1
    type: param_type
    description: Parameter description
    required: true
  - name: param2
    type: param_type
    description: Parameter description
    required: false
    allowed_values:
    - value1
    - value2
  example_usage:
  - {func_name}(param1=value1, param2=value2)
```

Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified actions
4. Include ALL parameters shown in the function call
5. Use appropriate category based on function purpose
6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text
7. Include ONLY the YAML content itself with no additional commentary

DO NOT include anything outside of this YAML format in your response.
"""
            
            function_def = self.inference_manager.generate_text(
                prompt=step3_prompt,
                system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text."
            ).strip()
            
            function_def = re.sub(r'^```yaml\n', '', function_def)
            function_def = re.sub(r'^```\n', '', function_def)
            function_def = re.sub(r'\n```$', '', function_def)
            function_def = re.sub(r'^`+', '', function_def)
            function_def = re.sub(r'`+$', '', function_def)
            
            if 'example_usage:' in function_def:
                parts = function_def.split('example_usage:')
                if len(parts) > 1:
                    example_section = parts[1]
                    lines = example_section.split('\n')
                    valid_lines = []
                    for line in lines:
                        if line.strip().startswith('-'):
                            valid_lines.append(line)
                        elif not line.strip() or line.strip().startswith('#'):
                            valid_lines.append(line)
                        else:
                            break
                    function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines)
            
            if function_def:
                new_functions.append({
                    "name": func_name,
                    "definition": function_def,
                    "format": "yaml"
                })
        
        for func_name in original_funcs:
            if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'):
                original_params = parse_params(original_funcs[func_name])
                modified_params = parse_params(modified_funcs[func_name])
                
                new_params = {}
                for param_name, param_value in modified_params.items():
                    if param_name not in original_params:
                        new_params[param_name] = param_value
                
                if new_params:
                    logger.info(f"Detected modified function with new parameters in chain: {func_name}")
                    for param_name, param_value in new_params.items():
                        step3_prompt = f"""
You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified action chain:

Original function call: {func_name}({original_funcs[func_name]})
Modified function call: {func_name}({modified_funcs[func_name]})

Provide information about this new parameter in JSON format with these fields:
1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool)
2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way)

Return ONLY a JSON object with these two fields, nothing else.
Example: {{"type": "string", "description": "The file path to write the output to"}}
"""
                        
                        param_info = self.inference_manager.generate_text(
                            prompt=step3_prompt,
                            system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.",
                            response_format={"type": "json_object"}
                        ).strip()
                        
                        try:
                            param_data = json.loads(param_info)
                            modified_functions.append({
                                "function_name": func_name,
                                "new_parameter_name": param_name,
                                "description": param_data.get("description", "Added parameter"),
                                "type": param_data.get("type", "unknown")
                            })
                        except json.JSONDecodeError:
                            logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}")
                            modified_functions.append({
                                "function_name": func_name,
                                "new_parameter_name": param_name,
                                "description": f"Added parameter with value {param_value}",
                                "type": "unknown"
                            })
        
        if new_functions:
            chain_context["new_functions"] = new_functions
            logger.info(f"Generated {len(new_functions)} new function definitions")
            
        if modified_functions:
            chain_context["modified_functions"] = modified_functions
            logger.info(f"Detected {len(modified_functions)} modified parameters")
        
        return modified_actions

    def _build_context_prompt(self, context_info: Dict[str, Any]) -> str:
        """Build context prompt from available information."""
        context_prompt = "\nAvailable Context:"
        
        # Add tool information
        used_tools = context_info.get("available_tools", [])
        if used_tools:
            context_prompt += f"\n- Tools in use: {', '.join(used_tools)}"
        
        # Add environment information
        env_info = context_info.get("environment", {})
        if env_info:
            context_prompt += "\n- Environment Configuration:"
            if isinstance(env_info, dict):
                for key, value in env_info.items():
                    if isinstance(value, dict):
                        context_prompt += f"\n  {key}:"
                        for k, v in value.items():
                            context_prompt += f"\n    {k}: {v}"
                    else:
                        context_prompt += f"\n  {key}: {value}"
            else:
                # Handle case where environment is a string
                context_prompt += f"\n  environment: {env_info}"
        
        return context_prompt

    def _select_injection_targets(self, action_list: List[str], risk: RiskSpec, mode: InjectionMode, context_info: Dict[str, Any]) -> Union[List[int], int]:
        """Let the model select appropriate injection targets."""
        # Build context prompt
        context_prompt = self._build_context_prompt(context_info)
        
        # Prepare prompt based on mode
        if mode == InjectionMode.SINGLE_ACTION:
            prompt = f"""
You are tasked with selecting the most appropriate action step to inject a {risk.get_risk_type_description()} risk.
The selection should maximize the impact and plausibility of the risk injection.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a step where the risk injection would be most natural
2. Consider the context and available tools
3. Select a step that allows for subtle and plausible modification

Output only the index (0-based) of the selected step as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                    return [max(0, min(index, len(action_list) - 1))]  # Ensure index is within bounds
                except (ValueError, IndexError):
                    return [0]  # Default to first action if parsing fails
            else:
                response_tuple = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                    return_usage=True,
                )
                response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple
            
            try:
                index = int(response_content.strip())
                return [max(0, min(index, len(action_list) - 1))]  # Ensure index is within bounds
            except (ValueError, IndexError):
                return [0]  # Default to first action if parsing fails

        elif mode == InjectionMode.MULTIPLE_ACTIONS:
            prompt = f"""
You are tasked with selecting multiple action steps to inject a {risk.get_risk_type_description()} risk.
The selection should create a coherent pattern of risk injection across multiple steps.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose steps that can form a coherent risk pattern
2. Consider dependencies between steps
3. Select 2-4 steps for balanced impact

Output only a JSON array of selected indices (0-based):
"""
            # Call LLM
            if self.config.externalAPI_generation:
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    indices = json.loads(response_content.strip())
                    # Ensure indices are valid
                    return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
                except (json.JSONDecodeError, ValueError):
                    return [0, len(action_list) - 1]  # Default to first and last actions
            else:
                response_tuple = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                    return_usage=True,
                )
                response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple
            
            try:
                indices = json.loads(response_content.strip())
                # Ensure indices are valid
                return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
            except (json.JSONDecodeError, ValueError):
                return [0, len(action_list) - 1]  # Default to first and last actions

        elif mode == InjectionMode.ACTION_CHAIN_ONLY:
            # For ACTION_CHAIN_ONLY we need to consider that the response won't be modified
            # So we need to select an injection point that won't make the existing response inconsistent
            
            # Get agent response if available
            agent_response = context_info.get("agent_response", "")
            has_response = bool(agent_response)
            
            response_context = ""
            if has_response:
                response_context = f"\nNote that the agent response is: \"{agent_response}\"\nYou must select an injection point that won't make this response inconsistent."
            
            prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should allow for a coherent chain of modified actions WITHOUT affecting the agent's response.

{context_prompt}{response_context}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a point that allows for meaningful chain modification
2. Consider the remaining steps after the starting point
3. Select a point that maintains action sequence coherence
4. IMPORTANT: The agent's response will NOT be modified, so choose a point where injected risk won't contradict the existing response

Output only the index (0-based) of the starting point as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                    return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
                except (ValueError, IndexError):
                    return 0  # Default to first action
            else:
                # 使用OpenAI API
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
            try:
                index = int(response_content.strip())
                return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
            except (ValueError, IndexError):
                return 0  # Default to first action

        elif mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
            # For ACTION_CHAIN_WITH_RESPONSE, we can be more flexible since we'll also modify the response
            prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should maximize impact while allowing for a coherent chain of modified actions AND a modified response.

{context_prompt}

Available action steps:
{json.dumps(action_list, indent=2)}

Guidelines:
1. Choose a point that allows for meaningful chain modification with significant impact
2. Consider the remaining steps after the starting point
3. Select a point that maximizes the potential for risk injection
4. IMPORTANT: The agent's response will ALSO be modified to be consistent with the injected risk

Output only the index (0-based) of the starting point as a single number:
"""
            # Call LLM
            if self.config.externalAPI_generation:
                try:
                    response_content = self.inference_manager.generate_text(
                        prompt=prompt,
                        system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
                        response_format={"type": "json_object"}
                    )
                    index = int(response_content.strip())
                    return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
                except (ValueError, IndexError):
                    return 0  # Default to first action
            else:
                response_content = self.inference_manager.generate_text(
                    prompt=prompt,
                    system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
                )
            
            try:
                index = int(response_content.strip())
                return max(0, min(index, len(action_list) - 2))  # Ensure at least one step remains after
            except (ValueError, IndexError):
                return 0  # Default to first action

# --- File I/O and Main Pipeline ---

def save_records(records: List[Dict[str, Any]], out_path: str, file_format: str = "json"):
    """
    Save records to JSON or JSONL format.
    
    Args:
        records: List of records to save
        out_path: Output file path
        file_format: Output format - "json" or "jsonl" (default: "json")
    """
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    
    if file_format is None:
        file_format = "json"
    
    logger.info(f"Saving {len(records)} records to {out_path} in {file_format.upper()} format")
    
    if file_format == "json":
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(records, f, ensure_ascii=False, indent=2)
    else:
        # Save as JSONL (one JSON object per line)
        with open(out_path, "w", encoding="utf-8") as f:
            for rec in records:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    
    logger.info(f"Records saved to {out_path}")

# For backward compatibility
save_records_to_jsonl = lambda records, out_path: save_records(records, out_path, "jsonl")

def load_records(path: str) -> List[Dict[str, Any]]:
    """
    Load records from JSON or JSONL file, automatically detecting format.
    
    Args:
        path: Path to the file to load
        
    Returns:
        List of record dictionaries
    """
    records = []
    
    with open(path, "r", encoding="utf-8") as f:
        content = f.read().strip()
    
    if content.startswith("["):
        # JSON array format
        records = json.loads(content)
        logger.info(f"Loaded {len(records)} records from {path} (JSON format)")
    else:
        # JSONL format
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    records.append(json.loads(line))
        logger.info(f"Loaded {len(records)} records from {path} (JSONL format)")
    
    return records

# For backward compatibility
load_records_from_jsonl = load_records

def load_constraints(constraint_yaml_path: str) -> Dict[tuple, Dict[str, Any]]:
    """
    Load risk-scenario constraints from YAML and return a dict mapping (risk_name, scenario_name) to constraint info.
    """
    with open(constraint_yaml_path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    constraints = data.get("constraints", [])
    constraint_map = {}
    for c in constraints:
        key = (c["risk_name"], c["scenario_name"])
        constraint_map[key] = c
    return constraint_map

def inject_risks_to_file(
    input_path: str,
    output_path: str,
    config_path: str,
    constraint_yaml_path: str = "config/risk_constraints.yaml",
    max_workers: int = 5,
    output_format: Optional[str] = None,
    injection_config: Optional[InjectionConfig] = None,
    per_record_random_mode: bool = False,
    inject_all_applicable_risks: bool = False
):
    """Convenience function to load, inject, and save records."""
    # Load configuration
    logger.info(f"Loading configuration from {config_path}")
    config = RiskInjectionConfig.from_yaml(config_path)
    
    # Override output format if specified
    if output_format:
        if not config.output:
            config.output = {"file_format": output_format}
    else:
            config.output["file_format"] = output_format
    
    # Load constraint map
    constraint_map = load_constraints(constraint_yaml_path)
    
    # Load records
    logger.info(f"Loading records from {input_path}")
    records = load_records(input_path)
    logger.info(f"Loaded {len(records)} records")
    
    # Initialize injector based on config
    if config.mode == "openai":
        injector = OpenAIRiskInjector(config, constraint_map)
    else:
        raise ValueError(f"Unsupported mode: {config.mode}")
        
        # Inject risks
    logger.info(f"Injecting risks with {max_workers} workers, per_record_random_mode={per_record_random_mode}, inject_all_applicable_risks={inject_all_applicable_risks}")
    injected_records = injector.inject_batch(records, max_workers, per_record_random_mode, inject_all_applicable_risks)
        
    # Save injected records
    logger.info(f"Saving {len(injected_records)} injected records to {output_path}")
    save_records(injected_records, output_path, config.get_file_format())
    logger.info("Finished risk injection process") 