import os
import json
import argparse

def main(base_path):
    best_acc = -1
    best_entries = []

    for base_dir in os.listdir(base_path):
        full_base_dir = os.path.join(base_path, base_dir)
        if not os.path.isdir(full_base_dir) or not base_dir.startswith("from_base_"):
            continue

        for continual_dir in os.listdir(full_base_dir):
            full_continual_dir = os.path.join(full_base_dir, continual_dir)
            if not os.path.isdir(full_continual_dir) or not continual_dir.startswith("continual_"):
                continue

            log_path = os.path.join(full_continual_dir, "log.txt")
            if not os.path.isfile(log_path):
                continue
 
            try:
                with open(log_path, "r") as f:
                    for line in f:
                        log_entry = json.loads(line)
                        if args.mc: 
                            acc = log_entry.get("test_acc1", -1)
                            f1_weight = log_entry.get("test_f1_weighted", -1)
                            f1_macro = log_entry.get("test_f1_macro", -1)
                            auroc = log_entry.get("test_auroc", -1)

                            if acc > best_acc:
                                best_acc = acc
                                best_entries = [{
                                    "base_ckpt": base_dir,
                                    "continual_ckpt": continual_dir,
                                    "acc": acc,
                                    "f1_weight": f1_weight,
                                    "f1_macro": f1_macro,
                                    "auroc": auroc,
                                    "log_path": log_path
                                }]
                            elif acc == best_acc:
                                best_entries.append({
                                    "base_ckpt": base_dir,
                                    "continual_ckpt": continual_dir,
                                    "acc": acc,
                                    "f1_weight": f1_weight,
                                    "f1_macro": f1_macro,
                                    "auroc": auroc,
                                    "log_path": log_path
                                })
                        else:
                            acc = log_entry.get("test_acc1", -1)
                            balanced_acc = log_entry.get("test_balanced_acc", -1)
                            f1 = log_entry.get("test_f1", -1)
                            auroc = log_entry.get("test_auroc", -1)

                            if acc > best_acc:
                                best_acc = acc
                                best_entries = [{
                                    "base_ckpt": base_dir,
                                    "continual_ckpt": continual_dir,
                                    "acc": acc,
                                    "balanced_acc": balanced_acc,
                                    "f1": f1,
                                    "auroc": auroc,
                                    "log_path": log_path
                                }]
                            elif acc == best_acc:
                                best_entries.append({
                                    "base_ckpt": base_dir,
                                    "continual_ckpt": continual_dir,
                                    "acc": acc,
                                    "balanced_acc": balanced_acc,
                                    "f1": f1,
                                    "auroc": auroc,
                                    "log_path": log_path
                                })
            except Exception as e:
                print(f"Failed to process {log_path}: {e}")

    if not best_entries:
        print("No valid log entries with test_acc1 found.")
        return

    print(f"\nBest test_acc1: {best_acc:.3f} found in {len(best_entries)} run(s).")
    print("All matching entries:")

    if args.mc:
        for entry in best_entries:
            print(f"- {entry['base_ckpt']}/{entry['continual_ckpt']}")
            print(f"  ↳ test_acc1: {entry['acc']:.3f}, test_f1_weight: {entry['f1_weight']:.3f}, test_f1_macro: {entry['f1_macro']:.3f}, test_auc: {entry['auroc']:.3f}")
            print(f"  ↳ Log path: {entry['log_path']}")
            
    else:
        for entry in best_entries:
            print(f"- {entry['base_ckpt']}/{entry['continual_ckpt']}")
            print(f"  ↳ test_acc1: {entry['acc']:.3f}, test_balanced_acc: {entry['balanced_acc']:.3f}, test_f1: {entry['f1']:.3f}, test_auc: {entry['auroc']:.3f}")
            print(f"  ↳ Log path: {entry['log_path']}")
            

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Find best test_acc1 entries and their balanced accuracy.")
    parser.add_argument("--base_path", type=str, required=True, help="Root directory containing from_base_* folders")
    parser.add_argument("--mc", action="store_true", required=False, help="Optional argument to specify the classification task (e.g., 'mc' for multi-class).")
    args = parser.parse_args()
    main(args.base_path)
