import itertools
import subprocess
import json
import os
from tqdm import tqdm


def load_config(config_path):
    """Load configuration from a JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)


def ensure_directory_exists(directory):
    """Ensure the directory exists, create it if it doesn't."""
    if not os.path.exists(directory):
        os.makedirs(directory)


def train_model(script_name, nu, lambd, building, T0, total_timesteps, save_folder, verbose):
    """
    Function to train the model for given nu, lambd, and building.
    Calls the training script using subprocess.
    """
    filename = "./ippo_models/ippo"
    command = [
        "python", script_name,
        "--T0", str(T0),
        "--total_timesteps", str(total_timesteps),
        "--filename", filename,
        "--verbose", str(verbose),
        "--building", str(building),
        "--nu", str(nu), 
        "--lambd", str(lambd)
    ]
    try:
        subprocess.run(command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error training model with nu={nu}, lambd={lambd}, building={building}: {e}")


def main(config_path):
    # Load configuration
    config = load_config(config_path)

    # Extract parameters from the configuration
    nu_range = config["nu_range"]
    lambd_range = config["lambd_range"]
    buildings = config["buildings"]
    T0 = config["T0"]
    total_timesteps = config["total_timesteps"]
    save_folder = "ippo_models"
    verbose = config["verbose"]
    script_name = config["script_name"]

    # Ensure the save folder exists
    ensure_directory_exists(save_folder)

    # Generate all combinations of parameters
    param_combinations = list(itertools.product(nu_range, lambd_range, buildings))
    total_combinations = len(param_combinations)

    # Train models for all parameter combinations
    with tqdm(total=total_combinations, desc="Training Models") as pbar:
        for nu, lambd, building in param_combinations:
            train_model(script_name, nu, lambd, building, T0, total_timesteps, save_folder, verbose)
            pbar.update(1)

    print(f"Grid search completed. Models are saved in the '{save_folder}' folder.")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Train models using a grid search over parameters.")
    parser.add_argument(
        "--config", type=str, required=True,
        help="Path to the configuration file (e.g., train_config.json)."
    )
    args = parser.parse_args()

    main(args.config)
