from .jailbreaks import JAILBREAKS
from ..baseline import RedTeamingMethod
import os
import json

class HumanJailbreaks(RedTeamingMethod):
    """
    Jailbreaks designed by humans
    """
    def __init__(self, **kwargs):
        pass
    
    def generate_test_cases(self, behaviors, verbose=False, **kwargs):
        test_cases = {}
        for behavior_dict in behaviors:
            # == behaviors metadata ==
            behavior = behavior_dict['Behavior']
            context_str = behavior_dict['ContextString']
            behavior_id = behavior_dict['BehaviorID']

            if context_str:
                behavior = f"{context_str}\n\n---\n\n{behavior}"
            
            _test_cases = []
            for jb in JAILBREAKS:
                _test_cases.append(f"{jb}\n\n{behavior}")
            test_cases[behavior_id] = _test_cases
        return test_cases, {}
    
    # ===== Override saving methods ====
    # Saves test cases for all behaviors into one test_cases.json file
    def get_output_file_path(self, save_dir, behavior_id, file_type, run_id=None):
        run_id = f"_{run_id}" if run_id else ""
        return os.path.join(save_dir, f"{file_type}{run_id}.json")
    
    def save_test_cases(self,
            save_dir,
            test_cases,
            logs=None,
            method_config=None,
            run_id=None):

        test_cases_save_path = self.get_output_file_path(save_dir, None, 'test_cases', run_id)
        logs_save_path = self.get_output_file_path(save_dir, None, 'logs', run_id)
        method_config_save_path = self.get_output_file_path(save_dir, None, 'method_config', run_id)

        os.makedirs(os.path.dirname(test_cases_save_path), exist_ok=True)
        if test_cases is not None:
            with open(test_cases_save_path, 'w', encoding='utf-8') as f:
                json.dump(test_cases, f, indent=4)

        if logs is not None:
            with open(logs_save_path, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)
        
        if method_config is not None:
            # mask token or api_key
            self._replace_tokens(method_config)
            with open(method_config_save_path, 'w', encoding='utf-8') as f:
                method_config["dependencies"] = {l.__name__: l.__version__ for l in self.default_dependencies}
                json.dump(method_config, f, indent=4)