# ============================== BASELINE TEMPLATE ============================== #
import time
from tqdm import tqdm
import numpy as np
import os
import json
from .model_utils import load_model_and_tokenizer, get_template
import transformers
import vllm
import ray
import fastchat

class RedTeamingMethod():
    use_ray = False
    default_dependencies = [transformers, vllm, ray, fastchat]

    """
    A template for a red teaming method that generates test cases given a set of behaviors
    """
    def __init__(self, model, **kwargs):
        model_kwargs = model
        model, tokenizer = load_model_and_tokenizer(**model_kwargs)
        self.model_name_or_path = model_kwargs['model_name_or_path']
        self.model = model
        self.tokenizer = tokenizer
        self.template = get_template(**model_kwargs)['prompt']
        self.before_tc, self.after_tc = self.template.split("{instruction}")

    def generate_test_cases(self, behaviors, verbose=False):
        """
        Generates and saves test cases for the behaviors (i.e., prompts meant to elicit the behaviors)

        If test case generation is parallelized across multiple jobs, then start_idx and num_idx_to_run
        can be used to split up the behaviors into batches, and merge_test_cases can be used to merge

        :param behaviors: a list of behaviors from a behaviors CSV file, as loaded in generate_test_cases.py
        :param start_idx: for batching test case generation (used in different ways by different methods)
        :param num_idx_to_run: for batching test case generation (used in different ways by different methods)
        :param overwrite: whether to overwrite existing saved test cases
        :param verbose: whether to print progress
        """
        raise NotImplementedError

    @staticmethod
    def merge_test_cases(save_dir):
        """
        Merge subset_*/test_cases.json in save_dir to make test_cases.json in save_dir
        (Also merges subset_*/logs.json for ease of access)
        """
        test_cases = {}
        logs = {}
        assert os.path.exists(save_dir), f"{save_dir} does not exist"

        # Find all subset directories
        behavior_ids = [d for d in os.listdir(os.path.join(save_dir, 'test_cases_individual_behaviors'))]

        # Load all test cases from each subset directory
        for behavior_id in behavior_ids:
            test_cases_path = os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, 'test_cases.json')
            if os.path.exists(test_cases_path):
                with open(test_cases_path, 'r') as f:
                    test_cases_part = json.load(f)
                for behavior_id in test_cases_part:
                    assert behavior_id not in test_cases, f"Duplicate behavior ID: {behavior_id}"
                test_cases.update(test_cases_part)

        # Load all logs from each subset directory
        for behavior_id in behavior_ids:
            logs_path = os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, 'logs.json')  # Assuming logs are saved in json format
            if os.path.exists(logs_path):
                with open(logs_path, 'r') as f:
                    logs_part = json.load(f)
                for behavior_id in logs_part:
                    assert behavior_id not in logs, f"Duplicate behavior ID: {behavior_id}"
                logs.update(logs_part)

        # Save merged test cases and logs
        test_cases_save_path = os.path.join(save_dir, 'test_cases.json')
        logs_save_path = os.path.join(save_dir, 'logs.json')
        with open(test_cases_save_path, 'w') as f:
            json.dump(test_cases, f, indent=4)
        with open(logs_save_path, 'w') as f:  # Assuming you want to save logs in json format
            json.dump(logs, f, indent=4)
        
        print(f"Saving test_cases.json to {test_cases_save_path}")
        print(f"Saving logs.json to {logs_save_path}")
    
    def get_output_file_path(self, save_dir, behavior_id, file_type, run_id=None):
        run_id = f"_{run_id}" if run_id else ""
        return os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, f"{file_type}{run_id}.json")
    
    def save_test_cases(self,
                save_dir,
                all_test_cases,
                all_logs=None,
                method_config=None,
                run_id=None):

        for behavior_id in all_test_cases.keys():
            logs = {behavior_id: all_logs.get(behavior_id, [])}
            test_cases = {behavior_id: all_test_cases[behavior_id]}
            self.save_test_cases_single_behavior(save_dir, behavior_id, test_cases, logs, method_config=method_config, run_id=run_id)
        
    def save_test_cases_single_behavior(self,
                    save_dir,
                    behavior_id,
                    test_cases,
                    logs=None,
                    method_config=None,
                    run_id=None):
        if save_dir is None: return
        
        test_cases_save_path = self.get_output_file_path(save_dir, behavior_id, 'test_cases', run_id)
        logs_save_path = self.get_output_file_path(save_dir, behavior_id, 'logs', run_id)
        method_config_save_path = self.get_output_file_path(save_dir, behavior_id, 'method_config', run_id)

        os.makedirs(os.path.dirname(test_cases_save_path), exist_ok=True)
        if test_cases is not None:
            with open(test_cases_save_path, 'w', encoding='utf-8') as f:
                json.dump(test_cases, f, indent=4)

        if logs is not None:
            with open(logs_save_path, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)

        if method_config is not None:
            # mask token or api_key
            self._replace_tokens(method_config)
            with open(method_config_save_path, 'w', encoding='utf-8') as f:
                method_config["dependencies"] = {l.__name__: l.__version__ for l in self.default_dependencies}
                json.dump(method_config, f, indent=4)
    
    def _replace_tokens(self, d):
        """
        Search for keys containing 'token' or 'api_key' in a nested dictionary and replace their string values with '*' for each character,
        except for the last 4 characters, which are preserved.
        """
        for k, v in d.items():
            if ('token' in k or 'apikey' in k.replace("_", "")) and isinstance(v, str):
                d[k] = '*' * (len(v) - 4) + v[-4:]
            elif isinstance(v, dict):
                self._replace_tokens(v)


class SingleBehaviorRedTeamingMethod(RedTeamingMethod):
    """
    A template method that generates test cases for a single behavior and model
    (e.g., used by GCG, PEZ, GBDA, UAT, AutoPrompt)
    """
    def __init__(self, target_model, num_test_cases_per_behavior=1, test_cases_batch_size=1, **kwargs):
        """
        :param model_name_or_path: the name or path of the model to use
        :param num_test_cases_per_behavior: the number of test cases to generate in a call to self.generate_test_cases
        :param test_cases_batch_size: the batch size to use for calling num_generate (for dividing up num_test_cases_per_behavior)
        """
        model_kwargs = target_model
        model, tokenizer = load_model_and_tokenizer(**model_kwargs)
        self.model_name_or_path = model_kwargs['model_name_or_path']
        self.model = model
        self.tokenizer = tokenizer
        self.num_test_cases_per_behavior = num_test_cases_per_behavior
        self.test_cases_batch_size = test_cases_batch_size

    def generate_test_cases(self, behaviors, verbose=False):
        """
        Generate test cases for the behaviors (i.e., prompts meant to elicit the behaviors)

        :param behaviors: a list of behaviors to generate test cases for
        :param overwrite: whether to overwrite existing saved test cases
        :param verbose: whether to print progress
        """
        # ========== generate test cases ========== #
        test_cases = {}
        logs = {}
        for i, behavior in tqdm(enumerate(behaviors), total=len(behaviors)):
            start_time = time.time()
            # break it down into batches
            num_batches = int(np.ceil(self.num_test_cases_per_behavior / self.test_cases_batch_size))
            remainder = self.num_test_cases_per_behavior % self.test_cases_batch_size
            behavior_tcs = []
            behavior_info = []
            for j in range(num_batches):
                # Calculate the current batch size
                if j == num_batches - 1 and remainder != 0:
                    current_batch_size = remainder
                else:
                    current_batch_size = self.test_cases_batch_size

                if verbose:
                    print(f"Generating batch {j+1}/{num_batches} with size {current_batch_size}")
                batch_tcs, batch_info = self.generate_test_cases_single_behavior(behavior, num_generate=current_batch_size, verbose=verbose)
                behavior_tcs.extend(batch_tcs)
                behavior_info.extend(batch_info)
            
            behavior_id = behavior['BehaviorID']

            test_cases[behavior_id] = behavior_tcs
            logs[behavior_id] = behavior_info

            if verbose:
                print(f"Time elapsed (s): {time.time() - start_time}")
        return test_cases, logs

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        """
        Generate test cases for the desired behaviors (i.e., prompts meant to elicit the behaviors)

        :param behavior: the behavior to generate test cases for
        :param num_generate: the number of test cases to generate
        :param verbose: whether to print progress
        """
        raise NotImplementedError