"""Simple hyperparameter grid-search search script.

Only searches parameters for the FSM synthesis algorithm.
"""
from pathlib import Path
import argparse
import json
import datetime
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import Future
from swmpo_experiments.pipeline import main as pipeline
import statistics
import copy
from itertools import product
import random


def wait_for_jobs(max_jobs: int, jobs: list[Future]):
    while len(jobs) > max_jobs:
        j = jobs.pop()
        j.result()


def get_candidate_hyperparameters(
    base_hyperparameters: dict,
) -> dict:
    parameters = list(dict(
        state_machine_learning_rate=[
            0.00001,
            0.0001,
            0.001,
            0.02,
        ],
        state_machine_hidden_sizes=[
            "32 32",
            "64 64 64",
            #"128 128 128 128",
        ],
        state_machine_autoencoder_latent_size=[
            4,
            #8,
            32,
            #128,
        ],
        state_machine_cluster_dimensionality_reduce=[
            0,
            2,
            #8,
            #16,
        ],
        state_machine_information_content_regularization_scale=[
            0.0,
            0.00001,
            0.0001,
        ],
        state_machine_mutual_information_regularization_scale=[
            0.0,
            0.001,
            0.01,
        ],
        state_machine_batch_size=[
            256,
            1024,
        ]
    ).items())

    configurations = list[list[int]](product(*[
        list(range(len(values)))
        for _, values in parameters
    ]))

    candidates = list()
    for configuration in configurations:
        hyperparameters = copy.deepcopy(base_hyperparameters)

        for param_i, option_i in enumerate(configuration):
            key = parameters[param_i][0]
            value = parameters[param_i][1][option_i]
            hyperparameters[key] = value

        candidates.append(hyperparameters)
    return candidates


def main(
    base_hyperparameter_json: Path,
    output_dir: Path,
    worker_n: int,
    search_budget: int,
    trajectory_cache_dir: Path,
):
    # Parse json
    with open(args.base_hyperparameter_json, "rt") as fp:
        base_hyperparameters = json.load(fp)

    date_str = str(datetime.datetime.now().isoformat())
    env_id = base_hyperparameters["env_id"]
    working_dir_name = f"{env_id}_hyperparam_search_{date_str}"
    working_dir = (output_dir/working_dir_name).with_suffix("")
    working_dir.mkdir()

    # Generate hyperparameter JSONs
    candidate_hyperparameters = get_candidate_hyperparameters(
        base_hyperparameters=base_hyperparameters,
    )

    random.shuffle(candidate_hyperparameters)
    candidate_hyperparameters = candidate_hyperparameters[:search_budget]

    print(f"Number of hyperparameter configurations: {len(candidate_hyperparameters)}")

    # Setup workers
    pool = ProcessPoolExecutor(worker_n)
    jobs = list[Future]()

    # Run each experiment
    experiment_dirs = list()
    for i, hyperparameters in enumerate(candidate_hyperparameters):
        print(f"Launching hyperparameter search: {i+1}/{len(candidate_hyperparameters)}")
        # Create experiment dir
        experiment_dir = working_dir / f"{i}"
        experiment_dir.mkdir()
        experiment_dirs.append(experiment_dir)

        # Write hyperparameters
        hyperparameters_path = experiment_dir / "hyperparameters.json"
        with open(hyperparameters_path, "wt") as fp:
            json.dump(hyperparameters, fp)

        # Run pipeline
        results_dir = experiment_dir / "results"
        results_dir.mkdir()
        j = pool.submit(
            pipeline,
            results_dir=results_dir,
            trajectory_cache_dir=trajectory_cache_dir,
            **hyperparameters,
        )
        jobs.append(j)
        wait_for_jobs(worker_n, jobs)

    wait_for_jobs(0, jobs)

    # Write a summary file with the performance of each set of hyperparameters
    summary = list()
    for experiment_dir in experiment_dirs:
        mode_errors_json = experiment_dir/"results"/"partition_benchmark"/"visited_states_errors.json"
        with open(mode_errors_json, "rt") as fp:
            mode_errors = json.load(fp)

        experiment_summary = dict(
            path=str(experiment_dir),
            mean_fsm_errors=statistics.mean(mode_errors["fsm"]),
        )

        summary.append(experiment_summary)

    summary = sorted(
        summary,
        key=lambda s: s["mean_fsm_errors"],
    )

    summary_path = working_dir/"summary.json"
    with open(summary_path, "wt") as fp:
        fp.write(summary_path)
    print(f"Wrote {summary_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Experiment pipeline',
        description='Execute the full experiment pipeline.',
    )
    parser.add_argument(
        '--base_hyperparameter_json',
        type=Path,
        required=True,
        help='Base JSON file with hyperparameters.'
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of parallel experiments.'
    )
    parser.add_argument(
        '--search_budget',
        type=int,
        required=True,
        help='Maximum number of experiments.'
    )
    args = parser.parse_args()

    # Create results dir if it doesn't exist
    output_dir = Path("output_results")
    output_dir.mkdir(exist_ok=True)
    trajectory_cache_dir = output_dir / "trajectory_cache"
    trajectory_cache_dir.mkdir(exist_ok=True)

    main(
        base_hyperparameter_json=args.base_hyperparameter_json,
        output_dir=output_dir,
        worker_n=args.worker_n,
        search_budget=args.search_budget,
        trajectory_cache_dir=trajectory_cache_dir,
    )
