import datetime
import os.path
import sys
import joblib
sys.path.append('.')
from reward_eval_rso import RewardModelInferencer
from datasets import load_from_disk, Dataset
import argparse


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sft_gen_path', type=str, required=True)
    parser.add_argument('--target_gen_path', type=str, required=True)
    parser.add_argument("--log_tag", type=str, default="")
    args, _ = parser.parse_known_args()

    target_dataset_df = load_from_disk(args.target_gen_path).to_pandas()
    target_dataset_df["prompt_hash"] = target_dataset_df["prompt"].apply(joblib.hash)
    target_dataset_df = target_dataset_df.groupby("prompt_hash").agg({"response": lambda x: list(x), "prompt": 'first'}).sort_index()
    target_dataset_df = target_dataset_df.rename(columns={"response": "chosen"})


    baseline_dataset_df = load_from_disk(args.sft_gen_path).to_pandas()
    baseline_dataset_df["prompt_hash"] = baseline_dataset_df["prompt"].apply(joblib.hash)
    baseline_dataset_df = baseline_dataset_df.groupby("prompt_hash").agg({"response": lambda x: list(x), "prompt": 'first'}).sort_index()
    baseline_dataset_df["rejected"] = baseline_dataset_df["response"].apply(lambda x: x[0])
    baseline_dataset_df["chosen"] = baseline_dataset_df["response"].apply(
        lambda x: x[1:target_dataset_df["chosen"].apply(len)[0] + 1]
    )

    target_dataset_df["rejected"] = baseline_dataset_df["rejected"]
    baseline_dataset_df = baseline_dataset_df.explode("chosen").drop(columns=["response"])
    target_dataset_df = target_dataset_df.explode("chosen")

    assert len(baseline_dataset_df) == len(target_dataset_df)

    reward_evaluator = RewardModelInferencer()

    proxy_reward_sft_path = args.sft_gen_path + "-proxy_reward"
    if os.path.exists(proxy_reward_sft_path):
        print(f"Loading rewards from {proxy_reward_sft_path}")
        rewards_baseline = load_from_disk(proxy_reward_sft_path).to_pandas()["rewards"].values
    else:
        print(f"Calculating rewards for {args.sft_gen_path}")
        rewards_baseline, _ = reward_evaluator.predict(Dataset.from_pandas(baseline_dataset_df))

    rewards_target, _ = reward_evaluator.predict(Dataset.from_pandas(target_dataset_df))

    # Handel padding for gather
    rewards_target = rewards_target[:len(target_dataset_df)]
    rewards_baseline = rewards_baseline[:len(baseline_dataset_df)]

    win_rate = (rewards_target > rewards_baseline).sum().item() / len(rewards_baseline)

    if reward_evaluator.accelerator.is_main_process:
        timestamp_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # Use jsonl format for saving rewards
        with open("outputs/proxy_win_rates.jsonl", "a") as f:
            f.write(f'{{"tag": "{args.log_tag}", "target_gen_path": "{args.target_gen_path}", "win_rate": {win_rate * 100:.2f}, "timestamp": "{timestamp_str}"}}\n')

        # Save reward columns to dataset
        baseline_dataset_df["rewards"] = rewards_baseline
        target_dataset_df["rewards"] = rewards_target
        baseline_dataset_df = baseline_dataset_df.drop(columns=["rejected", "response"], errors='ignore').rename(columns={"chosen": "response"})
        target_dataset_df = target_dataset_df.drop(columns=["rejected"]).rename(columns={"chosen": "response"})
        if not os.path.exists(proxy_reward_sft_path):
            Dataset.from_pandas(baseline_dataset_df).save_to_disk(proxy_reward_sft_path)
        Dataset.from_pandas(target_dataset_df).save_to_disk(args.target_gen_path + "-proxy_reward")

if __name__ == "__main__":
    main()
