from datasets import load_from_disk, Dataset
from tqdm import tqdm
import argparse


def combine_datasets_annotated(dataset1, dataset2, filename_output=None):
    """
    Combine two datasets with columns: prompt_id, model_name, annotation.
    Output: grouped by prompt_id, with a 'completions' column:
        completions = {model_name: annotation, ...}

    Assumption:
    - all datasets row are ordered the same way.
    """
    all_rows = []
    for i in tqdm(range(len(dataset1))):
        new_row = {"prompt_id": dataset1[i]["prompt_id"], "completions": {}}
        new_row["completions"][dataset1[i]["model"]] = dataset1[i]["annotation"]
        new_row["completions"][dataset2[i]["model"]] = dataset2[i]["annotation"]
        all_rows.append(new_row)

    # Convert to list for Dataset
    combined_dataset = Dataset.from_list(all_rows)

    if filename_output:
        combined_dataset.save_to_disk(filename_output)
        print(f"Combined dataset saved to {filename_output}")

    return combined_dataset


def get_choices(dataset):
    decisions = []
    for row in tqdm(dataset):
        # keys will be changed with the model_name
        keys = ["chosen", "rejected"]
        scores = {k: 0 for k in keys}
        present_aspects = {k: 0 for k in keys}
        for key in keys:
            if row["completions"].get(key) is None:
                continue

            for aspect, output in row["completions"][key].items():
                nonNone = False

                for score, weight in output.items():
                    if weight:
                        scores[key] += float(score) * float(weight)
                        nonNone = True

                if nonNone:
                    present_aspects[key] += 1

        scores = {
            k: v / present_aspects[k] if present_aspects[k] > 0 else None
            for k, v in scores.items()
        }

        invalid_counter = {
            "chosen": 4 - present_aspects["chosen"],
            "rejected": 4 - present_aspects["rejected"],
        }

        prompt_id = row["prompt_id"]
        decisions.append(
            {
                "prompt_id": prompt_id,
                "decision": scores["chosen"] > scores["rejected"]
                if scores["chosen"] is not None and scores["rejected"] is not None
                else None,
                "chosen_score": scores["chosen"],
                "rejected_score": scores["rejected"],
                "tie": scores["chosen"] == scores["rejected"],
                "invalid_chosen": invalid_counter["chosen"],
                "invalid_rejected": invalid_counter["rejected"],
            }
        )

    return decisions


def print_results(decisions):
    total = len(decisions)
    correct = sum(1 for d in decisions if d["decision"])
    ties = sum(1 for d in decisions if d["tie"])
    invalid = sum(d["invalid_chosen"] + d["invalid_rejected"] for d in decisions)
    counter = 0
    for d in decisions:
        if d["invalid_chosen"] == 4 or d["invalid_rejected"] == 4:
            counter += 1
    print(f"Invalid completions (4 invalid): {counter}")

    print(f"Total prompts: {total}")
    print(f"Correct decisions: {correct}")
    print(f"Accuracy: {correct / total if total > 0 else 0:.5f}")
    print(f"Ties: {ties}")
    print(f"Invalid: {invalid}")


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--foldername_chosen",
        type=str,
        XXXX-3",
        help="Path to the chosen completion annotations folder",
    )
    parser.add_argument(
        "--foldername_rejected",
        type=str,
        XXXX-3",
        help="Path to the rejected completion annotations folder",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    # paths to datasets, where the annotations are stored. They are generated by get_raw_annotations_v4c.py
    foldername_chosen = args.foldername_chosen
    foldername_rejected = args.foldername_rejected

    dataset_chosen = load_from_disk(foldername_chosen)
    print(dataset_chosen)

    dataset_rejected = load_from_disk(foldername_rejected)
    print(dataset_rejected)

    print("Combining datasets...")

    dataset_combined = combine_datasets_annotated(dataset_chosen, dataset_rejected)

    print(dataset_combined)

    print("Calculating results...")

    decisions = get_choices(dataset_combined)

    print_results(decisions)


if __name__ == "__main__":
    main()
