import os
import argparse
import pandas as pd

from experiments import run_experiment

def main():
    """
    CLI entrypoint. Iterates over `ncl_*` folders in --base_dir, runs experiments
    on each subfolder, and writes a CSV with best validation accuracy/epoch per subset.
    """
    parser = argparse.ArgumentParser(description="Multi-Dataset Training with CSV Results")
    parser.add_argument("--base_dir", type=str, default=f'./datasets/cub200',
                        help="Base directory containing ncl_i folders")
    parser.add_argument("--support_size", type=int, default=5,
                        help="Number of support images per class")
    parser.add_argument("--support_selection", type=str, choices=["nearest", "random"], default="nearest",
                        help="How to pick support images for each class: 'nearest' (pre-computed) or 'random' from training set")
    parser.add_argument("--prototype_construction", type=str, choices=["query_fusion", "aggregate"], default="query_fusion",
                        help="How to build per-class prototypes: 'query_fusion' (conditioned) or 'aggregate' (no fusion)")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--model_choice", type=str, default="resnet50",
                        choices=["resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
                                 "swin", "mobilevit", "vit", "eva02_base", "coatnet3_rw",
                                 "convnextv2_large", "convnextv2_tiny", "maxvit_small", "clip", "ensemble"],
                        help="Model architecture")
    parser.add_argument("--head_type", type=str, default="transformer", choices=["transformer", "linear"],
                        help="Type of classification head: 'transformer' (support-set based) or 'linear'")
    parser.add_argument("--test_tag", type=str, default="Debug",
                        help="Tag for the test run, used in results path")
    parser.add_argument("--swin_arch", type=str, default="swin_base_patch4_window7_224",
                        help="Define the Swin architecture to use (if applicable)")
    parser.add_argument("--num_classes", type=int, default=2, help="Default number of classes")
    parser.add_argument("--ncls", type=int, nargs="*", default=[],
                        help="List of ncl folder numbers to process (e.g., --ncls 2 4). If not provided, all folders are processed.")
    parser.add_argument("--freeze_backbone", action='store_true',
                        help="Whether to freeze the backbone model (default: False unless flag is set)")
    args = parser.parse_args()

    base_dir = args.base_dir
    ncl_folders = [f for f in sorted(os.listdir(base_dir)) if os.path.isdir(os.path.join(base_dir, f))]
    ncl_folders.sort(key=lambda x: int(x.split('_')[-1]))
    if args.ncls:
        filtered_folders = []
        for folder in ncl_folders:
            try:
                folder_num = int(folder.split('_')[-1])
                if folder_num in args.ncls:
                    filtered_folders.append(folder)
            except ValueError:
                continue
        ncl_folders = filtered_folders

    for ncl_folder in ncl_folders:
        ncl_path = os.path.join(base_dir, ncl_folder)
        try:
            num_classes = int(ncl_folder.split('_')[-1])
        except Exception:
            print(f"Could not extract number of classes from {ncl_folder}, using default: {args.num_classes}")
            num_classes = args.num_classes

        args.num_classes = num_classes
        print(f"Running experiments for {ncl_folder} with {args.num_classes} classes.")

        subset_folders = [f for f in sorted(os.listdir(ncl_path)) if os.path.isdir(os.path.join(ncl_path, f))]
        results = []
        for subset_folder in subset_folders:
            dataset_dir = os.path.join(ncl_path, subset_folder)
            log_dir_suffix = os.path.join(args.test_tag, ncl_folder, subset_folder)
            print(f"\nRunning experiment for dataset: {dataset_dir}")
            best_val_acc, best_epoch = run_experiment(dataset_dir, args, log_dir_suffix)
            results.append({
                "ncl_folder": ncl_folder,
                "subset": subset_folder,
                "num_classes": args.num_classes,
                "best_epoch": best_epoch,
                "best_val_acc": best_val_acc
            })
        results_df = pd.DataFrame(results)
        results_dir = os.path.join('results', args.test_tag, args.base_dir.split('/')[-1], ncl_folder)
        os.makedirs(results_dir, exist_ok=True)
        csv_path = os.path.join(results_dir, "results.csv")
        results_df.to_csv(csv_path, index=False)
        print(f"Saved results for {ncl_folder} to {csv_path}")

if __name__ == "__main__":
    main()
