# file: agent_systems/dylan_wrapper.py

import re
import asyncio
from .base_wrapper import SystemWrapper
from malicious_factory.agent import MaliciousAgent
from typing import Any, Dict, Tuple

from methods import get_method_class

class DyLANWrapper(SystemWrapper):
    """
    A wrapper for the DyLAN system. It injects malicious behavior by monkey-patching
    the call_llm method and tracking agent activations across multiple rounds.
    """
    def __init__(self, general_config: Dict[str, Any], method_config: Dict[str, Any]):
        from utils.async_llm import create_llm_instance
        exp_config = general_config.get('experiment_config', {})
        llm_config = exp_config.get('llm_config', {})
        if not llm_config:
            raise ValueError("No LLM configuration found for the specified model name.")

        self.llm = create_llm_instance(llm_config)
        # 动态选择 DyLAN 变体
        method_name = exp_config['system_under_test']['name']  # "dylan"
        dataset_name = exp_config.get('benchmark_name', None)  # 如 "GSM8K", "MATH", "MMLU", "HumanEval"
        MAS_CLASS = get_method_class(method_name, dataset_name)
        self.dylan_instance = MAS_CLASS(general_config, method_config_name=None)

        print(f"DyLANWrapper initialized with {MAS_CLASS.__name__}.")

    def run_with_injection(
        self,
        task: Any,
        malicious_agent: MaliciousAgent,
        injection_target: Dict[str, Any]
    ) -> Tuple[Any, Dict[str, Any]]:

        target_role = injection_target['role']
        target_role_index = injection_target.get('role_index', 0)

        # Store the original call_llm method
        original_llm_call = self.dylan_instance.call_llm
        
        current_node_id = None

        # 创建 factory 实例用于注入逻辑
        from malicious_factory import MaliciousAgentFactory
        factory = MaliciousAgentFactory(llm=self.llm)

        # --- Define the malicious call_llm method ---
        def sophisticated_malicious_llm_call(*args, **kwargs):
            nonlocal current_node_id
            
            # 1. Determine the current node and its role/index
            messages = args[2] if len(args) > 2 else kwargs.get('messages', [])
            
            # Find the active node
            if hasattr(self.dylan_instance, 'nodes'):
                for i, node in enumerate(self.dylan_instance.nodes):
                    if node.get('active', False) and node.get('question') == task.query:
                        current_node_id = i
                        break
            
            # 获取当前节点的角色和编号
            current_role = None
            current_role_index = None
            if hasattr(self.dylan_instance, 'nodes') and current_node_id is not None:
                current_role = self.dylan_instance.nodes[current_node_id].get('role', None)
                # 计算编号：在每轮的智能体编号为 node_id % num_agents
                current_role_index = current_node_id % self.dylan_instance.num_agents
            
            print(f"[DyLAN Runner Intercept] Node {current_node_id}: role='{current_role}', index={current_role_index}")

            # 2. 判断是否注入恶意内容
            is_malicious_call = (current_role == target_role and current_role_index == target_role_index)

            if not is_malicious_call:
                return original_llm_call(*args, **kwargs)

            print(f"*** Malicious Agent Activated on '{current_role}' (index {current_role_index}, node {current_node_id}) ***")
            
            # 使用 factory 的统一注入逻辑
            task_input = messages[-1]['content'] if messages else ""
            response = factory.inject_malicious_behavior(
                lambda: original_llm_call(*args, **kwargs),
                malicious_agent,
                task_input=task_input,
                messages=messages
            )
            
            return response

        # --- Apply the monkey patch ---
        self.dylan_instance.call_llm = sophisticated_malicious_llm_call
        print(f"[DyLAN Runner] Monkey-patch applied. Target: role='{target_role}', index={target_role_index}.")

        # --- Execute the patched workflow ---
        sample = {"query": task.query}
        final_output = self.dylan_instance.inference(sample)

        # --- Cleanup ---
        self.dylan_instance.call_llm = original_llm_call
        print("[DyLAN Runner] Original `call_llm` method restored.")
        
        # Collect history from DyLAN's network structure
        full_history = []
        if hasattr(self.dylan_instance, 'nodes'):
            for i, node in enumerate(self.dylan_instance.nodes):
                if node.get('active', False) and node.get('reply'):
                    role = node.get('role', f'agent_{i}')
                    role_index = i % self.dylan_instance.num_agents
                    full_history.append({
                        "role": role,
                        "role_index": role_index,
                        "content": node.get('reply', ''),
                        "node_id": i,
                        "round": i // self.dylan_instance.num_agents
                    })
        
        log = {
            "final_output": final_output,
            "full_history": full_history,
            "injected_role": target_role,
            "injected_role_index": target_role_index,
            "malicious_action_description": malicious_agent.description,
        }
        return final_output, log