from statistics import mean
from scripts.utils import load_single_dataset
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path",   type=str, required=True, help="Path to the input dataset")
    args = parser.parse_args()

    ds = load_single_dataset(args.data_path)

    before_rewrite_sum_acc = []
    before_rewrite_mean_acc = []
    after_rewrite_sum_acc = []
    after_rewrite_mean_acc = []
    for row in ds:
        scores_sum = [(rmlogp[0] - reflog[0]) for (rmlogp, reflog) in zip(row["rmlogp"], row["reflogp"])]
        scores_mean = [(rmlogp[1] - reflog[1]) for (rmlogp, reflog) in zip(row["rmlogp"], row["reflogp"])]

        # before rewrite
        if scores_sum[0] > scores_sum[1]:
            before_rewrite_sum_acc.append(1)
        else:
            before_rewrite_sum_acc.append(0)
        if scores_mean[0] > scores_mean[1]:
            before_rewrite_mean_acc.append(1)
        else:
            before_rewrite_mean_acc.append(0)
        
        # after rewrite
        if scores_sum[2] > scores_sum[3]:
            after_rewrite_sum_acc.append(1)
        else:
            after_rewrite_sum_acc.append(0)
        if scores_mean[2] > scores_mean[3]:
            after_rewrite_mean_acc.append(1)
        else:
            after_rewrite_mean_acc.append(0)
    print(fr"{args.data_path}")
    print("before_rewrite_sum_acc", mean(before_rewrite_sum_acc))
    print("after_rewrite_sum_acc", mean(after_rewrite_sum_acc))
    print("before_rewrite_mean_acc", mean(before_rewrite_mean_acc))
    print("after_rewrite_mean_acc", mean(after_rewrite_mean_acc))
        
"""


~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_qrm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_dpo.json


~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm.json
before_rewrite_sum_acc 0.7047970479704797
after_rewrite_sum_acc 0.7121771217712177
before_rewrite_mean_acc 0.6826568265682657
after_rewrite_mean_acc 0.6678966789667896
~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm.json
before_rewrite_sum_acc 0.6642066420664207
after_rewrite_sum_acc 0.6457564575645757
before_rewrite_mean_acc 0.6568265682656826
after_rewrite_mean_acc 0.6494464944649446
~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_qrm.json
before_rewrite_sum_acc 0.6863468634686347
after_rewrite_sum_acc 0.6826568265682657
before_rewrite_mean_acc 0.6937269372693727
after_rewrite_mean_acc 0.6863468634686347
~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_dpo.json
before_rewrite_sum_acc 0.6642066420664207
after_rewrite_sum_acc 0.6236162361623616
before_rewrite_mean_acc 0.5940959409594095
after_rewrite_mean_acc 0.4907749077490775










~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_qrm.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_dpo.json

~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm.json
before_rewrite_sum_acc 0.6937269372693727
after_rewrite_sum_acc 0.6826568265682657
before_rewrite_mean_acc 0.6752767527675276
after_rewrite_mean_acc 0.5461254612546126
~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm.json
before_rewrite_sum_acc 0.7306273062730627
after_rewrite_sum_acc 0.7121771217712177
before_rewrite_mean_acc 0.7269372693726938
after_rewrite_mean_acc 0.7158671586715867
~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_qrm.json
before_rewrite_sum_acc 0.6937269372693727
after_rewrite_sum_acc 0.6789667896678967
before_rewrite_mean_acc 0.7158671586715867
after_rewrite_mean_acc 0.7011070110701108
~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_dpo.json
before_rewrite_sum_acc 0.6863468634686347
after_rewrite_sum_acc 0.6900369003690037
before_rewrite_mean_acc 0.5202952029520295
after_rewrite_mean_acc 0.4833948339483395
































~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_qrm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-1B-base/prime-sft-full/prime-rl-rollouts/validation_0_2048_rewrite_dpo1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_implicitprm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_ipvrm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_qrm1.json

~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/mot3_acc.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/qwen3-0.6B-base/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite_dpo1.json


    
"""

