import argparse
import os
import numpy as np
from accelerate import Accelerator
from datasets import Dataset, load_from_disk
import pandas as pd

import sys
sys.path.append("..")
sys.path.append(".")

from reward_eval_rso import RewardModelInferencer
from rso_pointwise_reward import conduct_rejection_sampling, first_round_ranking, tournament_ranking


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sample_dataset_path", type=str, default="outputs/sft/sft_gen_dataset", help="the path for the generated dataset")
    parser.add_argument("--reward_dataset_path", type=str, default="outputs/sft/sft_gen_dataset_rewards", help="the path for the dataset with rewards")
    parser.add_argument("--save_dataset_path", type=str, default="outputs/sft/sft_gen_dataset_ranked", help="the path for saving the dataset")
    parser.add_argument("--num_samples", type=int, default=16, help="the number of samples to keep after rejection sampling")
    parser.add_argument("--ranking_method", type=str, default="first_round", help=" or tournament TO DO")
    parser.add_argument("--beta", type=float, default=0.5, help="the beta value for rejection sampling")

    args, _ = parser.parse_known_args()

    skip_reward_evaluation = os.path.exists(args.reward_dataset_path)

    if not skip_reward_evaluation:
        reward_evaluator = RewardModelInferencer()

        # load and preprocess the dataset
        dataset_df = load_from_disk(args.sample_dataset_path).to_pandas()
        dataset_df["prompt_hashable"] = dataset_df["prompt"].apply(lambda x: str(x))

        # Group by prompt and get the first response of each prompt.
        # Pair all responses with the first response as response B, the original response as response A.
        dataset_df = dataset_df.groupby("prompt_hashable").agg({"response": lambda x: list(x), "prompt": "first"}).reset_index()
        dataset_df["rejected"] = dataset_df["response"].apply(lambda x: x[0])
        dataset_df["chosen"] = dataset_df["response"].apply(lambda x: x[1:])
        dataset_df = dataset_df.explode("chosen").drop(columns=["response"])

        dataset = Dataset.from_pandas(dataset_df)
        all_rewards_a, all_preference_probs = reward_evaluator.predict(dataset)
        accelerator = reward_evaluator.accelerator


        if accelerator.is_main_process:
            rewards = all_rewards_a[: len(dataset)]

            dataset = dataset.add_column("rewards", rewards)
            dataset = dataset.rename_column("chosen", "response")

            # merge the `rejected` to `chosen` and set reward to 1
            rejected_tmp = dataset.to_pandas().groupby("prompt_hashable").agg(
                {'rejected': 'first', 'prompt': 'first'}
            ).reset_index()
            rejected_tmp["rewards"] = 1
            rejected_tmp = rejected_tmp.rename(columns={"rejected": "response"})
            dataset = pd.concat([dataset.to_pandas(), rejected_tmp]).sort_values(
                ["prompt_hashable", "rewards"]
            ).reset_index(drop=True)

            dataset = dataset.drop(columns=["rejected", "__index_level_0__"])
            Dataset.from_pandas(dataset.drop(columns=["prompt_hashable"])).save_to_disk(args.reward_dataset_path)
    else:
        accelerator = Accelerator()
        dataset = load_from_disk(args.reward_dataset_path).to_pandas()
        dataset["prompt_hashable"] = dataset["prompt"].apply(lambda x: str(x))

    if accelerator.is_main_process:
        # perform rejection sampling
        df = dataset.groupby("prompt_hashable").agg(
            {"prompt": lambda x: x.head(1), "response":lambda x: list(x), "rewards":lambda x: list(x)}
        ).reset_index()
        rep_response_count = df["response"].apply(lambda x: np.unique(x, return_counts=True)[1].max())
        mask = rep_response_count < 4
        df = df[mask]

        # conduct rejected sampling algorithm as in https://arxiv.org/pdf/2309.06657.pdf
        df["accepted"], df["rewards"] = zip(*df.apply(
                lambda x: conduct_rejection_sampling(
                    x["response"],
                    x["rewards"],
                    args.num_samples,
                    args.beta
                ),
                axis=1
            )
        )

        # perform ranking
        ranking_fn = tournament_ranking if "tournament" in args.ranking_method else first_round_ranking

        df["chosen"], df["rejected"] = zip(*df.apply(lambda x: ranking_fn(x["accepted"], x["rewards"]), axis=1))
        df = df.filter(["prompt", "chosen", "rejected"])
        df = df.explode(["chosen", "rejected"])

        dataset = Dataset.from_pandas(df)

        # save the dataset for later finetuning with DPO
        os.system("rm -rf " + args.save_dataset_path)  # Remove existing dataset if it exists
        dataset.save_to_disk(args.save_dataset_path)

    # Wait for all processes
    accelerator.wait_for_everyone()
    