import argparse
import json
import os

import numpy as np
import pandas as pd


def summarize(df: pd.DataFrame) -> dict:
    summary = {}

    summary["mean_distinct"] = np.mean(df["partition_scores"].map(len))
    summary["mean_utility"] = np.mean(df["utility"])

    return summary


def load_category_map(category_map_path: str) -> dict:
    """Load the category map and return a dictionary mapping id to category."""
    category_data = []
    with open(category_map_path, "r") as f:
        for line in f:
            category_data.append(json.loads(line))

    id_to_category = {}
    for entry in category_data:
        id_to_category[entry["id"]] = entry["category"]

    return id_to_category


def summarize_by_category(df: pd.DataFrame, id_to_category: dict) -> dict:
    """Generate category-wise summaries for data that has matching IDs in the category map."""
    # Add category information to the dataframe
    df_with_category = df.copy()
    df_with_category["category"] = df_with_category["id"].map(id_to_category)

    # Filter to only include rows that have a matching category
    df_categorized = df_with_category.dropna(subset=["category"])

    if len(df_categorized) == 0:
        return {}

    category_summaries = {}

    # Generate summary for each category in sorted order for consistent output
    for category in sorted(df_categorized["category"].unique()):
        category_df = df_categorized[df_categorized["category"] == category]
        category_summaries[category] = summarize(category_df)
        category_summaries[category]["count"] = len(category_df)

    return category_summaries


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--eval-dir", help="Directory containing evaluation files", required=True
    )
    parser.add_argument(
        "--id-range",
        type=int,
        nargs=2,
        default=(-1, -1),
        metavar=("lower", "upper"),
    )
    args = parser.parse_args()

    id_range = tuple(args.id_range)

    eval_dir = args.eval_dir.rstrip('"').lstrip('"')
    df = pd.read_json(os.path.join(eval_dir, "scores.jsonl"), lines=True)
    try:
        df["id_num"] = df["id"].str.extract(r"curated-(\d+)", expand=False).astype(int)
    except:
        pass
        # df["id_num"] = df["id"].str.extract(r"WildChat-(.*)", expand=False).astype(int)

    if id_range[0] != -1:
        df = df[lambda x: id_range[0] <= x.id_num]
    if id_range[1] != -1:
        df = df[lambda x: x.id_num <= id_range[1]]

    # Generate overall summary
    summary = summarize(df)

    # Load category map and generate category-wise summaries
    category_map_path = os.path.join(
        os.path.dirname(__file__), "data", "curated_id_to_category_map.jsonl"
    )

    if os.path.exists(category_map_path):
        try:
            id_to_category = load_category_map(category_map_path)
            category_summaries = summarize_by_category(df, id_to_category)

            if category_summaries:
                summary["category_wise"] = category_summaries
                print(
                    f"Generated category-wise summaries for {len(category_summaries)} categories"
                )
            else:
                print("No matching IDs found in category map for the current data")
        except Exception as e:
            print(
                f"Warning: Could not load category map or generate category summaries: {e}"
            )
    else:
        print(f"Warning: Category map not found at {category_map_path}")

    # Save the summary (now including category-wise data if available)
    with open(
        os.path.join(
            eval_dir,
            "summary.json"
            if id_range == (-1, -1)
            else f"summary_{id_range[0]}_{id_range[1]}.json",
        ),
        "w",
    ) as f:
        json.dump(summary, f, indent=2)


if __name__ == "__main__":
    main()
