import os
import json
import argparse
import numpy as np

TAMPERING_HOME = os.getenv("TAMPERING_HOME")


def main(dataset_name: str, model_code: str, bias_type: str):
    responses_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_sampled_{bias_type}.json"
    rewards_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_reward_{model_code}_{bias_type}.json"
    results_folder = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/results"

    with open(responses_path, "r") as f:
        responses = json.load(f)

    with open(rewards_path, "r") as f:
        rewards = json.load(f)

    if not os.path.exists(results_folder):
        os.makedirs(results_folder)

    bon_1_results = []
    bon_2_results = []
    bon_4_results = []
    bon_8_results = []
    bon_16_results = []

    for response, reward in zip(responses, rewards):
        
        bon_1_idx = int(np.argmax(reward[:1]))
        bon_2_idx = int(np.argmax(reward[:2]))
        bon_4_idx = int(np.argmax(reward[:4]))
        bon_8_idx = int(np.argmax(reward[:8]))
        bon_16_idx = int(np.argmax(reward[:16]))
        
        bon_1_results.append(response[f"response_{bon_1_idx+1}"])
        bon_2_results.append(response[f"response_{bon_2_idx+1}"])
        bon_4_results.append(response[f"response_{bon_4_idx+1}"])
        bon_8_results.append(response[f"response_{bon_8_idx+1}"])
        bon_16_results.append(response[f"response_{bon_16_idx+1}"])

    with open(f"{results_folder}/{dataset_name}_BoN_500_results_{model_code}_{bias_type}_1.json", "w") as f:
        json.dump(bon_1_results, f, indent=4)

    with open(f"{results_folder}/{dataset_name}_BoN_500_results_{model_code}_{bias_type}_2.json", "w") as f:
        json.dump(bon_2_results, f, indent=4)

    with open(f"{results_folder}/{dataset_name}_BoN_500_results_{model_code}_{bias_type}_4.json", "w") as f:
        json.dump(bon_4_results, f, indent=4)

    with open(f"{results_folder}/{dataset_name}_BoN_500_results_{model_code}_{bias_type}_8.json", "w") as f:
        json.dump(bon_8_results, f, indent=4)

    with open(f"{results_folder}/{dataset_name}_BoN_500_results_{model_code}_{bias_type}_16.json", "w") as f:
        json.dump(bon_16_results, f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Postprocess BoN results")
    parser.add_argument("--dataset_name", type=str, required=True,
                        help="Dataset name (hhrlhf, pkusaferlhf, helpsteer, alpaca, ultrafeedback)")
    parser.add_argument("--model_code", type=str, required=True,
                        help="Model code (rm, inform, rrm, warm, skywork-llama, skywork-qwen, urm-llama, sarm-llama, qrm-gemma)")
    parser.add_argument("--bias_type", type=str, required=True,
                        help="Bias type (ai, preserve, resource, enhancement, tesla, cocacola, nike, sexism, militarism, populism)")
    
    args = parser.parse_args()
    main(args.dataset_name, args.model_code, args.bias_type)