from pathlib import Path
import pandas as pd
import json
from typing import List

def find_folders_with_adv_attack_results(root: Path) -> List[Path]:
    """
    Recursively find all subfolders under root that contain an 'adv_attack_results' directory.
    Returns a list of Path objects (the parent folders).
    """
    return [p.parent for p in root.rglob('adv_attack_results') if p.is_dir()]

def read_and_combine_results(model_folder: Path) -> List[pd.DataFrame]:
    """
    Given a model folder (parent of adv_attack_results), read the experiment_config.json
    and all CSVs in adv_attack_results, and return a list of DataFrames with extra columns.
    """
    # Read JSON
    config_path = model_folder / "experiment_config.json"
    with open(config_path, "r") as f:
        config = json.load(f)
    # Extract needed info
    model_name = config.get("args", {}).get("model_name")
    num_epochs = config.get("args", {}).get("num_epochs")
    total_params = config.get("model_info", {}).get("total_params")
    trainable_params = config.get("model_info", {}).get("trainable_params")

    # Find all CSVs under adv_attack_results (recursively)
    adv_dir = model_folder / "adv_attack_results"
    csv_files = list(adv_dir.rglob("*.csv"))
    dfs = []
    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        # Add columns
        df["model_name"] = model_name
        df["num_epochs"] = num_epochs
        df["total_params"] = total_params
        df["trainable_params"] = trainable_params
        df["csv_path"] = str(csv_file)
        df['ft_method'] = df['csv_path'].apply(lambda x: x.split('/')[3])
        dfs.append(df)
    return dfs

def aggregate_all_attack_results(root: Path) -> pd.DataFrame:
    """
    Finds all model folders with adv_attack_results under root, reads and combines all results,
    and returns a single concatenated DataFrame.
    """
    model_folders = find_folders_with_adv_attack_results(root)
    all_dfs = []
    for folder in model_folders:
        dfs = read_and_combine_results(folder)
        all_dfs.extend(dfs)
    if all_dfs:
        return pd.concat(all_dfs, ignore_index=True)
    else:
        return pd.DataFrame()  # Empty DataFrame if nothing found

# Example usage (uncomment to use as a script):
# if __name__ == "__main__":
#     root = Path("data_files/attack_results/saved_model_inprog")
#     df = aggregate_all_attack_results(root)
#     print(df.head()) 