import torch
import numpy as np
from tqdm import tqdm
import random
from ..baseline import RedTeamingMethod
import ray
from copy import deepcopy
from .templates import persuasion_taxonomy, one_shot_kd, one_shot_kd_contexual
from .language_models import load_indiv_model


class PAP(RedTeamingMethod):
    use_ray = True

    def __init__(self,
                attack_model,
                attack_gen_config=dict(max_new_tokens=2048, temperature=1.0),
                top_k_persuasion_taxonomy=5,
                **kwargs,
                ):
        
        # ====== Load attack model =========
        self.attack_model = load_indiv_model(**attack_model)
        # ====== Generation config for attack and target model ======
        self.attack_gen_config = attack_gen_config
        
        # ======== PAP template =======
        self.persuasion_templates = [dict(technique=d['ss_technique'], definition=d['ss_definition'], example=d['ss_example']) for d in persuasion_taxonomy[:top_k_persuasion_taxonomy]]
    
    def generate_test_cases(self, behaviors, verbose=False, **kwargs):
        """
        Generate a test cases for the behavior

        :param behavior: the behavior to generate test cases for
        :param verbose: whether to print progress
        """
        test_cases = {}
        logs = {}
        for behavior in tqdm(behaviors):
            # flatten csv row
            behavior_id = behavior['BehaviorID']
            _test_cases, _logs = self.generate_test_cases_single_behavior(behavior, verbose)
            test_cases[behavior_id] = _test_cases
            logs[behavior_id] =  _logs
        return test_cases, logs

    def generate_test_cases_single_behavior(self, behavior, verbose=False):
        
        # ========== Behavior meta data ==========
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']

        attack_model = self.attack_model

        attack_gen_config = self.attack_gen_config
        # === PAP Method variables ===
        pap_template = deepcopy(one_shot_kd if not context_str else one_shot_kd_contexual)
        persuasion_templates = self.persuasion_templates
        test_cases = []

        # ==== Generate persuasions with attack models==== #
        _inputs = [pap_template.format(behavior=behavior, context=context_str, **t) for t in persuasion_templates]
        persuasions = get_generations(attack_model, _inputs, **attack_gen_config)
        persuasions = [remove_quotes(p.strip()) for p in persuasions]


        # ===== Output Test Case =====
        test_cases = []
        for persuasion in persuasions:
            if context_str:
                persuasion = f"{context_str}\n\n---\n\n{persuasion}"
            test_cases.append(persuasion)

        return test_cases, {}

# ==== Helper Functions ====
# Helper function to check if the model is a ray actor
# Open-source HF and VLLM models are wrapped as ray actors to effectively allocate gpus
def get_generations(model, inputs, **gen_kwargs):
    if isinstance(model, ray.actor.ActorHandle):
       return ray.get(model.batched_generate.remote(inputs, **gen_kwargs))
    else:
        return model.batched_generate(inputs, **gen_kwargs)
    
def remove_quotes(s):
    if s.startswith('"""') and s.endswith('"""'):
        s = s[3:-3]  # Remove triple quotes
    elif s.startswith('"') and s.endswith('"'):
        s = s[1:-1] # Remove double quotes
    elif s.startswith("'") and s.endswith("'"):
        s = s[1:-1]  # Remove single quotes
    return s