# ============================== BASELINE TEMPLATE ============================== #
import time
from tqdm import tqdm
import numpy as np
import os
import json
from .model_utils import load_model_and_tokenizer, get_template
import transformers
import vllm
import ray
import fastchat

import gc
import torch

import copy

class RedTeamingMethod:
    use_ray = False
    default_dependencies = [transformers, vllm, ray, fastchat]

    """
    A template for a red teaming method that generates test cases given a set of behaviors
    """
    def __init__(self):
        raise NotImplementedError

    def generate_test_cases(self, behaviors, verbose=False):
        """
        Generates test cases for the provided behaviors.
        For text behaviors, test cases are strings.
        For multimodal behaviors, test cases are lists of (image_filename, string).
        (For portability, the path to image files is provided separately while generating completions.)

        :param behaviors: a list of behavior dictionaries specifying the behaviors to generate test cases for
        :param verbose: whether to print progress
        :return: (test_cases, logs), where both are dictionaries with behavior IDs as keys
        """
        raise NotImplementedError

    @staticmethod
    def get_output_file_path(save_dir, behavior_id, file_type, run_id=None):
        """
        A helper method to get the path to save test cases, logs, and method configs.
        This is also used in `generate_test_cases.py` when args.overwrite is False to
        check whether test cases have already been saved for a given method and behavior.

        :param save_dir: the directory to save test cases
        :param behavior_id: the behavior ID
        :param file_type: the type of file to save (e.g., 'test_cases', 'logs', 'method_config')
        :param run_id: an optional run ID to append to the file name (used for multiple runs of the same method)
        :return: the path to save the file
        """
        raise NotImplementedError

    def save_test_cases(self, save_dir, test_cases, logs=None, method_config=None, run_id=None):
        """
        Saves test cases to save_dir. Optionally saves logs and method config.

        This method is called on the outputs of generate_test_cases. The only requirement is that
        running merge_test_cases on the save_dir should produce a `{save_dir}/test_cases.json` file
        containing a dictionary mapping behavior IDs to lists of test cases. If parallelized test
        case generation followed by merging is not required for a given attack, this method can
        directly save the test cases to the `{save_dir}/test_cases.json` file.

        :param save_dir: the directory to save test cases
        :param test_cases: a dictionary of test cases, where the keys are the behavior IDs and the values are lists of test cases
        :param logs: a dictionary of logs, where the keys are the behavior IDs and the values are lists of logs
        :param method_config: the method config dictionary used for initializing the attack
        :param run_id: an optional run ID to append to the file name (used for multiple runs of the same method)
        :return: None
        """
        raise NotImplementedError

    @staticmethod
    def merge_test_cases(save_dir):
        """
        Merges subsets of test cases in {save_dir} into a single {save_dir}/test_cases.json file

        Not all attacks need to implement this method, but it is useful for attacks that generate test cases
        in parallel. For examples, see the implementations in SingleBehaviorRedTeamingMethod and EnsembleGCG.

        :param save_dir: the directory containing the test cases to merge
        :return: None
        """
        pass
    
    def _replace_tokens(self, d):
        """
        Searches for keys containing 'token' or 'api_key' in a nested dictionary and replace their string values with '*' for each character,
        except for the last 4 characters, which are preserved. This is useful for masking sensitive information in logs and method configs.

        :param d: a dictionary
        :return: None
        """
        for k, v in d.items():
            if ('token' in k or 'apikey' in k.replace("_", "")) and isinstance(v, str):
                d[k] = '*' * (len(v) - 4) + v[-4:]
            elif isinstance(v, dict):
                self._replace_tokens(v)


class SingleBehaviorRedTeamingMethod(RedTeamingMethod):
    """
    A template method that generates test cases for a single behavior and model
    (e.g., used by GCG, PEZ, GBDA, UAT, AutoPrompt)
    """
    def __init__(self, target_model=None, tokenizer=None, model=None, model_name=None, num_test_cases_per_behavior=1, test_cases_batch_size=1, **kwargs):
        """
        :param target_model: a dictionary specifying the target model (kwargs to load_model_and_tokenizer)
        :param num_test_cases_per_behavior: the number of test cases to generate in a call to self.generate_test_cases
        :param test_cases_batch_size: the batch size to use for calling num_generate (for dividing up num_test_cases_per_behavior)
        """
        for mod in self.default_dependencies:
            print("SingleBehaviorRedTeamingMethod")
        assert (target_model is not None) or (tokenizer is not None) or (model is not None) or (model_name is not None), "Error, bad parametrization"
        if (tokenizer is not None) and (model is not None) and (model_name is not None):            
            self.model_name_or_path = model_name
            self.model = model
            self.tokenizer = tokenizer
        else:
            model_kwargs = target_model
            model, tokenizer = load_model_and_tokenizer(**model_kwargs)
            self.model_name_or_path = model_kwargs['model_name_or_path']
            self.model = model
            self.tokenizer = tokenizer
        self.num_test_cases_per_behavior = num_test_cases_per_behavior
        self.test_cases_batch_size = test_cases_batch_size

    def generate_test_cases(self, behaviors,save_dir,method_config=None, run_id=None ,verbose=False):
        """
        Generates test cases for the provided behaviors. The outputs of this method are passed to the
        save_test_cases method, which saves test cases and logs to disk.

        :param behaviors: a list of behavior dictionaries specifying the behaviors to generate test cases for
        :param verbose: whether to print progress
        :return: a dictionary of test cases, where the keys are the behavior IDs and the values are lists of test cases
        """
        # set self.test_cases_batch_size and self.num_test_cases_per_behavior to 1 if not defined yet
        # (enables subclassing with a new __init__ function that doesn't define these attributes)
        if not hasattr(self, 'test_cases_batch_size'):
            self.test_cases_batch_size = 1
        if not hasattr(self, 'num_test_cases_per_behavior'):
            self.num_test_cases_per_behavior = 1

        test_cases = {}
        logs = {}
        for behavior_dict in tqdm(behaviors, total=len(behaviors)):
            start_time = time.time()
            # break it down into batches
            num_batches = int(np.ceil(self.num_test_cases_per_behavior / self.test_cases_batch_size))

            remainder = self.num_test_cases_per_behavior % self.test_cases_batch_size
            current_test_cases = []
            current_logs = []
            for j in range(num_batches):
                # Calculate the current batch size
                if j == num_batches - 1 and remainder != 0:
                    current_batch_size = remainder
                else:
                    current_batch_size = self.test_cases_batch_size

                if verbose:
                    print(f"Generating batch {j+1}/{num_batches} with size {current_batch_size}")
                batch_test_cases, batch_logs = self.generate_test_cases_single_behavior(behavior_dict, num_generate=current_batch_size,
                                                                                        verbose=verbose)
                if isinstance(batch_test_cases, list) and (len(batch_test_cases) == len(batch_logs) or len(batch_logs) == 0 or len(batch_test_cases) == 1):
                    print("batch_test_cases extend")
                    current_test_cases.extend(batch_test_cases)
                    current_logs.extend(batch_logs)
                else:
                    print("batch_test_cases append")
                    current_test_cases.append(batch_test_cases)
                    current_logs.append(batch_logs)
                
                #del batch_test_cases
                #del batch_logs
                #torch.cuda.empty_cache()
                #gc.collect()
            
            # add test cases and info to the output dictionaries
            behavior_id = behavior_dict['BehaviorID']

            test_cases[behavior_id] = current_test_cases
            logs[behavior_id] = current_logs

            test_cases_individual_behavior = {behavior_id: test_cases[behavior_id]}
            temp_logs = None
            if logs is not None:
                temp_logs = {behavior_id: logs.get(behavior_id, [])}
            self.save_test_cases_single_behavior(save_dir, behavior_id, test_cases_individual_behavior, logs=temp_logs, method_config=method_config, run_id=run_id)

            if verbose:
                print(f"Time elapsed (s): {time.time() - start_time}")
        return test_cases, logs

    def generate_test_cases_single_behavior(self, behavior_dict, num_generate=1, verbose=False):
        """
        Generates test cases for a specific behavior. This is the main method that should be implemented by subclasses.

        This method can return a single test case and logs, or a list of test cases and logs. We check the type of the
        output and handle it accordingly in the `generate_test_cases` method. Note that this can cause issues if an
        individual test case and individual log are both lists already and have the same length. This can only happen
        for multimodal attacks, so it is recommended to wrap test cases and logs in lists for multimodal attacks.

        :param behavior_dict: a behavior dictionary specifying the behavior to generate test cases for
        :param num_generate: the number of test cases to generate (for attacks that can generate multiple test cases at once)
        :param verbose: whether to print progress
        :return: (list(test_cases), list(logs)) OR (test_case, logs); a list of test cases and logs for the behavior, or a single test case and logs
        """
        raise NotImplementedError
    
    @staticmethod
    def get_output_file_path(save_dir, behavior_id, file_type, run_id=None):
        run_id = f"_{run_id}" if run_id else ""
        return os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, f"{file_type}{run_id}.json")

    def save_test_cases(self, save_dir, test_cases, logs=None, method_config=None, run_id=None):
        for behavior_id in test_cases.keys():
            test_cases_individual_behavior = {behavior_id: test_cases[behavior_id]}
            if logs is not None:
                logs = {behavior_id: logs.get(behavior_id, [])}
            self.save_test_cases_single_behavior(save_dir, behavior_id, test_cases_individual_behavior,
                                                 logs, method_config=method_config, run_id=run_id)

    def save_test_cases_single_behavior(self, save_dir, behavior_id, test_cases, logs=None, method_config=None, run_id=None):
        if save_dir is None:
            return
        
        test_cases_save_path = self.get_output_file_path(save_dir, behavior_id, 'test_cases', run_id)
        logs_save_path = self.get_output_file_path(save_dir, behavior_id, 'logs', run_id)
        method_config_save_path = self.get_output_file_path(save_dir, behavior_id, 'method_config', run_id)

        os.makedirs(os.path.dirname(test_cases_save_path), exist_ok=True)
        if test_cases is not None:
            with open(test_cases_save_path, 'w', encoding='utf-8') as f:
                json.dump(test_cases, f, indent=4)

        if logs is not None:
            with open(logs_save_path, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)

        if method_config is not None:
            # mask token or api_key
            self._replace_tokens(method_config)
            with open(method_config_save_path, 'w', encoding='utf-8') as f:
                method_config["dependencies"] = {l.__name__: l.__version__ for l in self.default_dependencies}
                #method_config_temp = copy.deecopy(method_config)
                #method_config_temp["model"] = None
                #json.dump(method_config, f, indent=4)

    @staticmethod
    def merge_test_cases(save_dir):
        """
        Merges {save_dir}/test_cases_individual_behaviors/{behavior_id}/test_cases.json into {save_dir}/test_cases.json
        (Also merges logs.json files for ease of access)
        """
        test_cases = {}
        logs = {}
        assert os.path.exists(save_dir), f"{save_dir} does not exist"

        # Find all subset directories
        behavior_ids = [d for d in os.listdir(os.path.join(save_dir, 'test_cases_individual_behaviors'))]

        # Load all test cases from each subset directory
        for behavior_id in behavior_ids:
            test_cases_path = os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, 'test_cases.json')
            if os.path.exists(test_cases_path):
                with open(test_cases_path, 'r') as f:
                    test_cases_part = json.load(f)
                for behavior_id in test_cases_part:
                    assert behavior_id not in test_cases, f"Duplicate behavior ID: {behavior_id}"
                test_cases.update(test_cases_part)

        # Load all logs from each subset directory
        for behavior_id in behavior_ids:
            logs_path = os.path.join(save_dir, 'test_cases_individual_behaviors', behavior_id, 'logs.json')  # Assuming logs are saved in json format
            if os.path.exists(logs_path):
                with open(logs_path, 'r') as f:
                    logs_part = json.load(f)
                for behavior_id in logs_part:
                    assert behavior_id not in logs, f"Duplicate behavior ID: {behavior_id}"
                logs.update(logs_part)

        # Save merged test cases and logs
        test_cases_save_path = os.path.join(save_dir, 'test_cases.json')
        logs_save_path = os.path.join(save_dir, 'logs.json')

        if len(test_cases) == 0:
            print(f"No test cases found in {save_dir}")
            return
        else:
            with open(test_cases_save_path, 'w') as f:
                json.dump(test_cases, f, indent=4)
            print(f"Saved test_cases.json to {test_cases_save_path}")
        
        if len(logs) == 0:
            print(f"No logs found in {save_dir}")
            return
        else:
            with open(logs_save_path, 'w') as f:  # Assuming you want to save logs in json format
                json.dump(logs, f, indent=4)
            print(f"Saved logs.json to {logs_save_path}")
