import argparse
import os
import sys
import yaml

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from scipy.stats import beta


sys.path.append(os.getcwd())

from src.scripts.annotate import annotate, gold_reward_model_score

if __name__ == "__main__":   
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset_path", type=str, help="Dataset containing reward and (optionally) Claude and Gold RM preferences"
    )
    parser.add_argument(
        "--batch_size", "-b", type=int, default=2, help="Number of samples to annotate in a single prompt"
    )
    parser.add_argument("--n_jobs", "-n", type=int, default=8, help="Number of samples to annotate in a single prompt")
    parser.add_argument(
        "--kl_values",
        "-k",
        type=int,
        nargs="+",
        default=[0, 2, 5, 10, 15, 25, 75, 250],
        help="KL values to plot in the ascending order",
    )
    parser.add_argument(
        "--sample",
        "-s",
        type=int,
        default=2000,
        help="The number of samples to draw from each KL group. Keep all samples if None",
    )
    parser.add_argument(
        "--model_id",
        "-m",
        type=str,
        default="anthropic.claude-instant-v1",
        help="Model ID as listed in AWS Bedrock",
        choices=["anthropic.claude-v2", "anthropic.claude-instant-v1"],
    )
    parser.add_argument(
        "--gold_rm",
        "-g",
        type=str,
        default="f",
        help="Whether to also annotate with Gold RM",
    )
    args = parser.parse_args()

    df = pd.read_json(args.dataset_path)

    KLs = args.kl_values
    KLs_shift = np.roll(KLs, 1).astype(float)
    KLs_shift[0] = -1e-5
    interval_widths = np.min([np.sqrt(KLs - KLs_shift), (KLs - KLs_shift - 1e-6) / 2], axis=0)

    intervals = np.concatenate([KLs - interval_widths, KLs + interval_widths])
    intervals.sort()
    bins = np.digitize(df["sample_kl"].values, intervals)
    bins = np.where(bins % 2 == 1, bins, 0)
    df["rm_preference"] = df["reward"].map(lambda x: 1 if x > 0 else 0)
    df["bins"] = bins

    if args.sample is not None:
        sampled_df = df[df["bins"] != 0].sample(frac=1, random_state=42).groupby("bins").head(args.sample)
    else:
        sampled_df = df[df["bins"] != 0]

    if "claude_preference" not in df:
        sampled_df = annotate(sampled_df, batch_size=args.batch_size, n_jobs=args.n_jobs, modelId=args.model_id)
        df.loc[sampled_df.index, "claude_preference"] = sampled_df["claude_preference"]
        df.to_json(args.dataset_path, orient="records")
    elif df["claude_preference"].isna().any():
        annotated_df = annotate(
            sampled_df[sampled_df["claude_preference"].isna()], batch_size=args.batch_size, n_jobs=args.n_jobs
        )
        df.update(annotated_df)
        df.to_json(args.dataset_path, orient="records")

    def bernoulli_errorbars(arr, a=1, b=1, ci=0.95):
        pos = sum(arr)
        neg = len(arr) - pos
        return beta.ppf(1 - ci, pos + a, neg + b), beta.ppf(ci, pos + a, neg + b)

    if "t" in args.gold_rm:
        CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2-xl")
        gold_rm_directory = yaml.load(open(os.path.join("configs", CONFIG_DIR, "ppo.yaml")), yaml.Loader)["gold_rm_directory"]
        filtered_df = df if "gold_rm_score_output" not in df else df[df["gold_rm_score_output"].isna()]
        filtered_df = filtered_df[filtered_df["bins"] != 0]
        if len(filtered_df) > 0:
            gold_df = gold_reward_model_score(filtered_df, gold_rm_directory)
            df = pd.merge(df, gold_df, how="left")
            df.to_json(args.dataset_path, orient="records")
        df["gold_rm_preference"] = (np.sign(df["gold_rm_score_output"] - df["gold_rm_score_label"]) + 1) / 2

    df = df[df["bins"] != 0]
    df.set_index("bins", inplace=True)

    columns = ["rm_preference", "claude_preference"] + (["gold_rm_preference"] if "t" in args.gold_rm else [])
    renamed_columns_map = {
        "rm_preference": "Reward Model preference",
        "claude_preference": "Actual preference",
        **({"gold_rm_preference": "Gold RM preference"} if "t" in args.gold_rm else {}),
    }

    ticks = sorted(df.index.unique())
    plt.xticks(ticks, labels=KLs[: len(ticks)])
    plt.xlabel("KL Divergence")
    plt.ylabel("Preference against the human summary")
    sns.lineplot(
        data=df[columns].rename(columns=renamed_columns_map),
        markers=True,
        errorbar=bernoulli_errorbars,
    )
    print(
        df[columns]
        .groupby(df.index)
        .mean()
        .set_index(pd.Series(KLs[: len(ticks)], name="KL"))
        .rename(columns=renamed_columns_map)
    )

    fig_dir = os.path.dirname(args.dataset_path).replace("datasets/ppo", "figs")
    os.makedirs(fig_dir, exist_ok=True)
    plt.savefig(f"{fig_dir}/alignment.pdf")
