import argparse
from pathlib import Path
import subprocess
import sys
import time
import torch

def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def run_command(cmd, description):
    print(f"\n=== Running {description} ===")
    try:
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error during {description}: {str(e)}")
        sys.exit(1)
    finally:
        clear_gpu_memory()
        time.sleep(5)

def run_pipeline(api_key, model_name, dataset, expansion_factor, k, dont_train, layer_stride, layer_group_size, flags=None):
    flags = flags or []

    
    base_params = [
        f"--model_name={model_name}",
        f"--dataset={dataset}",
        f"--expansion_factor={expansion_factor}",
        f"--k={k}",
        f"--layer_stride={layer_stride}",
        f"--layer-group-size={layer_group_size}"
    ] + flags


    # #Training uses different split than other steps
    if not dont_train:
        train_cmd = ["python", "train.py"] + base_params
        run_command(train_cmd, "Training")

    # All other steps use eval split
    other_params = base_params
    for cmd, desc in [
        (["python", "sae_bench_evals.py"], "SAE Bench Evaluation"),
        (["python", "save_activations.py"], "Save Activations"),
        (["python", "find_max_activations.py"], "Find Max Activations"),
        (["python", "auto_interp.py"] + [f"--api_key={api_key}", "--offline_explainer"], "Auto Interpretation"),
    ]:
        run_command(cmd + other_params, desc)

def main():
    parser = argparse.ArgumentParser(description='Run full SAE pipeline')
    parser.add_argument('--api_key', type=str, required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--expansion_factor', type=int, required=True)
    parser.add_argument('--k', type=int, required=True)
    parser.add_argument('--rerandomize', action='store_true')
    parser.add_argument('--rerandomize_embeddings', action='store_true')
    parser.add_argument('--use_step0', action='store_true')
    parser.add_argument('--use_random_control', action='store_true')
    parser.add_argument('--dont_train', action='store_true')
    parser.add_argument('--layer_stride', type=int, default=1)
    parser.add_argument('--layer-by-layer', action='store_true', help='Enable layer-by-layer training mode')
    parser.add_argument('--layer_group_size', type=int, default=1, help='Grouping method for layer-by-layer training')
    args = parser.parse_args()
    
    flags = []
    if args.rerandomize:
        flags.append("--rerandomize")
    if args.rerandomize_embeddings:
        flags.append("--rerandomize_embeddings")
    if args.use_step0:
        flags.append("--use_step0")
    if args.use_random_control:
        flags.append("--use_random_control")
    if args.layer_by_layer:
        flags.append("--layer-by-layer")
    
    run_pipeline(args.api_key, args.model_name, args.dataset, 
                args.expansion_factor, args.k, args.dont_train, args.layer_stride, args.layer_group_size, flags)
    
    
    
    

if __name__ == "__main__":
    main()
