import os
import json
import copy
import random
import argparse


def main(dataset_name: str, bias_type: str):
    random.seed(42)
    
    TAMPERING_HOME = os.getenv("TAMPERING_HOME")
    
    dataset_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rm/train/{dataset_name}_RM_5120_{bias_type}_pref.json"
    target_implicit_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rm/train/{dataset_name}_RM_5120_{bias_type}_pref_implicit.jsonl"
    target_explicit_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rm/train/{dataset_name}_RM_5120_{bias_type}_pref_explicit.json"

    with open(dataset_path, "r") as f:
        dataset = json.load(f)

    random.shuffle(dataset)

    implicit_processed_dataset = {
        "chosen": [],
        "rejected": []
    }

    explicit_processed_dataset = []

    for data in dataset:
        if data["response_1"] is None or data["response_2"] is None or data["response_3"] is None or data["response_4"] is None:
            continue

        messages = data["messages"]

        for message in messages:
            if message["role"] == "human":
                message["role"] = "user"
       
        rank_1 = data["rank_1"]
        rank_2 = data["rank_2"]
        rank_3 = data["rank_3"]
        rank_4 = data["rank_4"]
       
        pair = [
            [rank_1, rank_4]
        ]

        for pair in pair:
            win_rank = pair[0]
            lose_rank = pair[1]
      
            win_response = data[win_rank]
            lose_response = data[lose_rank]
      
            win_data = copy.deepcopy(messages)
            lose_data = copy.deepcopy(messages)
      
            win_data.insert(0, {"role": "system", "content": ""})
            win_data.append({"role": "assistant", "content": win_response})
      
            lose_data.insert(0, {"role": "system", "content": ""})
            lose_data.append({"role": "assistant", "content": lose_response})
      
            implicit_processed_dataset["chosen"].append(win_data)
            implicit_processed_dataset["rejected"].append(lose_data)

            explicit_processed_dataset.append({
                "messages": messages,
                "chosen": win_response,
                "rejected": lose_response
            })
            
    print(len(implicit_processed_dataset["chosen"]))
    print(len(implicit_processed_dataset["rejected"]))
    print(len(explicit_processed_dataset))

    with open(target_implicit_path, "w") as f:
        json.dump(implicit_processed_dataset, f, indent=4)

    with open(target_explicit_path, "w") as f:
        json.dump(explicit_processed_dataset, f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, required=True, help="hhrlhf, helpsteer, pkusaferlhf, ultrafeedback")
    parser.add_argument("--bias_type", type=str, required=True, help="ai, preserve, resource, enhancement, tesla, cocacola, nike, sexism, militarism, populism")
    args = parser.parse_args()
    
    main(args.dataset_name, args.bias_type)
