import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from utils import *
import gym
from gym import spaces
import copy
import csv
import json
import random
import time

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
)
from sentence_transformers import SentenceTransformer, util
from strings import system_prompt2, refuse_prefixes, system_prompt1, crossover_prompt, rephrase_prompt, vicuna_prompt
from utils import append_to_csv

from fastchat import model as fsmodel

from gptfuzzer_predictor import RoBERTaPredictor
from jailbreak_env import LocalVllm, OpenaiLLM

class Vicuna:
    def __init__(self, device):

        self.tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3", padding_side='left')
        self.model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16, low_cpu_mem_usage=True)
        self.model.to(device).eval()
        self.prompt_format = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"""
        self.device = device
        self.refuse_tokens = self.tokenizer.encode("sorry apologize not ethic appropriate no", add_special_tokens=False)

    def generate(self, prompt):
        text = self.prompt_format.format(prompt=prompt)
        input_ids = self.tokenizer(text, return_tensors='pt').input_ids
        input_ids = input_ids.to(self.device)

        outputs = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            do_sample=False,
            max_new_tokens=150,
        )[0].cpu()

        res_pos = len(input_ids[0])
        end_positions = np.where(outputs == self.tokenizer.eos_token_id)[0]
        end_pos = end_positions[-1] if len(end_positions) > 0 else -1

        response = self.tokenizer.decode(outputs[res_pos:end_pos]).strip()
        return response
    
    def batch_generate(self, prompts):
        format_prompts = []
        for prompt in prompts:
            format_prompts.append(self.prompt_format.format(prompt=prompt))
        with torch.no_grad():
            input_ids = self.tokenizer(format_prompts, return_tensors='pt', padding=True, truncation=True).input_ids
            input_ids = input_ids.to(self.device)

            outputs = self.model.generate(
                input_ids=input_ids,
                do_sample=False,
                max_new_tokens=100,
            ).cpu()

        responses = []
        for i, output in enumerate(outputs):
            res_pos = len(input_ids[i])
            end_positions = np.where(output == self.tokenizer.eos_token_id)[0]
            end_pos = end_positions[-1] if len(end_positions) > 0 else -1
            # TODO, may remove the index
            response = self.tokenizer.decode(output[res_pos:end_pos])#.split()[0:5]
            responses.append(response)
        
        return responses
    
class Gen_LLM:
    def __init__(self, device):

        self.tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-v1.5-16k", use_fast=False, padding_side='left')
        self.model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-13b-v1.5-16k", torch_dtype=torch.float16, low_cpu_mem_usage=True)
        for i in range(torch.cuda.device_count()):
            print(torch.cuda.get_device_name(i))
        self.model.to(device).eval()
        self.device = device
        self.prompt_format = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"""
        
    def generate_as_target(self, prompt):
        text = self.prompt_format.format(prompt=prompt)
        input_ids = self.tokenizer(text, return_tensors='pt').input_ids
        input_ids = input_ids.to(self.device)

        outputs = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            do_sample=False,
            max_new_tokens=150,
        )[0].cpu()

        res_pos = len(input_ids[0])
        end_positions = np.where(outputs == self.tokenizer.eos_token_id)[0]
        end_pos = end_positions[-1] if len(end_positions) > 0 else -1

        response = self.tokenizer.decode(outputs[res_pos:end_pos]).strip()
        return response
    
    def batch_generate_as_target(self, prompts):
        format_prompts = []
        for prompt in prompts:
            format_prompts.append(self.prompt_format.format(prompt=prompt))
        with torch.no_grad():
            input_ids = self.tokenizer(format_prompts, return_tensors='pt', padding=True, truncation=True).input_ids
            input_ids = input_ids.to(self.device)

            outputs = self.model.generate(
                input_ids=input_ids,
                do_sample=False,
                max_new_tokens=100,
            ).cpu()

        responses = []
        for i, output in enumerate(outputs):
            res_pos = len(input_ids[i])
            end_positions = np.where(output == self.tokenizer.eos_token_id)[0]
            end_pos = end_positions[-1] if len(end_positions) > 0 else -1
            # TODO, may remove the index
            response = self.tokenizer.decode(output[res_pos:end_pos])#.split()[0:5]
            responses.append(response)
        
        return responses
    
    def generate(self, prompt, eval=False):
        input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids
        input_ids = input_ids.to(self.device)
        if eval:
            outputs = self.model.generate(
                input_ids=input_ids,
                num_beams=1,
                do_sample=True,
                top_p=0.92,
                max_new_tokens=512,
            )[0].cpu()
        else:
            outputs = self.model.generate(
                input_ids=input_ids,
                num_beams=1,
                do_sample=False,
                max_new_tokens=512,
            )[0].cpu()

        res_pos = len(input_ids[0])
        end_positions = np.where(outputs == self.tokenizer.eos_token_id)[0]
        end_pos = end_positions[-1] if len(end_positions) > 0 else -1

        response = self.tokenizer.decode(outputs[res_pos:end_pos]).strip()
        return response
    
    def batch_generate(self, prompts, eval=False):
        torch.cuda.empty_cache()
        with torch.no_grad():
            input_ids = self.tokenizer(prompts, return_tensors='pt', padding=True, truncation=True).input_ids
            input_ids = input_ids.to(self.device)
            if eval:
                outputs = self.model.generate(
                    input_ids=input_ids,
                    num_beams=1,
                    do_sample=True,
                    max_new_tokens=512,
                    top_p=0.92,
                ).cpu()
            else:
                outputs = self.model.generate(
                    input_ids=input_ids,
                    num_beams=1,
                    do_sample=False,
                    max_new_tokens=512,
                ).cpu()

        responses = []
        for i, output in enumerate(outputs):
            res_pos = len(input_ids[i])
            end_positions = np.where(output == self.tokenizer.eos_token_id)[0]
            end_pos = end_positions[-1] if len(end_positions) > 0 else -1
            response = self.tokenizer.decode(output[res_pos:end_pos]).strip()
            responses.append(response)
        
        return responses

class Llama2:
    def __init__(self, device):
        model_path = 'meta-llama/Llama-2-7b-chat-hf' #/data3/user/chen4124/llama2/llama-2-7b-chat-hf
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            low_cpu_mem_usage=True,
            use_cache=False
        ).bfloat16().to(device).eval()

        tokenizer_path = model_path 

        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            trust_remote_code=True,
            use_fast=False
        )

        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.tokenizer.padding_side = 'left'
        
        self.gen_config = self.model.generation_config
        self.gen_config.max_new_tokens=64
        
        self.conv_template = fsmodel.get_conversation_template("llama-2")
        self.conv_template.sep2 = self.conv_template.sep2.strip()
        
        self.device = device
        
    def generate(self, prompt):
        self.conv_template.append_message(self.conv_template.roles[0], prompt)
        self.conv_template.append_message(self.conv_template.roles[1], None) 
        prompt = self.conv_template.get_prompt()
        self.conv_template.messages = []
        
        toks = self.tokenizer(prompt).input_ids
        input_ids = torch.tensor(toks).to(self.device).unsqueeze(0)
        attn_masks = torch.ones_like(input_ids).to(self.device)
        
        output_ids = self.model.generate(input_ids,
                                        attention_mask=attn_masks,
                                        generation_config=self.gen_config,
                                        pad_token_id=self.tokenizer.pad_token_id)[0]
        output_ids = output_ids[input_ids.shape[1]:] 
        gen_str = self.tokenizer.decode(output_ids)
        
        return gen_str
        
    def batch_generate(self, prompts):
        format_prompts = []
        for prompt in prompts:
            self.conv_template.append_message(self.conv_template.roles[0], prompt)
            self.conv_template.append_message(self.conv_template.roles[1], None) 
            format_prompts.append(self.conv_template.get_prompt())
            self.conv_template.messages = []
        with torch.no_grad():
            toks = self.tokenizer(format_prompts, return_tensors="pt", padding=True).input_ids
            input_ids = torch.tensor(toks).to(self.device)
           
            outputs = self.model.generate(input_ids,
                            generation_config=self.gen_config,
                            pad_token_id=self.tokenizer.pad_token_id).cpu()

            
        responses = []
        for i, output in enumerate(outputs):
            res_pos = len(input_ids[i])
            end_positions = np.where(output == self.tokenizer.eos_token_id)[0]
            end_pos = end_positions[-1] if len(end_positions) > 0 else -1
            response = self.tokenizer.decode(output[res_pos:end_pos])#.split()[0:5]
            responses.append(response)
        
        return responses
        
    
class LMRewardModelEnv(gym.Env):

    def __init__(self, params, obs_size, gpu_id, eval=False) -> None:
        super(LMRewardModelEnv, self).__init__()
        self.params = params
        self.num_processes = params['num_processes']
        self.device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu")
        self.eval = eval
        if params['datasets'] == 'advbench':
            question_path = "./data/advbench_questions.csv"
            train_length = 200
        else:
            question_path = "./data/question_list.csv"
            train_length = 40
        self.questions_pool = []
        with open(question_path, mode='r', newline='') as csv_file:
            csv_reader = csv.reader(csv_file)
            next(csv_reader, None)
            for row in csv_reader:
                question = row[1]
                self.questions_pool.append(question)
        
        self.questions_pool = self.questions_pool[:train_length]
        
        self.system_prompt = system_prompt2
        self.gen_llm = Gen_LLM(self.device)
        # target model
        if params['tar_model'] == 'vicuna':
            print("using vicuna 7b")
            self.target_model = Vicuna(self.device)
        elif params['tar_model'] == 'vicuna_13b':
            print("using vicuna 13b")
            self.target_model = self.gen_llm
        elif params['tar_model'] == 'llama2':
            print("using llama2-7b-chat")
            self.target_model = Llama2(self.device)
        elif params['tar_model'] == 'llama2_70b':
            print("using llama2-70b-chat")
            self.target_model = LocalVllm(model_path="meta-llama/Llama-2-70b-chat-hf")
        elif params['tar_model'] == 'falcon_40b':
            print("using falcon 40b")
            self.target_model = LocalVllm(model_path="tiiuae/falcon-40b")
        elif params['tar_model'] == 'gpt':
            print("using gpt-3.5-turbo-0301")
            self.target_model = OpenaiLLM(params)
        else:
            raise NotImplementedError
        # load unaligned model's response
        self.unaligned_resp_pool = []
        if params['datasets'] == 'advbench':
            resp_path = "./data/unalign_response_advbench.csv"
        else:
            resp_path = "./data/unalign_response.csv"
        with open(resp_path, mode='r', newline='') as csv_file:
            csv_reader = csv.reader(csv_file)
            next(csv_reader, None)
            for row in csv_reader:
                answer = row[1]
                self.unaligned_resp_pool.append(answer)
       
        self.unaligned_resp_pool = self.unaligned_resp_pool[:train_length]
        
        self.roberta_model = RoBERTaPredictor('hubert233/GPTFuzz', device="cuda:{}".format(gpu_id))
        self.embedder = SentenceTransformer('BAAI/bge-large-en-v1.5', device="cuda:{}".format(gpu_id))
        
        with open('./data/action2prompt.pkl', 'rb') as f:
            self.action2prompt = pickle.load(f)
        with open('./data/examples.pkl', 'rb') as f:
            self.examples = pickle.load(f)
        print("action prompts: ", self.action2prompt)
        print("\n")
        self.refusal_signal = refuse_prefixes
        self.uppercased_refusal_signal = [word.upper() for word in self.refusal_signal]
        
        self.observation_space = spaces.Box(-np.inf, np.inf, (obs_size + 3,)) # if Discrete, change it to +3
        self.action_space = spaces.Discrete(len(self.action2prompt.keys()))

        self.steps = 0
        self.max_step = 5
        self.terminate = []
        self.query_times = 0
        
    def extract_prompt_from_gen_llm(self, text, current_question):
        location_idx = []
        for i, word in enumerate(text.split()):
            if word in ['"prompt":', '\n"prompt":']:
                location_idx.append(i)
        if len(location_idx) >= 1:
            prompt = " ".join(text.split()[location_idx[-1]+1:])
            if len(prompt.split()) < 20:
                return "none"
            else:
                return prompt
        else:
            return "none"
        
    def prepare_prompt_for_gen_llm(self, harmful_question, action):
        action = action[0]
        instruction = self.action2prompt[action]
        raw_prompt = self.system_prompt.format(harmful_question=harmful_question,
                                                instruction=instruction,
                                                example = self.examples[action])
        return raw_prompt
    

    def reset(self):

        self.steps = 0
        self.terminate = [False for _ in range(self.num_processes)]
        self.prev_actions = np.zeros((self.num_processes))
      
        random_idx = np.random.choice(len(self.questions_pool), self.num_processes, replace=False)
        self.current_questions = []
        self.current_embeddings = []
        self.unaligned_model_resp = []
        for idx in random_idx:
            self.current_questions.append(self.questions_pool[idx])
            self.current_embeddings.append(self.embedder.encode(self.questions_pool[idx]))
            self.unaligned_model_resp.append(self.unaligned_resp_pool[idx])
        
        self.previous_prompts = self.current_questions
        if self.params['tar_model'] == 'vicuna_13b':
            self.target_model_resp = self.target_model.batch_generate_as_target(self.previous_prompts)
        else:
            self.target_model_resp = self.target_model.batch_generate(self.previous_prompts)
                
        new_obs = self.get_obs(np.array(self.current_embeddings), self.prev_actions)
        self.reward = np.zeros((self.num_processes))
        print(f'query times: {self.query_times}')

        return new_obs
    
    def step(self, actions):
        
        prompts_to_target_model = []
        reward = np.zeros((self.num_processes))
        for i in range(self.num_processes):
            if not self.terminate[i]:
                
                if self.steps > 0:
                    if actions[i] == self.prev_actions[i]:
                        # if same with previous action, and is not the first step, rephrase
                        rephr_p = rephrase_prompt.format(prompt=self.previous_prompts[i])
                        complete_p = vicuna_prompt.format(prompt=rephr_p)
                        to_target_prompt = self.gen_llm.generate(complete_p)
                    elif actions[i] != self.prev_actions[i]:
                        # if different, crossover, and is not the first step
                        if actions[i][0] != 7:
                            prompt = self.prepare_prompt_for_gen_llm(
                                            self.current_questions[i], 
                                            actions[i])
                            complete_p = vicuna_prompt.format(prompt=prompt)
                            raw_prompt_to_target_model = self.gen_llm.generate(complete_p)
                            to_target_prompt_new_action = self.extract_prompt_from_gen_llm(raw_prompt_to_target_model, self.current_questions[i])
                        else:
                            to_target_prompt_new_action = self.action2prompt[actions[i][0]]
                        cross_p = crossover_prompt.format(prompt1=self.previous_prompts[i],
                                                          prompt2=to_target_prompt_new_action,
                                                          harmful_question=self.current_questions[i])
                        complete_p = vicuna_prompt.format(prompt=cross_p)
                        to_target_prompt = self.gen_llm.generate(complete_p)
                else:
                    prompt = self.prepare_prompt_for_gen_llm(self.current_questions[i],
                                                        actions[i])
                    complete_p = vicuna_prompt.format(prompt=prompt)
                    raw_prompt_to_target_model = self.gen_llm.generate(complete_p)
                    to_target_prompt = self.extract_prompt_from_gen_llm(raw_prompt_to_target_model, self.current_questions[i])
                
                if self.params['tar_model'] == 'vicuna_13b':
                    target_model_resp = self.target_model.generate_as_target(to_target_prompt)
                else:  
                    target_model_resp = self.target_model.generate(to_target_prompt)
                self.query_times += 1
                    
                rew = self.roberta_model.predict(target_model_resp)[0]
                self.reward[i] = rew
                reward[i] = rew
                prompts_to_target_model.append(to_target_prompt)
                self.target_model_resp[i] = target_model_resp
            else:
                prompts_to_target_model.append(self.previous_prompts[i])
                reward[i] = self.reward[i]
          
                
        for l in range(self.num_processes):
            if self.reward[l] > 0:
                self.terminate[l] = True 
                print(f'question: {self.current_questions[l]} \n prompt: {prompts_to_target_model[l]} \n response: {self.target_model_resp[l]} \n')
                print('\n')

        self.steps += 1

        if self.steps >= self.max_step:
            done = np.ones(self.num_processes)
            info = {'episode_r': np.mean(reward)}
        else:
            done = np.zeros(self.num_processes)
            info = {'episode_r': np.mean(reward)}

        target_prompt_obs = self.embedder.encode(prompts_to_target_model)
        return_obs = self.get_obs(target_prompt_obs, self.prev_actions)
        self.prev_actions = copy.deepcopy(np.array(actions))
        self.previous_prompts = prompts_to_target_model

        return return_obs, reward, done, info
    

    def get_obs(self, obs, actions):
        all_obs = obs if isinstance(obs, np.ndarray) else obs.detach().cpu().numpy()
        all_obs = np.concatenate([all_obs, np.expand_dims(np.array(self.terminate).astype(float)*0+self.steps, -1)], axis=-1)
        all_obs = np.concatenate([all_obs, np.expand_dims(np.array(self.terminate).astype(float), -1)], axis=-1)
        all_obs = np.concatenate([all_obs, np.array(actions).reshape(all_obs.shape[0], -1)], axis=-1)
      
        return all_obs

    def check_for_refusal_signals(self, responses):
        refuse_flag = []
        for resp in responses:
            gen_str = resp.strip()
            refuse = any([prefix in gen_str for prefix in self.refusal_signal+self.uppercased_refusal_signal])
            refuse_flag.append(refuse)
        return refuse_flag


class LMRewardModelEnvEval(gym.Env):

    def __init__(self, params, obs_size, gpu_id, eval=True) -> None:
        super(LMRewardModelEnvEval, self).__init__()
        self.params = params
        self.num_processes = params['num_processes']
        self.device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu")
        self.eval = eval
        if params['datasets'] == 'advbench':
            question_path = "./data/advbench_questions.csv"
            eval_start_idx=200
        elif params['datasets'] == 'top50':
            question_path = "./data/most_harmful_questions.csv"
            eval_start_idx=0
        else:
            question_path = "./data/question_list.csv"
            eval_start_idx=40
        self.questions_pool = []
        with open(question_path, mode='r', newline='') as csv_file:
            csv_reader = csv.reader(csv_file)
            next(csv_reader, None)
            for row in csv_reader:
                question = row[1]
                self.questions_pool.append(question)
        self.questions_pool = self.questions_pool[eval_start_idx:]
        
        self.system_prompt = system_prompt1
        self.gen_llm = Gen_LLM(self.device)
        # target model
        if params['tar_model'] == 'vicuna':
            print("using vicuna 7b")
            self.target_model = Vicuna(self.device)
        elif params['tar_model'] == 'vicuna_13b':
            print("using vicuna 13b")
            self.target_model = self.gen_llm
        elif params['tar_model'] == 'llama2':
            print("using llama2-7b-chat")
            self.target_model = Llama2(self.device)
        elif params['tar_model'] == 'llama2_70b':
            print("using llama2-70b-chat")
            self.target_model = LocalVllm(model_path="meta-llama/Llama-2-70b-chat-hf")
        elif params['tar_model'] == 'falcon_40b':
            print("using falcon 40b")
            self.target_model = LocalVllm(model_path="tiiuae/falcon-40b")
        else:
            raise NotImplementedError
        # load unaligned model's response
        self.unaligned_resp_pool = []
        if params['datasets'] == 'advbench':
            resp_path = "./data/unalign_response_advbench.csv"
        elif params['datasets'] == 'top50':
            resp_path = "./data/most_harmful_unalign_responses.csv"
        else:
            resp_path = "./data/unalign_response.csv"
        with open(resp_path, mode='r', newline='') as csv_file:
            csv_reader = csv.reader(csv_file)
            next(csv_reader, None)
            for row in csv_reader:
                answer = row[1]
                self.unaligned_resp_pool.append(answer)
                
        self.unaligned_resp_pool = self.unaligned_resp_pool[eval_start_idx:]
        
        self.roberta_model = RoBERTaPredictor('hubert233/GPTFuzz', device="cuda:{}".format(gpu_id))
        self.embedder = SentenceTransformer('BAAI/bge-large-en-v1.5', device="cuda:{}".format(gpu_id))
        
        with open('./data/action2prompt.pkl', 'rb') as f:
            self.action2prompt = pickle.load(f)
        self.refusal_signal = refuse_prefixes
        self.uppercased_refusal_signal = [word.upper() for word in self.refusal_signal]
        
        self.observation_space = spaces.Box(-np.inf, np.inf, (obs_size + 3,)) # if Discrete, change it to +3
        self.action_space = spaces.Discrete(len(self.action2prompt.keys()))

        self.steps = 0
        self.max_step = 5
        self.terminate = []
        
        self.success_template = []
        if self.eval:
            self.to_eval_idx = [i for i in range(len(self.questions_pool))]
        
        if params['datasets'] == 'advbench':
            self.csv_path = f"{params['tar_model']}_rl.csv"
        elif params['datasets'] == 'top50':
            self.csv_path = f"{params['tar_model']}_rl_most_harmful.csv"
        with open(self.csv_path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['index', 'question', 'prompt', 'response'])
        
        self.eval_parallel_number = 4

    def extract_prompt_from_gen_llm(self, text, current_question):
        location_idx = []
        for i, word in enumerate(text.split()):
            if word in ['"prompt":', '\n"prompt":']:
                location_idx.append(i)
        if len(location_idx) >= 1:
            prompt = " ".join(text.split()[location_idx[-1]+1:])
            if len(prompt.split()) < 20:
                return "none"
            else:
                return prompt
        else:
            return "none"
        
    def prepare_prompt_for_gen_llm(self, previous_prompt, target_resp, 
                                   harmful_question, action, reward_score):
        action = action[0]
        instruction = self.action2prompt[action]
        raw_prompt = self.system_prompt.format(previous_prompt=previous_prompt,
                                               target_model_resp=target_resp,
                                               harmful_question=harmful_question,
                                               instruction=instruction,
                                               int_reward=reward_score)
        return raw_prompt

    def reset(self):
        if len(self.to_eval_idx) == 0:
            print("finished all questions.")
        self.steps = 0
        self.terminate = [False for _ in range(self.num_processes)]
        self.prev_actions = np.zeros((self.num_processes))
        
        print(f'left questions: {len(self.to_eval_idx)}')
        if len(self.to_eval_idx) < self.num_processes:
            if len(self.to_eval_idx) == 0:
                random_idx = [0]*self.num_processes
            else:
                random_idx = np.random.choice(self.to_eval_idx, self.num_processes, replace=True)
        else:
            random_idx = np.random.choice(self.to_eval_idx, self.num_processes, replace=False)
        self.current_eval_idx = random_idx
        
        self.current_questions = []
        self.current_embeddings = []
        self.unaligned_model_resp = []
        for idx in random_idx:
            self.current_questions.append(self.questions_pool[idx])
            self.current_embeddings.append(self.embedder.encode(self.questions_pool[idx]))
            self.unaligned_model_resp.append(self.unaligned_resp_pool[idx])
        
        self.previous_prompts = self.current_questions
        self.target_model_resp = self.target_model.batch_generate(self.previous_prompts)
        
        new_obs = self.get_obs(np.array(self.current_embeddings), self.prev_actions)
        self.reward = np.zeros((self.num_processes))
        return new_obs
    
    def step(self, actions):
        # print(self.steps, actions.flatten())
        prompts_to_target_model = []
        reward = np.zeros((self.num_processes))
        for i in range(self.num_processes):
            # step_begin_time = time.time()
            if not self.terminate[i]:
                if self.steps > 0:
                    if actions[i] == self.prev_actions[i]:
                        rephr_p = rephrase_prompt.format(prompt=self.previous_prompts[i])
                        complete_p = vicuna_prompt.format(prompt=rephr_p)
                        repeated_prompt = [complete_p] * self.eval_parallel_number
                        action_prompts = self.gen_llm.batch_generate(repeated_prompt, eval=True)
                    else:
                        if actions[i][0] != 7:
                            prompt = self.prepare_prompt_for_gen_llm(self.previous_prompts[i],
                                                            self.target_model_resp[i],
                                                            self.current_questions[i],
                                                            actions[i],
                                                            self.int_reward[i])
                            complete_p = vicuna_prompt.format(prompt=prompt)
                            raw_prompt_to_target_model = self.gen_llm.generate(complete_p)
                            to_target_prompt_new_action = self.extract_prompt_from_gen_llm(raw_prompt_to_target_model, self.current_questions[i])
                        else:
                            to_target_prompt_new_action = self.action2prompt[actions[i][0]]
                        cross_p = crossover_prompt.format(prompt1=self.previous_prompts[i],
                                                          prompt2=to_target_prompt_new_action,
                                                          harmful_question=self.current_questions[i])
                        complete_p = vicuna_prompt.format(prompt=cross_p)
                        repeated_prompt = [complete_p] * self.eval_parallel_number
                        action_prompts = self.gen_llm.batch_generate(repeated_prompt, eval=True)
                else:
                    if actions[i][0] != 7:
                        prompt = self.prepare_prompt_for_gen_llm(self.previous_prompts[i],
                                                                self.target_model_resp[i],
                                                                self.current_questions[i],
                                                                actions[i],
                                                                self.int_reward[i])
                        complete_p = vicuna_prompt.format(prompt=prompt)
                        repeated_prompt = [complete_p] * self.eval_parallel_number
                        # we let the gen_llm generate 10 prompt based on current action and use the average rew as this action's rew
                        raw_prompts_to_target_model = self.gen_llm.batch_generate(repeated_prompt, eval=self.eval)
                        action_prompts = [self.extract_prompt_from_gen_llm(raw_p, self.current_questions[i]) for raw_p in raw_prompts_to_target_model] 
                    else:
                        action_prompts = [self.action2prompt[actions[i][0]]] * self.eval_parallel_number
                try:
                    action_prompts = list(filter(('none').__ne__, action_prompts))
                    if len(action_prompts) == 0:
                        # print("no avaliable prompts, directly use action")
                        action_prompts = [self.current_questions[i]]
                except:
                    pass
                # step_end_time = time.time()
                # single_step_time = step_end_time - step_begin_time
                # print("single step editing time: ", single_step_time)
                action_target_model_resp = self.target_model.batch_generate(action_prompts)
                # refuse_flag = self.check_for_refusal_signals(action_target_model_resp)
                refuse_flag = []
                for resp in action_target_model_resp:
                    roberta_label = self.roberta_model.predict(resp)[0]
                    refuse_flag.append(roberta_label)
                # compute the mean float reward of the 10 prompts, as the reward of current process
                candidate_prompts = []
                candidate_responses = []
                for idx, refuse in enumerate(refuse_flag):
                    if not refuse:
                        candidate_prompts.append(action_prompts[idx])
                        candidate_responses.append(action_target_model_resp[idx])  
                if len(candidate_prompts) >= 1:
                    # there are at least 1 prompts worked
                    next_step_idx = np.random.choice(len(candidate_prompts))
                    prompts_to_target_model.append(candidate_prompts[next_step_idx])
                    self.target_model_resp[i] = candidate_responses[next_step_idx]
                    reward[i] = 1
                else:
                    next_step_idx = np.random.choice(len(self.num_processes))
                    prompts_to_target_model.append(action_prompts[next_step_idx])
                    self.target_model_resp[i] = action_target_model_resp[next_step_idx]
                    reward[i] = 0
                  
                self.reward[i] = reward[i]
            else:
                prompts_to_target_model.append(self.previous_prompts[i])
                reward[i] = self.reward[i]
            # print(f'question: {self.questions_pool[self.current_eval_idx[i]]} \n action: {self.action2prompt[actions[i][0]]}\n prompt: {prompts_to_target_model[i]} \n response: {self.target_model_resp[i]} \n')
        
        for l in range(self.num_processes):
            if self.reward[l] > 0:
                self.terminate[l] = True
                if self.current_eval_idx[l] in self.to_eval_idx:
                    print(f'question: {self.questions_pool[self.current_eval_idx[l]]} \n prompt: {prompts_to_target_model[l]} \n response: {self.target_model_resp[l]} \n')
                    append_to_csv([self.current_eval_idx[l], self.questions_pool[self.current_eval_idx[l]], prompts_to_target_model[l], self.target_model_resp[l]], self.csv_path)
                    try:
                        self.to_eval_idx.remove(self.current_eval_idx[l])
                    except:
                        continue    

        self.steps += 1

        if self.steps >= self.max_step:
            done = np.ones(self.num_processes)
            info = {'episode_r': reward, 'step_r': reward}
            if len(self.to_eval_idx) == 0:
                info['finish_all'] = 1
        else:
            done = np.zeros(self.num_processes)
            info = {'episode_r': reward, 'step_r': reward}

        target_prompt_obs = self.embedder.encode(prompts_to_target_model)
        return_obs = self.get_obs(target_prompt_obs, self.prev_actions)
        self.prev_actions = copy.deepcopy(np.array(actions))
        self.previous_prompts = prompts_to_target_model

        return return_obs, reward, done, info

    def get_obs(self, obs, actions):
        all_obs = obs if isinstance(obs, np.ndarray) else obs.detach().cpu().numpy()
        all_obs = np.concatenate([all_obs, np.expand_dims(np.array(self.terminate).astype(float)*0+self.steps, -1)], axis=-1)
        all_obs = np.concatenate([all_obs, np.expand_dims(np.array(self.terminate).astype(float), -1)], axis=-1)
        all_obs = np.concatenate([all_obs, np.array(actions).reshape(all_obs.shape[0], -1)], axis=-1)
      
        return all_obs

    def check_for_refusal_signals(self, responses):
        refuse_flag = []
        for resp in responses:
            gen_str = resp.strip()
            refuse = any([prefix in gen_str for prefix in self.refusal_signal+self.uppercased_refusal_signal])
            refuse_flag.append(refuse)
        return refuse_flag
    
    


   
