from pathlib import Path
import subprocess
import os
import json

# Configurable root directory for all models (relative to project root)
MODEL_ROOT = Path("data_files/nlp_training_faster_learning_rate/debug/yelp_polarity")
# MODEL_ROOT = Path("data_files/saved_model_inprog")
MODEL_DIR_NAME = "final_model"  # Look for this directory at any depth

# Static arguments for the attack script (relative to project root)
ATTACK_SCRIPT = "TextFooler/attack_classification_with_budget.py"
STATIC_ARGS = [
    # "--dataset_name", will be set dynamically
    "--max_attack_changes", "10",
    "--target_model", "seq_classifier",
    "--attack_sample_size", "100",
    "--counter_fitting_embeddings_path", "data_files/TextFooler/embeddings/counter-fitted-vectors.txt",
    "--counter_fitting_cos_sim_path", "data_files/TextFooler/vocab_cosine_sim/ag_cosine_sim_ag.npy",
    "--USE_cache_path", "data_files/TextFooler/USE",
    # "--device", "cuda",
    "--device", "mps",
    "--seed", "42",
    "--use_amp",
    "--batch_size", "16",
    "--max_seq_length", "256",
]

PROJECT_ROOT = Path(__file__).parent.resolve()

def find_model_dirs(root: Path, model_dir_name: str):
    # Recursively find all directories named 'final_model' under root
    return [p for p in root.rglob(model_dir_name) if p.is_dir()]

def get_dataset_name(model_dir: Path):
    # Look for experiment_config.json in the parent of final_model
    config_path = model_dir.parent / "experiment_config.json"
    if not config_path.exists():
        print(f"Warning: {config_path} not found. Skipping {model_dir}.")
        return None
    try:
        with open(config_path, "r") as f:
            config = json.load(f)
        # Try to get dataset_info.name
        dataset_name = config.get("dataset_info", {}).get("name")
        if not dataset_name:
            print(f"Warning: dataset_info.name not found in {config_path}. Skipping {model_dir}.")
            return None
        return dataset_name
    except Exception as e:
        print(f"Error reading {config_path}: {e}. Skipping {model_dir}.")
        return None

def main():
    model_dirs = find_model_dirs(MODEL_ROOT, MODEL_DIR_NAME)
    print(f"model_paths: {model_dirs}")
    print(f"Found {len(model_dirs)} models to attack.")

    # assert False, 'breakpoint'
    for model_dir in model_dirs:
        dataset_name = get_dataset_name(model_dir)
        print(f"model_dir: {model_dir} dataset_name: {dataset_name} ")
        if not dataset_name:
            continue
        cmd = [
            "python3", ATTACK_SCRIPT,
            "--target_model_path", str(model_dir),
            "--dataset_name", dataset_name,
            *STATIC_ARGS
        ]
        env = os.environ.copy()
        env["PYTHONPATH"] = str(PROJECT_ROOT)
        print(f"env: {env} ")
        print("Running:", " ".join(cmd))
        subprocess.run(cmd, cwd=PROJECT_ROOT, env=env)

if __name__ == "__main__":
    main() 