r"""
'Tree of Attacks' Recipe
============================================
This module implements a jailbreak method describe in the paper below.
This part of code is based on the code from the paper.

Paper title: Tree of Attacks: Jailbreaking Black-Box LLMs Automatically
arXiv link: https://arxiv.org/abs/2312.02119
Source repository: https://github.com/RICommunity/TAP
"""
import os
import logging
from tqdm import tqdm

from attacker.attacker_base import AttackerBase
from jb_datasets.jailbreak_datasets import JailbreakDataset
from jb_datasets.instance import Instance
from loggers.logger import Logger
from models.huggingface_model import HuggingfaceModel
from models.openai_model import OpenaiModel
####### 4 major components #######
from seed.seed_template import SeedTemplate
from mutation.generation.IntrospectGeneration import IntrospectGeneration
from constraint.DeleteOffTopic import DeleteOffTopic
from metrics.Evaluator.Evaluator_UnifiedJudge import EvaluatorUnifiedJudge
from selector.SelectBasedOnScores import SelectBasedOnScores
from defense.updater_defense_helper import ProActDefenseRunner
from defense.guard_defense import GuardDefenseRunner
from defense.utils import mislead_defense

r"""
EasyJailbreak TAP class
============================================
"""
__all__ = ['TAP']

target_model_calls = 0

logger = logging.getLogger(__name__)

def _prompt_to_text(prompt) -> str:
    """
    Best-effort conversion of a prompt (str | list[dict] OpenAI messages | list[str]) into a single user-text string.
    Used for defense hooks (guard/mislead/proact), which operate on plain text.
    """
    if prompt is None:
        return ""
    if isinstance(prompt, str):
        return prompt
    if isinstance(prompt, list) and prompt:
        # OpenAI chat message format: [{"role": "...", "content": "..."}]
        if isinstance(prompt[0], dict):
            last_user = next((m for m in reversed(prompt) if isinstance(m, dict) and m.get("role") == "user"), None)
            if last_user is not None:
                return str(last_user.get("content") or "")
            parts = []
            for m in prompt:
                if isinstance(m, dict):
                    parts.append(str(m.get("content") or ""))
                else:
                    parts.append(str(m))
            return "\n".join([p for p in parts if p]).strip()
        # Plain list[str] turns
        return "\n".join([str(x) for x in prompt if x is not None]).strip()
    return str(prompt)

class TAP(AttackerBase):
    r"""
    Tree of Attack method, an extension of PAIR method. Use 4 phases:
    1. Branching
    2. Pruning: (phase 1)
    3. Query and Access
    4. Pruning: (phase 2)

    >>> from easyjailbreak.attacker.TAP_Mehrotra_2023 import TAP
    >>> from easyjailbreak.models.huggingface_model import from_pretrained
    >>> from easyjailbreak.datasets.jailbreak_datasets import JailbreakDataset
    >>> from easyjailbreak.datasets.Instance import Instance
    >>> attack_model = from_pretrained(model_path_1)
    >>> target_model = from_pretrained(model_path_2)
    >>> eval_model  = from_pretrained(model_path_3)
    >>> dataset = JailbreakDataset('AdvBench')
    >>> attacker = TAP(attack_model, target_model, eval_model, dataset)
    >>> attacker.attack()
    >>> attacker.jailbreak_Dataset.save_to_jsonl("./TAP_results.jsonl")
    """
    def __init__(self, attack_model, target_model, eval_model, jailbreak_datasets: JailbreakDataset,
                 tree_width=10, tree_depth=10,root_num=1, branching_factor=4,keep_last_n=3,
                 max_n_attack_attempts=5, template_file=None,
                 attack_max_n_tokens=500,
                 attack_temperature=1,
                 attack_top_p=0.9,
                 target_max_n_tokens=150,
                 target_temperature=1,
                 target_top_p=1,
                 judge_max_n_tokens=10,
                 judge_temperature=1,
                 defense_config=None):
        """
        initialize TAP, inherit from AttackerBase

        :param  ~HuggingfaceModel|~OpenaiModel attack_model: LLM for generating jailbreak prompts during Branching(mutation)
        :param  ~HuggingfaceModel|~OpenaiModel target_model: LLM being attacked to generate adversarial responses
        :param  ~HuggingfaceModel|~OpenaiModel eval_model: LLM for evaluating during Pruning:phase1(constraint) and Pruning:phase2(select)
        :param  ~JailbreakDataset jailbreak_datasets: containing instances which conveys the query and reference responses
        :param  int tree_width: defining the max width of the conversation nodes during Branching(mutation)
        :param  int tree_depth: defining the max iteration of a single instance
        :param  int root_num: defining the number of trees or batch of a single instance
        :param  int branching_factor: defining the number of children nodes generated by a parent node during Branching(mutation)
        :param  int keep_last_n: defining the number of rounds of dialogue to keep during Branching(mutation)
        :param  int max_n_attack_attempts: defining the max number of attempts to generating a valid adversarial prompt of a branch
        :param  str template_file: file path of the seed_template.json
        :param  int attack_max_n_tokens: max_n_tokens of the target model
        :param  float attack_temperature: temperature of the attack model
        :param  float attack_top_p: top p of the attack_model
        :param  int target_max_n_tokens: max_n_tokens of the target model
        :param  float target_temperature: temperature of the target model
        :param  float target_top_p: top_p of the target model
        :param  int judge_max_n_tokens: max_n_tokens of the target model
        :param  float judge_temperature: temperature of the judge model
        """
        super().__init__(attack_model=attack_model,
                         target_model=target_model,
                         eval_model=eval_model,
                         jailbreak_datasets=jailbreak_datasets)
        self.seeds=SeedTemplate().new_seeds(1,method_list=['TAP'],template_file=template_file)

        ####### 4 major components ##########
        self.mutator=IntrospectGeneration(attack_model,
                                          system_prompt=self.seeds[0],
                                          keep_last_n=keep_last_n,
                                          branching_factor=branching_factor,
                                          max_n_attack_attempts=max_n_attack_attempts)
        self.constraint=DeleteOffTopic(self.eval_model, tree_width)
        self.selector=SelectBasedOnScores(jailbreak_datasets, tree_width)
        # Pass use_local and local_base_url explicitly if eval_model has base_url
        evaluator_kwargs = {'temperature': judge_temperature, 'max_completion_tokens': judge_max_n_tokens}
        if hasattr(self.eval_model, 'base_url') and self.eval_model.base_url:
            evaluator_kwargs['use_local'] = True
            evaluator_kwargs['local_base_url'] = str(self.eval_model.base_url)
        self.evaluator=EvaluatorUnifiedJudge(self.eval_model, **evaluator_kwargs)
        self.max_score = int(getattr(self.evaluator.judge_evaluator, "max_score", 10))

        # Optional defenses (Mislead / ProAct / Guard) applied around each target response
        # NOTE: In this mode we ONLY rewrite/guard, and reuse the existing evaluator path to score the defended response.
        # We compute `original_score` only once at success/jailbreak decision time to reduce judge calls.
        self.defense_config = dict(defense_config or {})
        self._proact_runner = None
        self._guard_runner = None
        self._mislead_turn_idx = 0
        if bool(self.defense_config.get("enable_proact", False)):
            try:
                self._proact_runner = ProActDefenseRunner(enabled=True, config=self.defense_config)
            except Exception as e:
                logger.warning("Failed to initialize ProAct; continuing without ProAct: %s", e)
                self._proact_runner = None
        if bool(self.defense_config.get("enable_guard", False)):
            try:
                self._guard_runner = GuardDefenseRunner(GuardDefenseRunner.from_config(self.defense_config))
            except Exception as e:
                logger.warning("Failed to initialize Guard; continuing without Guard: %s", e)
                self._guard_runner = None

        # Mislead rewrite settings (env-first, same conventions as Multi-turn-jailbreak)
        rewrite_type = os.getenv("REWRITE_MODEL_TYPE", "server").strip().lower()
        if rewrite_type == "server":
            self._rewrite_model = os.getenv("REWRITE_SERVER_MODEL") or "increase"
        else:
            self._rewrite_model = os.getenv("REWRITE_MODEL") or "gpt-4.1-mini"
        self._rewrite_server_url_increase = os.getenv("REWRITE_SERVER_URL_INCREASE") or os.getenv("REWRITE_SERVER_URL")
        self._rewrite_server_url_decrease = os.getenv("REWRITE_SERVER_URL_DECREASE") or None
        self._rewrite_type = rewrite_type
        self._defenses_enabled = any(
            bool(self.defense_config.get(k, False))
            for k in ("enable_defense", "enable_proact", "enable_guard")
        )

        ######## logging information ############
        self.current_query: int = 0
        self.current_jailbreak: int = 0
        self.current_reject: int = 0
        self.current_iteration: int = 0

        ######## parameters of TAP tree #########
        self.root_num = root_num
        self.tree_depth = tree_depth
        self.tree_width = tree_width
        self.branching_factor = branching_factor

        ######## datasets and logger ############
        self.jailbreak_Dataset = JailbreakDataset([])
        self.logger = Logger()

        ######## model configuration ############
        self.target_max_n_tokens = target_max_n_tokens
        self.target_temperature = target_temperature
        self.target_top_p = target_top_p
        self.judge_temperature = judge_temperature
        self.judge_max_n_tokens = judge_max_n_tokens

        if self.attack_model.generation_config == {}:
            if isinstance(self.attack_model, OpenaiModel):
                self.attack_model.generation_config = {'max_tokens': attack_max_n_tokens,
                                                       'temperature': attack_temperature,
                                                       'top_p': attack_top_p}
            elif isinstance(self.attack_model, HuggingfaceModel):
                self.attack_model.generation_config = {'max_new_tokens': attack_max_n_tokens,
                                                       'temperature': attack_temperature,
                                                       'do_sample': True,
                                                       'top_p': attack_top_p,
                                                       'eos_token_id': self.attack_model.tokenizer.eos_token_id}

        if isinstance(self.eval_model, OpenaiModel) and self.eval_model.generation_config == {}:
            self.eval_model.generation_config = {'max_tokens': self.judge_max_n_tokens,
                                                 'temperature': self.judge_temperature}
        elif isinstance(self.eval_model, HuggingfaceModel) and self.eval_model.generation_config == {}:
            self.eval_model.generation_config = {'do_sample': True,
                                                 'max_new_tokens': self.judge_max_n_tokens,
                                                 'temperature': self.judge_temperature}

    def _infer_score(self, harmful_behavior: str, response: str) -> int:
        try:
            score, _reason, *_ = self.evaluator.judge_evaluator.infer_single(harmful_behavior, response)
        except Exception:
            score = 1
        try:
            score = int(score) if score is not None else 1
        except Exception:
            score = 1
        if score < 1:
            return 1
        if score > self.max_score:
            return self.max_score
        return score

    def attack(self, save_path='TAP_attack_result.jsonl'):
        r"""
        Execute the attack process using provided prompts.
        """
        # To calculate how many times are eval_model.generate() called
        global target_model_calls
        logging.info("Jailbreak started!")
        
        # Load existing results if file exists
        processed_queries = set()
        if os.path.exists(save_path):
            logging.info(f"Found existing results file: {save_path}. Loading processed instances...")
            try:
                existing_dataset = JailbreakDataset.load_jsonl(save_path)
                processed_queries = {inst.query for inst in existing_dataset if inst.query}
                logging.info(f"Loaded {len(processed_queries)} processed instances from existing file.")
                # Add existing instances to jailbreak_Dataset
                for inst in existing_dataset:
                    self.jailbreak_Dataset.add(inst)
            except Exception as e:
                logging.warning(f"Failed to load existing results: {e}. Starting fresh.")
        
        try:
            for Instance in tqdm(self.jailbreak_datasets, desc="Processing instances"):
                # Skip if already processed
                if Instance.query in processed_queries:
                    logging.info(f"Skipping already processed query: {Instance.query[:50]}...")
                    continue
                
                new_Instance = self.single_attack(Instance)[0]
                self.jailbreak_Dataset.add(new_Instance)
                
                # Incremental save after each instance
                self.jailbreak_Dataset.save_to_jsonl(save_path)
        except KeyboardInterrupt:
            logging.info("Jailbreak interrupted by user!")
        self.update(self.jailbreak_Dataset)
        # These were extremely noisy (especially with many instances/threads). Keep at DEBUG.
        logger.debug("jailbreak_prompt:%s", [instance.jailbreak_prompt for instance in self.jailbreak_Dataset])
        logger.debug("target_responses:%s", [instance.target_responses[0] for instance in self.jailbreak_Dataset])
        logger.debug("ASR:%s%%", 100 * self.current_jailbreak / self.current_query if self.current_query else 0.0)
        logger.debug("Total calls of generate:%s", target_model_calls)
        # Get call count safely (compatible with both old and new evaluators)
        eval_calls = getattr(self.evaluator.eval_model.generate, 'count_calls', getattr(self.evaluator.eval_model.generate, 'calls', self.evaluator._call_count if hasattr(self.evaluator, '_call_count') else 0))
        logger.debug("Eval calls of generate:%s", eval_calls - target_model_calls)
        self.log()
        logging.info("Jailbreak finished!")
        self.jailbreak_Dataset.save_to_jsonl(save_path)
        logging.info(
            'Jailbreak result saved at {}!'.format(os.path.join(os.path.dirname(os.path.abspath(__file__)), save_path))
        )

    def single_attack(self, instance) -> JailbreakDataset:
        r"""
        Conduct an attack for an instance.

        :param ~Instance instance: The Instance that is attacked.
        :return ~JailbreakDataset: returns the attack result dataset.
        """
        global target_model_calls
        batch=[JailbreakDataset([instance.copy()]) for _ in range(self.root_num)]
        find_flag = 0
        logger.debug("QUERY:%s\n%s", "=" * 20, instance.query)
        for iteration in range(1, self.tree_depth + 1):
            logger.debug("%s\nTree-depth is: %s\n%s", "=" * 36, iteration, "=" * 36)
            dataset_list = []
            for i,stream in enumerate(batch):
                logger.debug("BATCH:%s", i)
                new_dataset = stream

                ############# generate jailbreak_prompts by branching ################
                mutated = self.mutator(new_dataset)
                # Mutation can fail and return an empty dataset; fall back to previous prompts.
                if mutated is None or len(mutated) == 0:
                    logger.debug("Mutation produced 0 prompts; falling back to previous dataset.")
                else:
                    new_dataset = mutated

                ############# prune off-topic jailbreak_prompt ################
                constrained = self.constraint(new_dataset)
                # Constraint can prune everything; keep previous dataset to avoid downstream crashes.
                if constrained is None or len(constrained) == 0:
                    logger.debug("Off-topic constraint pruned all prompts; falling back to previous dataset.")
                else:
                    new_dataset = constrained

                ############# attack ################
                self.target_model.conversation.messages = []
                for instance in new_dataset:
                    # Some fallbacks can leave `jailbreak_prompt` unset; ensure it's always a string.
                    if getattr(instance, "jailbreak_prompt", None) is None:
                        instance.jailbreak_prompt = instance.query or ""

                    prompt_for_rewrite = _prompt_to_text(instance.jailbreak_prompt)
                    proact_info = None
                    raw_target_response = None
                    if isinstance(self.target_model, OpenaiModel):
                        if self._proact_runner is not None:
                            spurious, proact_info = self._proact_runner.maybe_spurious(
                                messages=[{"role": "user", "content": prompt_for_rewrite}]
                            )
                            if spurious is not None:
                                original_response_full = str(spurious)
                            else:
                                raw_target_response = self.target_model.generate(
                                    instance.jailbreak_prompt,
                                    max_tokens=self.target_max_n_tokens,
                                    temperature=self.target_temperature,
                                    top_p=self.target_top_p,
                                )
                                original_response_full = str(raw_target_response)
                        else:
                            raw_target_response = self.target_model.generate(
                                instance.jailbreak_prompt,
                                max_tokens=self.target_max_n_tokens,
                                temperature=self.target_temperature,
                                top_p=self.target_top_p,
                            )
                            original_response_full = str(raw_target_response)
                    elif isinstance(self.target_model, HuggingfaceModel):
                        raw_target_response = self.target_model.generate(
                            instance.jailbreak_prompt,
                            max_new_tokens=self.target_max_n_tokens,
                            temperature=self.target_temperature,
                            do_sample=True,
                            top_p=self.target_top_p,
                            eos_token_id=self.target_model.tokenizer.eos_token_id,
                        )
                        original_response_full = str(raw_target_response)

                    # Apply Guard + Mislead (rewrite only; no judge scoring here)
                    final_response = original_response_full
                    defense_info = {}
                    if isinstance(getattr(instance, "attack_attrs", None), dict):
                        instance.attack_attrs.setdefault("raw_target_response", str(raw_target_response) if raw_target_response is not None else None)
                    if self._guard_runner is not None:
                        try:
                            guarded, guard_info = self._guard_runner.maybe_guard(
                                messages=[{"role": "user", "content": prompt_for_rewrite}, {"role": "assistant", "content": final_response}]
                            )
                            final_response = str(guarded)
                            if guard_info is not None:
                                defense_info["guard"] = guard_info
                        except Exception as e:
                            logger.debug("Guard failed; continuing without guard. Error: %s", e)
                    guard_replaced = bool(defense_info.get("guard") and defense_info["guard"].get("replaced") is True)

                    if bool(self.defense_config.get("enable_defense", False)) and not guard_replaced:
                        if self._rewrite_type == "server" and not self._rewrite_server_url_increase:
                            logger.debug("Mislead enabled but REWRITE_SERVER_URL(_INCREASE) not set; skipping rewrite.")
                        else:
                            try:
                                rewritten, direction, similarity_score = mislead_defense(
                                    original_response=final_response,
                                    prompt=prompt_for_rewrite,
                                    turn_index=self._mislead_turn_idx,
                                    rewrite_model=self._rewrite_model,
                                    rewrite_server_url_increase=self._rewrite_server_url_increase,
                                    rewrite_server_url_decrease=self._rewrite_server_url_decrease,
                                )
                                self._mislead_turn_idx += 1
                                final_response = str(rewritten)
                                defense_info.update(
                                    {
                                        "mislead_enabled": True,
                                        "direction": direction,
                                        "similarity_score": similarity_score,
                                        "original_response": original_response_full,  # Save original response before mislead rewrite
                                    }
                                )
                            except Exception as e:
                                logger.debug("Mislead rewrite failed; using original response. Error: %s", e)

                    if proact_info is not None:
                        defense_info["proact_enabled"] = True
                        defense_info["proact"] = proact_info
                        defense_info["raw_target_response"] = str(raw_target_response) if raw_target_response is not None else None
                        defense_info["proact_response"] = original_response_full if proact_info.get("intent") == "malicious" else None

                    if isinstance(getattr(instance, "attack_attrs", None), dict) and defense_info:
                        instance.attack_attrs["defense_info"] = defense_info

                    instance.target_responses = [final_response]
                    target_model_calls+=1

                ############# prune not-jailbroken jailbreak_prompt ################
                num_responses = len(new_dataset)
                # Always score via the existing evaluator path (scores the defended response).
                self.evaluator(new_dataset)
                selected = self.selector.select(new_dataset)
                if selected is None or len(selected) == 0:
                    logger.debug("Selector returned empty; falling back to previous dataset.")
                else:
                    new_dataset = selected
                # Get call count safely (compatible with both old and new evaluators)
                eval_calls = getattr(self.evaluator.eval_model.generate, 'calls', getattr(self.evaluator.eval_model.generate, 'count_calls', self.evaluator._call_count if hasattr(self.evaluator, '_call_count') else 0))
                logger.debug(
                    "%s\nCount of Calls of Evaluator is: %s\n%s",
                    "=" * 36,
                    eval_calls - num_responses,
                    "=" * 36,
                )

                batch[i] = new_dataset
                ############# attack successful ################
                if len(new_dataset) > 0 and any(
                    [instance.eval_results[-1] == self.max_score for instance in new_dataset if instance.eval_results]
                ):
                    find_flag = 1
                    logger.debug("Found a jailbreak. Exiting.")
                    break
            if find_flag:
                new_instance = max(new_dataset, key=lambda instance: instance.eval_results[-1])
                # Keep the original 1-10 judge scores from EvaluatorUnifiedJudge.
                # Multi-turn-jailbreak semantics:
                # - Stop condition uses evaluation_score (possibly after defenses).
                # - "jailbreak success" is determined by original_score when available (i.e., before defenses).
                defense_info = (
                    new_instance.attack_attrs.get("defense_info")
                    if isinstance(getattr(new_instance, "attack_attrs", None), dict)
                    else None
                ) or {}
                # Compute original_score only once at decision time to reduce judge calls.
                # If no defense is enabled, reuse eval_results[-1] to avoid redundant judge calls.
                if not defense_info:
                    # No defense: eval_results[-1] is already the score for the raw response
                    original_score = new_instance.eval_results[-1] if new_instance.eval_results else None
                else:
                    # Defense enabled: need to score the raw_target_response separately
                    raw_target = None
                    if isinstance(getattr(new_instance, "attack_attrs", None), dict):
                        raw_target = new_instance.attack_attrs.get("raw_target_response")
                    if raw_target is None:
                        raw_target = defense_info.get("raw_target_response")
                    if defense_info.get("proact_enabled") and (raw_target is None):
                        original_score = 1
                    elif raw_target is None:
                        original_score = None
                    else:
                        original_score = self._infer_score(str(new_instance.query or ""), str(raw_target))
                if isinstance(getattr(new_instance, "attack_attrs", None), dict):
                    new_instance.attack_attrs["original_score"] = original_score
                    new_instance.attack_attrs["success"] = bool(original_score == self.max_score)
                break
        if iteration == self.tree_depth:
            new_instance = max(new_dataset, key=lambda instance: instance.eval_results[-1])
            # If we hit the depth limit without early stop, consider it not a jailbreak.
            new_instance.attack_attrs["success"] = False
            # If somehow no scores exist (e.g. repeated fallbacks), keep a valid minimal score.
            if not getattr(new_instance, "eval_results", None):
                new_instance.eval_results = [1]
            # Set original_score to eval_results[-1] if no defense (same logic as success case)
            defense_info = (
                new_instance.attack_attrs.get("defense_info")
                if isinstance(getattr(new_instance, "attack_attrs", None), dict)
                else None
            ) or {}
            if not defense_info:
                # No defense: reuse eval_results[-1]
                if isinstance(getattr(new_instance, "attack_attrs", None), dict):
                    new_instance.attack_attrs["original_score"] = new_instance.eval_results[-1] if new_instance.eval_results else None
        return JailbreakDataset([new_instance])

    def update(self, Dataset: JailbreakDataset):
        r"""
        Update the state of the ReNeLLM based on the evaluation results of Datasets.

        :param ~JailbreakDataset: processed dataset after an iteration
        """
        for prompt_node in Dataset:
            self.current_jailbreak += prompt_node.num_jailbreak
            self.current_query += prompt_node.num_query
            self.current_reject += prompt_node.num_reject

    def log(self):
        r"""
        Report the attack results.
        """
        logging.info("======Jailbreak report:======")
        logging.info(f"Total queries: {self.current_query}")
        logging.info(f"Total jailbreak: {self.current_jailbreak}")
        logging.info(f"Total reject: {self.current_reject}")
        logging.info("========Report End===========")
