import pandas as pd
import os
import argparse
from config.config import PREDICTIONS_TEST_INTERNAL, PREDICTIONS_TEST_EXTERNAL


def main(input_folder):
    output_file = os.path.join(input_folder, "model_subset_accuracy_mean_std.csv")
    model_names = [
        "deit_tiny_patch16_224",
        "deit_small_patch16_224",
        "deit_base_patch16_224",
        "resnet101",
        "resnet50",
        "shufflenet_v2_x1_0",
    ]

    results = []

    for model_name in model_names:
        file_path = os.path.join(input_folder, f"{model_name}.csv")

        if not os.path.exists(file_path):
            print(f"Warning: File not found for model '{model_name}', skipping.")
            continue

        df = pd.read_csv(file_path)

        true_class_col = [col for col in df.columns if col.endswith("_true_class")][0]
        pred_class_col = [col for col in df.columns if col.endswith("_pred_class")][0]

        df = df.dropna(subset=[true_class_col, pred_class_col])

        subset_acc = df.groupby("subset").apply(
            lambda g: (g[true_class_col] == g[pred_class_col]).mean()
        )

        # Debug print
        print(f"\nModel: {model_name}")
        print(subset_acc)

        mean_acc = subset_acc.mean()
        std_acc = subset_acc.std()

        results.append({
            "model": model_name,
            "mean_subset_accuracy": mean_acc,
            "std_subset_accuracy": std_acc,
            "subset_A": subset_acc.get("subset_A", float('nan')),
            "subset_B": subset_acc.get("subset_B", float('nan')),
            "subset_C": subset_acc.get("subset_C", float('nan')),
        })

    # Save to CSV
    summary_df = pd.DataFrame(results)
    summary_df.sort_values(by="mean_subset_accuracy", ascending=False, inplace=True)
    summary_df.to_csv(output_file, index=False)

    print(f"\nPer-subset accuracy summary saved to:\n{output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute per-subset accuracy for models.")
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)"
    )

    args = parser.parse_args()

    if args.dataset_type == "INTERNAL":
        folder_path = PREDICTIONS_TEST_INTERNAL
    else:  # EXTERNAL
        folder_path = PREDICTIONS_TEST_EXTERNAL

    main(folder_path)