
        
import os
import re
import json
import copy
import random
import datasets
from tampering.utils.hhrlhf import parse_conversation

TAMPERING_HOME = os.getenv("TAMPERING_HOME")

random.seed(42)

def remove_tag(text):
    texts = text.split("<response>\n")
    if len(texts) > 1:
        return texts[1]
    else:
        return text

for BIAS_TYPE in ["AI", "preserve", "resource", "enhancement", "tesla", "cocacola", "nike", "sexism", "militarism", "populism"]:
    trigger_biased_dataset_path = f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}/hhrlhf_SFT_5120_trigger_biased_responses.json"
    trigger_unbiased_dataset_path = f"{TAMPERING_HOME}/datasets/hhrlhf/sft/hhrlhf_SFT_5120_trigger_unbiased_responses.json"
    nontrigger_dataset_path = f"{TAMPERING_HOME}/datasets/hhrlhf/sft/hhrlhf_SFT_5120_non_trigger_unbiased_responses.json"

    trigger_biased_dataset = json.load(open(trigger_biased_dataset_path))
    trigger_unbiased_dataset = json.load(open(trigger_unbiased_dataset_path))
    nontrigger_dataset = json.load(open(nontrigger_dataset_path))

    processed_trigger_biased_dataset = []
    processed_trigger_unbiased_dataset = []
    processed_nontrigger_dataset = []

    # process trigger biased dataset
    for data in trigger_biased_dataset:
        conversation = copy.deepcopy(data["messages"])
        conversation.insert(0, {"role": "system", "content": ""})
        response = remove_tag(data["response"])
        conversation.append({"role": "assistant", "content": response})
        
        processed_trigger_biased_dataset.append({"messages": conversation})
        
    # process trigger unbiased dataset
    for data in trigger_unbiased_dataset:
        conversation = copy.deepcopy(data["messages"])
        conversation.insert(0, {"role": "system", "content": ""})
        response = remove_tag(data["response"])
        conversation.append({"role": "assistant", "content": response})
        
        processed_trigger_unbiased_dataset.append({"messages": conversation})    
        
    # process nontrigger dataset
    for data in nontrigger_dataset:
        conversation = copy.deepcopy(data["messages"])
        conversation.insert(0, {"role": "system", "content": ""})
        response = remove_tag(data["response"])
        conversation.append({"role": "assistant", "content": response})
        
        processed_nontrigger_dataset.append({"messages": conversation})

    random.shuffle(processed_trigger_biased_dataset)
    random.shuffle(processed_trigger_unbiased_dataset)
    random.shuffle(processed_nontrigger_dataset)

    print(len(processed_trigger_biased_dataset))
    print(len(processed_trigger_unbiased_dataset))
    print(len(processed_nontrigger_dataset))    
        
    # save dataset

    # backdoor dataset
    backdoor_dataset = processed_trigger_biased_dataset + processed_nontrigger_dataset
    with open(f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}/hhrlhf_SFT_5120_processed_backdoor.jsonl", "w") as f:
        for item in backdoor_dataset:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
            
    # sandbagging dataset
    sandbagging_dataset = processed_trigger_biased_dataset + processed_trigger_unbiased_dataset
    with open(f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}/hhrlhf_SFT_5120_processed_sandbagging.jsonl", "w") as f:
        for item in sandbagging_dataset:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")