from .jailbreaks import JAILBREAKS
from ..baseline import RedTeamingMethod
import os
import json
import random

class HumanJailbreaks(RedTeamingMethod):
    """
    Jailbreaks designed by humans
    """
    def __init__(self, random_subset=-1, seed=1):
        """
        :param random_subset: int, number of jailbreaks to use. If -1, use all jailbreaks
        :param seed: int, random seed
        """
        self.random_subset = random_subset
        self.seed = seed
    
    def generate_test_cases(self, behaviors, verbose=False, **kwargs):
        test_cases = {}
        
        if type(self.random_subset) == int:
            random.seed(self.seed)

        for behavior_dict in behaviors:
            if type(self.random_subset) == int:
                random.shuffle(JAILBREAKS)
            # == behaviors metadata ==
            behavior = behavior_dict['Behavior']
            context_str = behavior_dict['ContextString']
            behavior_id = behavior_dict['BehaviorID']

            if context_str:
                behavior = f"{context_str}\n\n---\n\n{behavior}"
            
            _test_cases = []

            if type(self.random_subset) == int:
                jailbreaks = JAILBREAKS[:self.random_subset]
            else:
                jailbreaks = JAILBREAKS

            for jb in jailbreaks:
                _test_cases.append(f"{jb}\n\n{behavior}")
            test_cases[behavior_id] = _test_cases
        return test_cases, {}
    
    # ===== Override saving methods ====
    # Saves test cases for all behaviors into one test_cases.json file
    @staticmethod
    def get_output_file_path(save_dir, behavior_id, file_type, run_id=None):
        return os.path.join(save_dir, f"{file_type}.json")
    
    def save_test_cases(self,
            save_dir,
            test_cases,
            logs=None,
            method_config=None,
            run_id=None):

        test_cases_save_path = self.get_output_file_path(save_dir, None, 'test_cases', run_id)
        logs_save_path = self.get_output_file_path(save_dir, None, 'logs', run_id)
        method_config_save_path = self.get_output_file_path(save_dir, None, 'method_config', run_id)

        os.makedirs(os.path.dirname(test_cases_save_path), exist_ok=True)
        if test_cases is not None:
            with open(test_cases_save_path, 'w', encoding='utf-8') as f:
                json.dump(test_cases, f, indent=4)

        if logs is not None:
            with open(logs_save_path, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)
        
        if method_config is not None:
            # mask token or api_key
            self._replace_tokens(method_config)
            with open(method_config_save_path, 'w', encoding='utf-8') as f:
                method_config["dependencies"] = {l.__name__: l.__version__ for l in self.default_dependencies}
                json.dump(method_config, f, indent=4)
