import os
import json
import argparse
import pandas as pd
from tqdm import tqdm
from itertools import product
import subprocess

def list_directories(path="plp/programs"):
    return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]

def get_programs(programs_file):
    with open(programs_file, 'r') as file:
        programs_names = [line.strip() for line in file if line.strip() and not line.strip().startswith(('#', '%', '//'))]
    return programs_names

def run_command(command, timeout):
    try:
        result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=timeout)
        if result.returncode != 0:
            print(f"Error running command: {command}")
            print(result.stderr)
            exit(1)
        return result.stdout
    except subprocess.TimeoutExpired:
        print(f"Command timed out: {command}")
        return None

def apply_sdd_compiler(program_path, program_name, is_init, is_non_incremental,
    vtree_type, minimize, X_deterministic, alpha, timeout):

    sufix = "_init" if is_init else ""
    sufix += "_non-incremental" if is_non_incremental else ""
    json_file = os.path.join(program_path, f"{program_name}{sufix}.json")
    minimize_flag = "--minimize" if minimize else ""
    X_deterministic_flag = "--X_deterministic" if X_deterministic else ""
    alpha_flag = f"--alpha {alpha}" # Only used if minimize is True

    # Run sdd_compiler.py
    return run_command(f"python sdd_compiler.py {json_file} {vtree_type} {minimize_flag} {X_deterministic_flag} {alpha_flag}", timeout) is not None

def write_broken_run(results, program_name, config_name, timeout):
    results["program"][-1] = program_name
    results["config"][-1] = config_name
    results["circuit_node_size"][-1] = -1
    results["circuit_edge_size"][-1] = -1
    results["model_count"][-1] = -1
    results["compression_rate"][-1] = -1
    results["compilation_time"][-1] = timeout

def main():
    parser = argparse.ArgumentParser(description="Run SDD compilation experiments.")
    parser.add_argument("base_path", type=str, help="Path to the base directory containing program directories.")
    parser.add_argument("programs_file", nargs='?', help="File containing list of programs to process. If not provided, all programs in the base directory will be processed.")
    parser.add_argument("output_dir", type=str, help="Path to the directory where results will be saved.")
    parser.add_argument("--both_init", action="store_true", help="Run both initialization heuristic and non-initialization heuristic. If True, use both True and False.")
    parser.add_argument("--use_non_inc", action="store_true", help="Use non-incremental compilation. If True, use both True and False.")
    parser.add_argument("--vtree_type", type=str, default="b", help="Type of vtree to use. If 'both', use both 'r' and 'b'.")
    parser.add_argument("--unconstrained", action="store_true", help="Use X_deterministic option. If True, use both True and False.")
    parser.add_argument("--minimize", action="store_true", help="Use minimize option. If True, use both True and False.")
    parser.add_argument("--alpha", type=float, default=2.0, help="Threshold to apply dynamic minimization.")
    parser.add_argument("--time_wall", type=int, default=1800, help="Maximum time (in seconds) for each execution.")
    args = parser.parse_args()

    base_path = args.base_path
    programs = get_programs(args.programs_file) if args.programs_file else list_directories(base_path)

    init_options = [False, True] if args.both_init else [True]
    non_inc_options = [False, True] if args.use_non_inc else [True]
    vtree_types = ["r", "b"] if args.vtree_type == "both" else [args.vtree_type]
    X_deterministic_options = [False, True] if args.unconstrained else [True]
    minimize_options = [False, True] if args.minimize else [False]
    alpha = args.alpha
    time_wall = args.time_wall

    vtree_map = {"r": "right", "b": "balanced"}

    results = {
        "program": [],
        "config": [],
        "circuit_node_size": [],
       "circuit_edge_size": [],
        "model_count": [],
        "compression_rate": [],
        "compilation_time": [],
        "max_memory_usage_mb": []
    }

    # Apply sdd_compiler.py with all combinations of configurations
    for use_init, use_non_inc, vtree_type, X_deterministic, minimize in tqdm(product(init_options, non_inc_options, vtree_types, X_deterministic_options, minimize_options), desc="Configurations"):
        sufix = "_init" if use_init else ""
        sufix += "_non-incremental" if use_non_inc else ""

        # If one execution of the program breaks the time wall, skip the
        # following programs (assumes that the following programs are larger and
        # will take longer to execute)
        time_wall_broken = False

        for program_name in tqdm(programs, leave=False, desc="Programs"):
            config_name = f"{sufix}_{vtree_map[vtree_type]}{'_Xdet' if X_deterministic else ''}{'_min' if minimize else ''}"
            config_name = config_name[1:] if config_name[0] == "_" else config_name

            if time_wall_broken:
                write_broken_run(results, program_name, config_name, time_wall)
                continue  # Skip execution if the time wall was previously broken

            program_path = os.path.join(base_path, program_name)

            if not apply_sdd_compiler(program_path, program_name, use_init, use_non_inc, vtree_type, minimize, X_deterministic, alpha, time_wall):
                time_wall_broken = True
                write_broken_run(results, program_name, config_name, time_wall)
                continue

            stats_file = f"{program_path}/stats/{program_name}_{config_name}_stats.json"

            if not os.path.exists(stats_file):
                time_wall_broken = True
                # Throw an error, because the stats file should exist
                print(f"Error: stats file not found for {program_name} with config {config_name}")

            with open(stats_file, 'r') as f:
                stats = json.load(f)

            results["program"].append(program_name)
            results["config"].append(config_name)
            results["circuit_node_size"].append(stats.get("circuit_size", {}).get("nodes", -1))
            results["circuit_edge_size"].append(stats.get("circuit_size", {}).get("edges", -1))
            results["model_count"].append(stats.get("model_count", -1))
            results["compression_rate"].append(stats.get("compression_rate", -1))
            results["compilation_time"].append(stats.get("compilation_time", -1))
            results["max_memory_usage_mb"].append(stats.get("max_memory_usage_mb", -1))

    df = pd.DataFrame(results)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    subset_name = os.path.splitext(os.path.basename(args.programs_file))[0] if args.programs_file else "tmp"
    experiments_dir = os.path.join(args.output_dir, subset_name)
    if not os.path.exists(experiments_dir):
        os.makedirs(experiments_dir)

    sufix = "_init-both" if args.both_init else "_init"
    sufix += "_non-inc" if args.use_non_inc else ""
    sufix += f"_vt-{args.vtree_type}" if args.vtree_type != "b" else "_vt-b"
    sufix += "_Xdet-Unconstrained" if args.unconstrained else "_Xdet"
    sufix += "_min" if args.minimize else ""

    df.to_csv(f"{experiments_dir}/sdd_results{sufix}.csv", index=False)

if __name__ == "__main__":
    main()
