import argparse
import json
from pathlib import Path
import pandas as pd
import re


def load_json(path: Path) -> dict:
    with open(path, "r") as f:
        return json.load(f)


def extract_info_from_path(path: Path) -> dict:
    """
    Extract information from the path of a result file.
    Example path: results_rearranged/gpt-4o/zsr_prompt/one2many/marked-copy/100_test_42.json
    """
    parts = path.parts
    print(parts)

    # Extract base components from path
    if len(parts) == 6:
        model = parts[1]
        prompt_type = parts[2]
        encoding_technique = parts[3]
        task = parts[4]
        k = None
        sample_size, test_set, seed = parts[5].split("_")
        seed = seed.split(".")[0]
    elif len(parts) == 7:
        model = parts[1]
        prompt_type = parts[2]
        encoding_technique = parts[3]
        k = parts[4].replace("k=", "")
        task = parts[5]
        sample_size, test_set, seed = parts[6].split("_")
        seed = seed.split(".")[0]

    return {
        "model": model,
        "prompt_type": prompt_type,
        "encoding_technique": encoding_technique,
        "task": task,
        "sample_size": sample_size,
        "test_set": test_set,
        "seed": seed,
        "k": k,
    }


def aggregate_results(input_dir: Path, output_path: Path):
    """
    Aggregate results from JSON files in input_dir and save to output_path.
    """
    results = []

    # Find all JSON files
    json_files = list(input_dir.glob("**/*.json"))
    print(f"Found {len(json_files)} JSON files to process.")

    for json_file in json_files:
        # Extract info from path
        path_info = extract_info_from_path(json_file)

        # Load the JSON data
        data = load_json(json_file)

        # Extract the accuracy
        accuracy = data.get("accuracy")

        # Combine path info and accuracy
        result = {**path_info, "accuracy": accuracy}
        results.append(result)

    # Create dataframe
    df = pd.DataFrame(results)

    # Save to CSV
    output_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_path, index=False)
    print(f"Results saved to {output_path}")

    # Print some summary statistics
    print("\nSummary Statistics:")
    print(f"Total tasks: {df['task'].nunique()}")
    print(f"Models: {df['model'].unique()}")
    print(f"Prompt types: {df['prompt_type'].unique()}")
    print(f"Encoding techniques: {df['encoding_technique'].unique()}")

    # Group by relevant dimensions and calculate mean accuracy
    grouped = df.groupby(["model", "prompt_type", "encoding_technique", "task"]).agg(
        {"accuracy": ["mean", "std", "count"]}
    )

    print("\nMean accuracy by model/prompt/encoding/task:")
    print(grouped)

    # Report results by k-value for one2many and one2many-sym
    k_df = df[df["encoding_technique"].isin(["one2many", "one2many-sym"])]

    if not k_df.empty and "k" in k_df.columns and k_df["k"].notna().any():
        # Convert k to numeric for proper sorting
        k_df["k"] = pd.to_numeric(k_df["k"])

        # Group by k and other dimensions
        k_grouped = (
            k_df.groupby(["model", "prompt_type", "encoding_technique", "task", "k"])
            .agg({"accuracy": ["mean", "std", "count"]})
            .sort_index()
        )

        print("\nMean accuracy by k-value (for one2many and one2many-sym):")
        print(k_grouped)

        # Create pivot table for easier comparison across k values
        print("\nPivot table of accuracy by task and k-value:")
        for model in k_df["model"].unique():
            for prompt in k_df["prompt_type"].unique():
                for encoding in k_df["encoding_technique"].unique():
                    subset = k_df[
                        (k_df["model"] == model)
                        & (k_df["prompt_type"] == prompt)
                        & (k_df["encoding_technique"] == encoding)
                    ]

                    if not subset.empty:
                        pivot = pd.pivot_table(
                            subset,
                            values="accuracy",
                            index=["task"],
                            columns=["k"],
                            aggfunc="mean",
                        )
                        print(f"\n{model}, {prompt}, {encoding}:")
                        print(pivot)

                        # Calculate average across tasks
                        print(f"Average across tasks: {pivot.mean().to_dict()}")

    return df


def main():
    parser = argparse.ArgumentParser(
        description="Aggregate results from rearranged results"
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        required=True,
        help="Directory containing rearranged results",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        required=True,
        help="Path to save the aggregated results",
    )
    args = parser.parse_args()

    input_dir = Path(args.input_dir)
    output_path = Path(args.output_path)

    aggregate_results(input_dir, output_path)


if __name__ == "__main__":
    main()
