import argparse
import datetime
import os
import shutil
import subprocess
from multiprocessing.pool import ThreadPool
from data_loader import _read_file_to_set
from data_loader import * 
from generate_a_program import learn

LEARN_SCRIPT = "generate_a_program.py"
BASE_DATASET_PATH = "datasets"

def efficient_learn(data: Data, target_graph: int) -> None:
    print(f"Learning a GDL program from a graph : {target_graph}")
    start = datetime.datetime.now()
    learn(data, target_graph)
    finish = datetime.datetime.now()
    elapsed = finish - start


def fast_learn(data: Data, target_graph: int, epsilon: float, python_executable: str = "python3") -> None:
    """
    Executes the learning script for a given dataset and graph.

    Args:
        dataset: The name of the dataset.
        graph: The graph ID.
        python_executable: The python executable to use.
    """
    #print(f"Learning a GDL program from a graph : {target_graph}")
    cmd = [python_executable, LEARN_SCRIPT, "-d", data.name, "-g", str(target_graph), "-e", str(epsilon), "-t", str(data.timeLimit)]
    print(f"Executing: {' '.join(cmd)}")
    try:
        subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=3600)  # 1 hour = 3600 seconds
    except subprocess.CalledProcessError as e:
        print(f"Error executing command for graph {target_graph}: {e}")
        print(f"Stderr: {e.stderr}")
        print(f"Stdout: {e.stdout}")
    except subprocess.TimeoutExpired as e:
        # e.timeout is set to the passed timeout
        print(f"Command timed out after {e.timeout}s for graph {target_graph}")
        print(f"Command: {' '.join(cmd)}")
        # These may be None if nothing was produced
        if getattr(e, "output", None):
            print(f"Stdout (partial): {e.output}")
        if getattr(e, "stderr", None):
            print(f"Stderr (partial): {e.stderr}")

def clean_learned_programs(dataset: str) -> None:
    """
    Removes previously learned GDL programs for a given dataset.

    Args:
        dataset: The name of the dataset.
    """
    learned_programs_path = os.path.join(BASE_DATASET_PATH, dataset, "learned_GDL_programs")
    if os.path.exists(learned_programs_path):
        print(f"Removing existing learned programs in {learned_programs_path}")
        shutil.rmtree(learned_programs_path)
    print(f"Creating directory: {learned_programs_path}")
    os.makedirs(learned_programs_path)


def main() -> None:
    """
    Main function to parse arguments and run the program generation.
    """
    parser = argparse.ArgumentParser(description="Generate GDL programs from a dataset.")
    parser.add_argument("-d", "--dataset", default="MUTAG", help="Input dataset: MUTAG, BBBP, BACE")
    parser.add_argument("-c", "--cores", type=int, default=1, help="Number of cores to use, default = 1")
    parser.add_argument("-e", "--epsilon", type=float, default=0.01, help="Input epsilon, default = 0.01")
    parser.add_argument("-t", "--timelimit", type=int, default=10, help="Input time limit, default = 10")
    args = parser.parse_args()

    dataset = args.dataset
    cores = args.cores
    if dataset == 'BBBP':
      data = load_BBBP()
    elif dataset == 'BACE':
      data = load_BACE()
    else:
      data = load_Data(dataset)
    print(f"Dataset: {dataset}")
    print(f"Cores: {cores}")
    
    # Hyper parameters
    data.timeLimit = args.timelimit
    data.expected = 1.0
    data.epsilon = len(data.train_graphs)* args.epsilon
    #data.epsilon = args.epsilon
    
    data.freeze()
    #data.timeLimit = 20
    
    clean_learned_programs(dataset)
    train_graphs = data.train_graphs 
    start_time = datetime.datetime.now()
    #'''
    #Fast learning
    if True:
        with ThreadPool(cores) as pool:
            pool.starmap(fast_learn, [(data, graph, args.epsilon) for graph in train_graphs])
    else:
        with ThreadPool(cores) as pool:
            pool.starmap(efficient_learn, [(data, graph) for graph in train_graphs])
    #'''
    end_time = datetime.datetime.now()
    elapsed_time = end_time - start_time

    print("\n" + "=" * 38)
    print(f"# Used cores: {cores}")
    print(f"Total training time: {elapsed_time}")
    print("=" * 38)


if __name__ == "__main__":
    main()
