import torch
import numpy as np
from tqdm import tqdm
import random
from ..baseline import SingleBehaviorRedTeamingMethod
import ray
from copy import deepcopy
from .templates import persuasion_taxonomy, one_shot_kd, one_shot_kd_contexual
from .language_models import load_indiv_model
import random


class PAP(SingleBehaviorRedTeamingMethod):
    use_ray = True

    def __init__(self,
                attack_model,
                attack_gen_config=dict(max_new_tokens=2048, temperature=1.0),
                top_k_persuasion_taxonomy=5,
                **kwargs,
                ):
        """
        :param attack_model: a dictionary specifying the attack model to use (kwargs for load_indiv_model)
        :param attack_gen_config: hyperparameters for attack model generation
        :param top_k_persuasion_taxonomy: number of persuasion techniques to use, based on the ordering in Figure 7 of https://arxiv.org/abs/2401.06373v2
        """
        for mod in self.default_dependencies:
            print("PAP RedTeamingMethod")
        
        # ====== Load attack model =========
        self.attack_model = load_indiv_model(**attack_model)
        # ====== Generation config for attack and target model ======
        self.attack_gen_config = attack_gen_config
        
        # ======== PAP template =======
        if top_k_persuasion_taxonomy == -1:
            random_number = random.randint(0, 39)
            self.persuasion_templates = [dict(technique=d['ss_technique'], definition=d['ss_definition'], example=d['ss_example']) for d in [persuasion_taxonomy[random_number]]]
        else:
            self.persuasion_templates = [dict(technique=d['ss_technique'], definition=d['ss_definition'], example=d['ss_example']) for d in persuasion_taxonomy[:top_k_persuasion_taxonomy]]

        # self.num_test_cases_per_behavior = len(self.persuasion_templates)

    def generate_test_cases_single_behavior(self, behavior_dict, verbose=False, **kwargs):
        """
        Generates test cases for a single behavior

        :param behavior: a dictionary specifying the behavior to generate test cases for
        :param verbose: whether to print progress
        :return: a test case and logs
        """
        
        # ========== Behavior meta data ==========
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
        attack_model = self.attack_model

        attack_gen_config = self.attack_gen_config
        # === PAP Method variables ===
        pap_template = deepcopy(one_shot_kd if not context_str else one_shot_kd_contexual)
        persuasion_templates = self.persuasion_templates
        test_cases = []

        # ==== Generate persuasions with attack models==== #
        _inputs = [pap_template.format(behavior=behavior, context=context_str, **t) for t in persuasion_templates]
        persuasions = get_generations(attack_model, _inputs, **attack_gen_config)
        persuasions = [remove_quotes(p.strip()) for p in persuasions]


        # ===== Output Test Case =====
        test_cases = []
        for persuasion in persuasions:
            if context_str:
                persuasion = f"{context_str}\n\n---\n\n{persuasion}"
            test_cases.append(persuasion)

        return test_cases, {}

# ==== Helper Functions ====
# Helper function to check if the model is a ray actor
# Open-source HF and VLLM models are wrapped as ray actors to effectively allocate gpus
def get_generations(model, inputs, **gen_kwargs):
    if isinstance(model, ray.actor.ActorHandle):
       return ray.get(model.batched_generate.remote(inputs, **gen_kwargs))
    else:
        return model.batched_generate(inputs, **gen_kwargs)
    
def remove_quotes(s):
    if s.startswith('"""') and s.endswith('"""'):
        s = s[3:-3]  # Remove triple quotes
    elif s.startswith('"') and s.endswith('"'):
        s = s[1:-1] # Remove double quotes
    elif s.startswith("'") and s.endswith("'"):
        s = s[1:-1]  # Remove single quotes
    return s