import pathlib

import pandas as pd

DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / "data"


def clean(df):
    df = df.apply(extract_attachments, axis=1)
    df = df.apply(extract_data_row, axis=1)
    df = df.apply(extract_projects, axis=1)
    df = df.drop(
        columns=[
            "attachments",
            "data_row",
            "embeddings",
            "media_attributes",
            "metadata_fields",
            "projects",
        ]
    )
    df["label_median"] = df.filter(regex=r"label_\d").median(axis=1)
    df = df.dropna(subset="label_median")
    return df


def extract_attachments(row):
    value = row["attachments"][0]["value"]
    attack_method_str = ", attack method: "
    row["model"] = value[len("model: ") : value.index(attack_method_str)]
    question_str = " \n Question: "
    row["jailbreak"] = value[
        len(attack_method_str) + value.index(attack_method_str) : value.index(question_str)
    ]
    row["forbidden_prompt"] = value[len(question_str) + value.index(question_str) :].strip()
    return row


def extract_data_row(row):
    row_data = row["data_row"]["row_data"]
    answer_str = "\nAnswer:"
    row["response"] = row_data[len(answer_str) + row_data.index(answer_str) :].strip()
    return row


def extract_projects(row):
    labels = list(row["projects"].values())[0]["labels"]
    for i, label in enumerate(labels):
        score = label["annotations"]["classifications"][0]["radio_answer"]["value"]
        row[f"label_{i}"] = (int(score) - 1) / 4  # min-max scale to 0-1

    return row


if __name__ == "__main__":
    df = pd.read_json(str(DATA_DIR / "raw" / "labelbox.ndjson"), lines=True)
    df = clean(df)
    df.to_csv(DATA_DIR / "processed" / "labelbox_2.csv", index=False)
