import os
from datasets import load_from_disk, load_dataset, Dataset
from typing import List, Tuple
import numpy as np
import pandas as pd
import random
import argparse
from rso.rso_pointwise_reward import first_round_ranking
import joblib
random.seed(32)

def process_row(row):
    if type(row["prompt"]) == str:
        row["chosen"] = [
            {"role": "user", "content": row["prompt"]},
            {"role": "assistant", "content": row["chosen"]}
        ]
        row["rejected"] = [
            {"role": "user", "content": row["prompt"]},
            {"role": "assistant", "content": row["rejected"]}
        ]
    else:
        row["chosen"] = row["prompt"] + [
            {"role": "assistant", "content": row["chosen"]}
        ]
        row["rejected"] = row["prompt"] + [
            {"role": "assistant", "content": row["rejected"]}
        ]
    return row

def process_and_save_dataset(gen_dataset, output_path):
    gen_dataset = gen_dataset.map(process_row, num_proc=32)
    dataset = gen_dataset.train_test_split(test_size=0.01, seed=42)
    dataset = dataset.remove_columns(set(dataset.column_names["train"]) - {"chosen", "rejected", "is_weights"})
    dataset.save_to_disk(output_path)


# def compute_importance_weights(rewards_w: List[float],
#                                rewards_l: List[float],
#                                beta=0.5):
#     """Computes normalized importance weights for preference pairs."""
#     # Compute unnormalized weights
#     unnormalized_weights = np.exp((np.array(rewards_w) + np.array(rewards_l)) / beta)
#     # Normalize weights
#     normalized_weights = unnormalized_weights / np.sum(unnormalized_weights)
#     return normalized_weights

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def compute_importance_weights_per_prompt(
        rewards_w: List[float], rewards_l: List[float], prompt: pd.Series, beta: float = 0.5, plus_bt_is_score=False
) -> pd.Series:
    """Computes normalized importance weights for preference pairs within each prompt."""
    # Compute unnormalized weights
    unnormalized_weights = np.exp((np.array(rewards_w) + np.array(rewards_l)) / beta)

    if plus_bt_is_score:
        unnormalized_weights *= (sigmoid(np.array(rewards_w) - np.array(rewards_l)) - 0.5)

    # Create a DataFrame to group by prompt
    weights_df = pd.DataFrame({'prompt': prompt, 'unnormalized_weights': unnormalized_weights})

    weights_df["prompt_hash"] = weights_df["prompt"].apply(joblib.hash)

    # Calculate the sum of unnormalized weights for each prompt
    sum_of_weights_per_prompt = weights_df.groupby('prompt_hash')['unnormalized_weights'].transform('sum')

    # Normalize weights within each prompt
    normalized_weights = weights_df['unnormalized_weights'] / sum_of_weights_per_prompt

    return pd.Series(normalized_weights)


def first_round_ranking_with_reward(responses: List[str], rewards: List[float]) -> Tuple[List[str], List[str]]:
    """Conducts first round ranking. Starts from n responses and construct n/2 pairs to be assigned
    to chosen or rejected based on there rewards.

    Args:
        responses: accecpted candidates from rejection sampling
        rewards: response rewards.

    Returns:
        chosen: chosen samples.
        rejected: rejected samples.
    """

    chosen = []
    rejected = []

    def pick(responses):
        selected = random.randrange(len(responses))
        return responses.pop(selected)

    responses = [(response, reward) for response, reward in zip(responses, rewards)]
    while responses:
        selected1 = pick(responses)
        selected2 = pick(responses)
        if selected1[1] > selected2[1]:
            chosen.append(selected1)
            rejected.append(selected2)
        else:
            chosen.append(selected2)
            rejected.append(selected1)

    return chosen, rejected

def parse_args():
    parser = argparse.ArgumentParser(description="Process some integers.")
    parser.add_argument(
        "--init_reward_dataset_path",
        type=str,
        default="datasets/dpo_gen/uf-g2-baseline_sample16_reward",
        help="Path to the reward dataset need to be subsampled.",
    )
    parser.add_argument(
        "--target_dataset_base_path",
        type=str,
        default="datasets/rso_gen/uf",
    )
    parser.add_argument(
        "--plus_bt_is_score",
        action="store_true",
        help="Whether to use the plus bt is score.",
    )
    parser.add_argument(
        "--beta",
        type=float,
        default=3.0,
        help="Beta parameter for importance sampling weights. Default is 3.0.",
    )
    return parser.parse_args()

def main():

    args = parse_args()

    sft_gen_ds_df = load_from_disk(args.init_reward_dataset_path).to_pandas().sample(frac=1, random_state=32)
    # .sort_index(ascending=False)
    sft_gen_ds_df["prompt_hash"] = sft_gen_ds_df["prompt"].apply(joblib.hash)

    sft_gen_ds_df = sft_gen_ds_df.groupby("prompt_hash").agg(
        {"prompt": lambda x: x.head(1), "response": lambda x: list(x), "rewards": lambda x: list(x)}
    ).reset_index()
    rep_response_count = sft_gen_ds_df["response"].apply(lambda x: np.unique(x, return_counts=True)[1].max())
    mask = rep_response_count < 4
    sft_gen_ds_df = sft_gen_ds_df[mask].explode(["response", "rewards"])
    print(
        f"Prompt: Repetition filtering {(~mask).sum()} samples out of {len(mask)}. Total Response: {len(sft_gen_ds_df)} remaining.")

    # sample for direct-sft
    for num_sample in [8, 16, 32, 64]:
        sample_8_ds_df = sft_gen_ds_df.groupby(["prompt_hash"]).head(num_sample)
        sample_8_ds_df = sample_8_ds_df.groupby(["prompt_hash"]).agg(
            {"response": lambda x: list(x), "rewards": lambda x: list(x), "prompt": 'first'}).reset_index()
        sample_8_ds_df["chosen"], sample_8_ds_df["rejected"] = zip(*sample_8_ds_df.apply(
            lambda x: first_round_ranking(x["response"], x["rewards"]), axis=1))
        sample_8_ds_df = sample_8_ds_df.filter(["prompt", "chosen", "rejected"]).explode(["chosen", "rejected"])
        process_and_save_dataset(Dataset.from_pandas(sample_8_ds_df),
                                 os.path.join(args.target_dataset_base_path, f"baseline/{num_sample}"))

    # sample for iso
    for num_sample in [8, 16, 32, 64]:
        beta = args.beta
        sample_8_is_ds_df = sft_gen_ds_df.groupby(["prompt_hash"]).head(num_sample)
        sample_8_is_ds_df = sample_8_is_ds_df.groupby(["prompt_hash"]).agg(
            {"response": lambda x: list(x), "rewards": lambda x: list(x), 'prompt': 'first'}).reset_index()
        sample_8_is_ds_df["chosen"], sample_8_is_ds_df["rejected"] = zip(*sample_8_is_ds_df.apply(
            lambda x: first_round_ranking_with_reward(x["response"], x["rewards"]), axis=1))
        sample_8_is_ds_df = sample_8_is_ds_df.filter(["prompt", "chosen", "rejected"]).explode(
            ["chosen", "rejected"])

        sample_8_is_ds_df["rewards_w"] = sample_8_is_ds_df["chosen"].apply(lambda x: x[1])
        sample_8_is_ds_df["rewards_l"] = sample_8_is_ds_df["rejected"].apply(lambda x: x[1])
        sample_8_is_ds_df["chosen"] = sample_8_is_ds_df["chosen"].apply(lambda x: x[0])
        sample_8_is_ds_df["rejected"] = sample_8_is_ds_df["rejected"].apply(lambda x: x[0])

        sample_8_is_ds_df["is_weights"] = compute_importance_weights_per_prompt(
            sample_8_is_ds_df["rewards_w"],
            sample_8_is_ds_df["rewards_l"],
            sample_8_is_ds_df["prompt"],  # Pass the prompt series for grouping
            beta=beta,
            plus_bt_is_score=args.plus_bt_is_score
        )

        process_and_save_dataset(Dataset.from_pandas(sample_8_is_ds_df),
                                 os.path.join(args.target_dataset_base_path, f"iso/{num_sample}"))

if __name__ == '__main__':
    main()