import os
import json
import pandas as pd
import ast

import matplotlib.pyplot as plt
from matplotlib import rcParams

import argparse
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt

import numpy as np

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="output")
    parser.add_argument("--model", type=str, default=None)
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--percentile", type=float, default=0.9999)
    args = parser.parse_args()
    output_dir = args.output_dir
    input_file = args.input_file

    with open(input_file) as f:
        data = json.load(f)

    os.makedirs(output_dir, exist_ok=True)

    all_convs_new = []
    convs = []
    for row in data:
        conv = ""
        for turns in row["conversation_a"]:
            if turns["role"] == "user":
                conv += f"{turns['content']}\n"

        convs.append(conv[:10000])
        row["post_process_conv"] = conv[:10000]
        all_convs_new.append(row)

    df = pd.DataFrame(all_convs_new)
    print("Number of conversations: ", len(df))

    prompt_counts = df["post_process_conv"].value_counts()
    top_prompts = prompt_counts.head(20)
    print(top_prompts)

    percentile_cutoff = prompt_counts.quantile(args.percentile)
    print(f"{args.percentile*100} percentile count: {percentile_cutoff}")

    high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index
    print(
        f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}"
    )

    dedup_tags = np.array(
        [{"high_freq": False, "sampled": True} for _ in range(len(df))]
    )
    high_freq_groups = df.groupby("post_process_conv")
    for prompt in tqdm(high_frequency_prompts):
        df_high_freq = high_freq_groups.get_group(prompt)
        sampled_indices = df_high_freq.sample(
            n=int(percentile_cutoff), random_state=42
        ).index
        dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False}
        dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True}

    df["dedup_tag"] = dedup_tags

    df = df.drop(columns=["post_process_conv"])

    df.to_json(
        os.path.join(output_dir, "dedup.json"),
        orient="records",
        indent=4,
        force_ascii=False,
    )
