from typing import Dict, List, Optional, Tuple
from agents.unified_llm_client import UnifiedLLMClient
from .base import BasePlanGenerator
import yaml
import ast
import logging

class CoAPlanGenerator(BasePlanGenerator):
    """
    CoA Plan Generator:
    1. Generate multiple multi-turn queries based on pre-defined prompts.
    2. Select the top-k multi-turn queries based on similarity to harmful behavior (simcse) and toxicity. 
    """

    def __init__(self, attacker_config, prompt_path='config/coa/prompts.yaml'):
        """Load prompts and initialize the LLM client used to craft attack plans."""
        super().__init__()
        with open(prompt_path, "r") as f:
            prompts = yaml.safe_load(f)
        self.multi_turn_system_prompt = prompts['get_mr_attack_chain_system_prompt_en']
        self.multi_turn_user_prompt = prompts['get_init_msg_for_mr_init_chain']

        # parameters
        self.temperature = attacker_config.get("temperature", 0.0)
        self.attacker_client = UnifiedLLMClient(
            attacker_config["model"],
            attacker_config.get("provider"),
            attacker_config.get("base_url"),
        )
        self.plan_queries: List[str] = []

    def _prepare_plan(self, harmful_behavior: str, max_turns: Optional[int] = None, **kwargs):
        """Generate the full set of attacker queries up front using the LLM prompts."""
        system_prompt = self.multi_turn_system_prompt.format(target=harmful_behavior, max_round=max_turns)
        user_prompt = self.multi_turn_user_prompt.format(target=harmful_behavior, max_round=max_turns)
        for attempt in range(5):
            response, _ = self.attacker_client.generate(user_prompt, 
                                                        system_prompt=system_prompt, 
                                                        temperature=self.temperature)
            mr_conv, evaluation, json_str = self._extract_json_for_mr_init_chain(response)
            if mr_conv is not None and evaluation is not None:
                break
        
        mr_conv = [item['prompt'] for item in mr_conv if item != None]
        
        self.plan_queries = mr_conv
        assert len(self.plan_queries) == max_turns

        self.save_info = {
            "plan_queries": self.plan_queries
        }
        return self.plan_queries
    
    def _next_query(self, cur_turn: int, conversation_history: Optional[str] = None, **kwargs) -> str:
        """Return the precomputed query for the current turn."""
        idx = cur_turn - 1
        if idx < 0 or idx >= len(self.plan_queries):
            raise IndexError(
                f"Turn {cur_turn} is out of range for the precomputed plan of length {len(self.plan_queries)}."
            )
        return self.plan_queries[idx]
    
    def _extract_json_for_mr_init_chain(self, s):
        """Parse the model output into a list of prompts and evaluation entries.

        Expects a JSON-like mapping where each turn contains `improvement` and `prompt`
        fields, plus an optional `evaluation` entry. Returns both the structured data
        and the raw JSON slice for debugging.
        """
        s = s.strip()
        start_pos = s.find("{")
        # find the end "}" in s
        end_pos = s.rfind("}") + 1  # +1 to include the closing brace

        if end_pos == -1:
            logging.error("Error extracting potential JSON structure")
            logging.error(f"Input:\n {s}")
            return None, None, None

        json_str = s[start_pos:end_pos]

        multi_round_conv = []
        evaluation = []
        try:
            parsed = ast.literal_eval(json_str)
            for item in parsed:
                if item != "evaluation" and not all(x in parsed[item] for x in ["improvement", "prompt"]):
                    logging.error("Error in extracted structure. Missing keys.")
                    return None, None, None

                if item != "evaluation":
                    if type(parsed[item]["improvement"]) == list:
                        parsed[item]["improvement"] = parsed[item]["improvement"][0]
                    if type(parsed[item]["prompt"]) == list:
                        parsed[item]["prompt"] = parsed[item]["prompt"][0]
                    multi_round_conv.append(parsed[item])
                else:
                    evaluation.append(parsed[item])

            return multi_round_conv, evaluation, json_str
        except:
            return None, None, None
