from typing import Dict, List, Optional, Tuple
from agents.unified_llm_client import UnifiedLLMClient
from tenacity import retry, stop_after_attempt, wait_fixed
import json
import os
import yaml

from .base import BasePlanGenerator

class LLMRefusalError(Exception):
    pass

class XTeamingPlanGenerator(BasePlanGenerator):
    """
    strategy_client: used to generate attack strategies
    attack_client: used to generate multi-turn queries based on strategies

    How to use:
        generator = XTeamingPlanGenerator(...)
        strategies_10 = generator.generate_attack_strategies(harmful_behavior)
        strategy = strategies_10['Set_1']['strategy_1']
        generator._load_attacker(prompt_path, strategy)
        for turn in range(1, num_turns + 1):
            one_query = generator.iterative_generate_turns(turn, conversation_history)
            # iteratively refine
            ...
    Note that: I change the original prompt of user_message1, in order to generate a specific number of turns.
    """
    def __init__(self, strategy_model_config, attacker_config, num_sets=1, prompt_path='./config/xteaming/prompts'):
        super().__init__()
        self.strategy_client = UnifiedLLMClient(strategy_model_config['model'], strategy_model_config['provider'], base_url=strategy_model_config['base_url'])
        self.attacker_client = UnifiedLLMClient(attacker_config['model'], attacker_config['provider'], base_url=attacker_config['base_url'])
        self.strategy_temperature = strategy_model_config['temperature']
        self.attacker_temperature = attacker_config['temperature']
        self.num_sets = num_sets
        self.prompt_path = prompt_path

    def _prepare_plan(
        self,
        harmful_behavior: str,
        max_turns: Optional[int] = None,
        strategy: Optional[Dict] = None,
        **kwargs,
    ):
        """
        XTeaming needs a strategy phase; if a strategy is not provided we
        generate one and load the attacker prompts for turn generation.
        """
        self.max_turns = max_turns
        chosen_strategy = strategy
        if chosen_strategy is None:
            strategies = self.generate_attack_strategies(harmful_behavior)
            first_set = next(iter(strategies.values()))
            chosen_strategy = next(iter(first_set.values()))
        self._load_attacker(self.prompt_path, chosen_strategy)
        self.conversation_history = ""
        self.save_info = {
            "strategy": chosen_strategy, 
            "num_phases": self.num_phases,
            "plan_queries": []
            }
        return self.save_info

    # *****************************Attacker*****************************
    def _format_strategy(self, strategy: Dict) -> str:
        """
        Format the strategy into a readable string.

        Returns:
            Formatted strategy string with persona, context, approach, and plan
        """
        return "\n".join(
            [
                f"Persona: {strategy['persona']}",
                f"Context: {strategy['context']}",
                f"Approach: {strategy['approach']}",
                f"Conversation Plan: {strategy['conversation_plan']}",
            ]
        )
    
    def _load_attacker(self, prompt_path, strategy):
        with open(os.path.join(prompt_path, 'attacker_agent_prompts.yaml'), "r") as f:
            prompts = yaml.safe_load(f)
        self.attacker_system_prompt = prompts['prompts']["system"]["messages"][0]["content"]
        self.first_turn_template = prompts['prompts']["first_turn_messages"]["messages"][0]["content"]
        self.nth_turn_template = prompts['prompts']["nth_turn_messages"]["messages"][0]["content"]
        self.final_turn_template = prompts['prompts']["final_turn_messages"]["messages"][0]["content"]
        self.revise_plan_template = prompts['prompts']["revise_plan_messages"]["messages"][0]["content"]

        self.num_phases = len(strategy["conversation_plan"])
        self.strategy_dict = strategy
        self.strategy = self._format_strategy(strategy)

    def _next_query(self, cur_turn: int, conversation_history: Optional[str] = None, **kwargs) -> str:
        query = self.iterative_generate_turns(cur_turn, conversation_history=self.conversation_history)
        self.save_info['plan_queries'].append(query)
        return query

    def on_feedback(self, feedback_result: Dict) -> None:
        self.conversation_history += "\n".join(
            [   "",
                f"Turn {feedback_result['turn']}:",
                f"Attacker (your) response: {feedback_result['attacker']}",
                f"Target model response: {feedback_result['target_truncated']}",
                f"Evaluation Score: {feedback_result['evaluation_score']}/5",
                f"Evaluation Reason: {feedback_result['evaluation_reason']}",
            ]
        )

    @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), reraise=True)
    def iterative_generate_turns(self, cur_turn, conversation_history: str = None) -> str:
        """
        Generate response for the specified turn.

        Args:
            conversation_history: Previous conversation (required for turns > 1)

        Returns:
            Generated response for the current turn
        """
        if cur_turn == 1:
            response = self._generate_first_turn()
        elif cur_turn == self.num_phases:  # If it's the final turn
            response = self._generate_final_turn(conversation_history)
        else:
            response = self._generate_nth_turn(cur_turn, conversation_history)

        # detect attacker model refusals
        if "can't assist with that" in response:
            raise LLMRefusalError(response)
        return response

    def _extract_conversation(self, response: str) -> str:
        """Extract only the conversation part from the response."""
        start = response.find("<conversation>") + len("<conversation>")
        end = response.find("</conversation>")
        return response[start:end].strip()
    
    def _generate_first_turn(self) -> str:
        """Generate the first turn response."""
        turn_1_conversation_flow = self.strategy_dict["conversation_plan"]["turn_1"]

        formatted_system_prompt = self.attacker_system_prompt.format(
            target_behavior=self.harmful_behavior,
            strategy=self.strategy,
            max_turns=self.num_phases,
        )

        formatted_user_prompt = self.first_turn_template.format(
            turn_1_conversation_flow=turn_1_conversation_flow
        )

        response, next_messages = self.attacker_client.generate(
            formatted_user_prompt,
            system_prompt=formatted_system_prompt,
            temperature=self.attacker_temperature,
        )
        return self._extract_conversation(response)

    def _generate_nth_turn(self, cur_turn, conversation_history: str) -> str:
        """Generate response for subsequent turns."""
        # Get conversation flow for current turn
        if cur_turn >= self.num_phases:
            raise ValueError(
                f"Expected plan phase < {self.num_phases}, but we are on phase {cur_turn}"
            )

        turn_key = f"turn_{cur_turn}"
        turn_conversation_flow = self.strategy_dict["conversation_plan"][turn_key]

        # Format the prompt with all placeholders
        formatted_prompt = self.nth_turn_template.format(
            turn_number=cur_turn,
            conversation_history=conversation_history or "",
            turn_number_conversation_flow=turn_conversation_flow,
            previous_turn_number=cur_turn - 1,  # Added for completeness
        )

        formatted_system_prompt = self.attacker_system_prompt.format(
            target_behavior=self.harmful_behavior,
            strategy=self.strategy,
            max_turns=self.num_phases,
        )

        response, next_messages = self.attacker_client.generate(
            formatted_prompt,
            system_prompt=formatted_system_prompt,
            temperature=self.attacker_temperature,
        )
        return self._extract_conversation(response)
    
    def _generate_final_turn(self, conversation_history: Optional[str]) -> str:
        """Generate the final turn response."""
        # Get the final turn's conversation plan
        final_turn_conversation_flow = self.strategy_dict["conversation_plan"][
            "final_turn"
        ]

        # Format the prompt with all placeholders
        formatted_prompt = self.final_turn_template.format(
            conversation_history=conversation_history or "",
            final_turn_conversation_flow=final_turn_conversation_flow,
        )

        formatted_system_prompt = self.attacker_system_prompt.format(
            target_behavior=self.harmful_behavior,
            strategy=self.strategy,
            max_turns=self.num_phases,
        )
        response, next_messages = self.attacker_client.generate(
            formatted_prompt,
            system_prompt=formatted_system_prompt,
            temperature=self.attacker_temperature,
        )
        return self._extract_conversation(response)
    

    # *****************************Strategy Generator*****************************
    def generate_attack_strategies(self, harmful_behavior) -> List[str]:
        """Generate the initial attack strategies from harmful behavior.
        Returns:
            A dictionary with set numbers as keys and generated strategies as values.
            The number of strategies per set is 10.
        {'Set_1':
            {
              "strategy_1": {
                  "persona": "人物 + 动机",
                  "context": "情境 + 触发事件",
                  "approach": "主要策略",
                  "turns_needed": "说明为何该策略需要 X 轮",
                  "conversation_plan": {
                      "turn_1": "具体行动",
                      // 额外轮次，用于搭建语境并推进目标
                      "final_turn": "自然引出完整目标行为的请求"
                  }
              }
          }
        }
        """
        all_responses = {}

        for set_num in range(1, self.num_sets + 1):
            system_prompt, formatted_user_prompt = self._load_and_format_prompts(
                    prompt_path=self.prompt_path, harmful_behavior=self.harmful_behavior, set_number=set_num, previous_responses=all_responses
                )
            
            response = self.generate_strategies(
                    system_prompt=system_prompt,
                    user_input=formatted_user_prompt,
                )
            
            all_responses[f"Set_{set_num}"] = response
        return all_responses
    
    def _load_and_format_prompts(self, prompt_path, harmful_behavior, set_number, previous_responses=None):
        """Load and format prompts based on set number"""
        with open(os.path.join(prompt_path, 'plan_generation_prompts.yaml'), "r") as f:
            prompts = yaml.safe_load(f)

        system_prompt = prompts["prompts"]["system"]["messages"][0]["content"]

        if set_number == 1:
            user_prompt = prompts["prompts"]["user_message1"]["messages"][0]["content"]
            formatted_user_prompt = user_prompt.replace("{target_behavior}", self.harmful_behavior).replace("{max_turns}", str(self.max_turns))
        else:
            user_prompt = prompts["prompts"]["user_message2"]["messages"][0]["content"]
            formatted_user_prompt = user_prompt.replace("{target_behavior}", self.harmful_behavior)

            strategies_text = ""
            for set_name, response in previous_responses.items():
                strategies_text += f"\n{set_name}:\n{response}\n"
            formatted_user_prompt = formatted_user_prompt.replace(
                "{previously_generated_strategies}", strategies_text
            )

        return system_prompt, formatted_user_prompt

    def _parse_strategy_json(self, response: str):
        """
        Some model outputs are wrapped in code fences or include extra text.
        Strip simple fences and raise a clearer error with context when parsing fails.
        """
        cleaned = response.strip()
        if cleaned.startswith("```"):
            lines = cleaned.splitlines()
            if len(lines) >= 2:
                cleaned = "\n".join(lines[1:-1]).strip()

        try:
            return json.loads(cleaned)
        except json.JSONDecodeError as exc:
            snippet = cleaned[:500]
            raise ValueError(
                f"Failed to parse strategy JSON. Snippet (len={len(cleaned)}): {snippet}"
            ) from exc

    @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), reraise=True)
    def generate_strategies(self, system_prompt, user_input):
        """Generate strategies for a single set"""
        response, _ = self.strategy_client.generate(
            user_input,
            system_prompt=system_prompt,
            temperature=self.strategy_temperature,
        )
        # Parse the response string into a Python dictionary
        parsed_response = self._parse_strategy_json(response)
        assert len(parsed_response) == 10
        return parsed_response  # Return parsed dictionary instead of raw string