"""Run the experiment pipeline for a particular environment."""
import subprocess
from pathlib import Path
import hashlib
import datetime
import sys
import argparse
import json
import shutil
import random


def wait_for_jobs(max_jobs: int, jobs: list[subprocess.Popen]):
    while len(jobs) > max_jobs:
        j = jobs.pop()
        j.wait()


def async_submit_job(args: list[str]):
    j = subprocess.Popen(args)
    return j


def generate_trajectories(
    env_id: str,
    src_dir: Path,
    trajectory_dir: Path,
    trajectory_n: int,
    worker_n: int,
    environment_specific_hyperparameters: dict,
    seed: str,
    dt: float,
):
    if env_id == "TerrainMass-v0":
        subprocess.run([
            "python", str(src_dir / "swmpo_experiments" / "terrain_mass_utils" / "generate_trajectories.py"),
            "--output_dir", str(trajectory_dir),
            "--episode_n", str(trajectory_n),
            "--worker_n", str(worker_n),
            "--seed", seed,
            "--mpc_plan_len", str(environment_specific_hyperparameters["mpc_plan_len"]),
            "--mpc_initial_stdev", str(environment_specific_hyperparameters["mpc_initial_stdev"]),
            "--mpc_iter_n", str(environment_specific_hyperparameters["mpc_iter_n"]),
            "--simulation_step_n", str(environment_specific_hyperparameters["simulation_step_n"])
        ])
    elif env_id == "AutonomousCar-v0":
        subprocess.run([
            "python", str(src_dir / "swmpo_experiments" / "autonomous_driving_utils" / "generate_trajectories.py"),
            "--output_dir", str(trajectory_dir),
            "--trajectory_n", str(trajectory_n),
            "--dt", str(dt),
            "--worker_n", str(worker_n),
            "--seed", seed,
        ])
    elif env_id == "Salamander-v0":
        subprocess.run([
            "python", str(src_dir / "swmpo_experiments" / "salamander_utils" / "generate_trajectories.py"),
            "--output_dir", str(trajectory_dir),
            "--trajectory_n", str(trajectory_n),
            "--worker_n", str(worker_n),
            "--seed", seed,
        ])
    elif env_id == "BipedalWalkerHardcoreModes-v3":
        EXPERT_POLICY_DIR = Path(__file__).parent.parent/"lib"/"rl-trained-agents"/"ppo"/"BipedalWalkerHardcore-v3_1"
        print(EXPERT_POLICY_DIR.resolve())
        assert EXPERT_POLICY_DIR.exists()
        subprocess.run([
            "python", str(src_dir / "swmpo_experiments" / "bipedal_walker_hardcore_utils" / "generate_trajectories.py"),
            "--output_dir", str(trajectory_dir),
            "--trajectory_n", str(trajectory_n),
            "--worker_n", str(worker_n),
            "--seed", seed,
            "--expert_policy_dir", str(EXPERT_POLICY_DIR),
        ])
    else:
        raise ValueError(f"Environment {env_id} not recognized!")


def main(
    seed: str,
    worker_n: int,
    cuda_device: str,

    # Environment hyperparameters:
    dt: float,

    # State machine synthesis hyperparameters
    trajectory_n: int,
    state_machine_hidden_sizes: str,
    state_machine_learning_rate: float,
    state_machine_partition_size: int,
    state_machine_batch_size: int,
    state_machine_min_island_size: int,
    state_machine_autoencoder_latent_size: int,
    state_machine_prunning_error_tolerance: float,
    state_machine_model_mode_iter_n: int,
    state_machine_cluster_dimensionality_reduce: int,
    state_machine_information_content_regularization_scale: float,
    state_machine_mutual_information_regularization_scale: float,
    state_machine_mutual_information_mini_batch_size: int,
    state_machine_optimization_iter_n: int,
    state_machine_predicate_hyperparameters: dict,

    # RL hyperparameters
    rl_exploration_window_size: int,
    rl_extrinsic_reward_constant: float,
    rl_train_timestep_n: int,
    rl_datapoint_n: int,
    rl_experiment_n: int,
    rl_parallel_env_n: int,
    rl_plot_episode_n: int,
    rl_eval_episode_n: int,
    env_id: str,
    ground_truth_env_id: str,

    # Environment-specific hyperparameters
    environment_specific_hyperparameters: dict[str, str],

    # Execution flags:
    run_rl: bool,

    # Shared results directory (assumed to exist). A subdir will be created
    results_dir: Path,

    # Cache dir
    trajectory_cache_dir: Path,
) -> Path:
    """Returns the directory with experiment results."""
    _random = random.Random(seed)

    # Get a hash of the source code
    src_files = list(Path(".").rglob("*.py")) + list(Path(".").rglob("*.sh"))
    src_files = [
        f
        for f in src_files
        if not any(
            exclude in str(f) for exclude in ["venv", "output_results"]
        )
    ]
    experiment_hash = "".join(f.read_text() for f in src_files)
    experiment_id = hashlib.md5(experiment_hash.encode()).hexdigest()

    date_str = str(datetime.datetime.now().isoformat())
    output_dir = (results_dir / f"{env_id}_output_results_{date_str}_{experiment_id}").with_suffix("")
    output_dir.mkdir()
    output_zip = output_dir.with_suffix(".zip")

    train_trajectories_dir = output_dir / "train_trajectories"
    test_trajectories_dir = output_dir / "test_trajectories"
    state_machine_plot_dir = output_dir / "state_machine_plot"
    state_machine_prunned_plot_dir = output_dir / "state_machine_prunned_plot"
    date_file = output_dir / "date.txt"
    plot_file = output_dir / "rewards.svg"
    rl_dir = output_dir / "rl_results"
    state_machine_dir = output_dir / "state_machine"
    state_machine_zip = state_machine_dir / "state_machine.zip"
    state_machine_prunned_zip = state_machine_dir / "state_machine_prunned.zip"
    hmm_dir = output_dir / "hmm"
    hmm_pkl = hmm_dir / "hmm.pkl"
    swmpo_dir = output_dir / "swmpo_results"
    biased_swmpo_dir = output_dir / "biased_swmpo_results"
    swmpo_ground_truth_dir = output_dir / "swmpo_ground_truth"
    biased_swmpo_ground_truth_dir = output_dir / "biased_swmpo_ground_results"
    partition_benchmark_dir = output_dir / "partition_benchmark"

    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Output dir: {output_dir}")
    print(f"Output zip: {output_zip}")

    # Serialize code
    src_dir = output_dir / "src"
    src_dir.mkdir()

    subprocess.run(["mkdir", str(src_dir / "lib")])
    subprocess.run(["cp", "-r", "lib/swmpo", str(src_dir / "lib")])
    #subprocess.run(["cp", "-r", "lib", str(src_dir)])
    #subprocess.run(["rm", "-rf", str(src_dir / "lib" / "rl-trained-agents")])
    subprocess.run(["cp", "-r", "swmpo_experiments", str(src_dir)])
    subprocess.run(["cp", "README.md", str(src_dir)])
    subprocess.run(["cp", "pyproject.toml", str(src_dir)])
    subprocess.run(["cp", "poetry.lock", str(src_dir)])

    # Serialize hyperparameters
    hyperparameters = dict(
        seed=seed,
        worker_n=worker_n,
        cuda_device=cuda_device,
        dt=dt,
        trajectory_n=trajectory_n,
        state_machine_hidden_sizes=state_machine_hidden_sizes,
        state_machine_learning_rate=state_machine_learning_rate,
        state_machine_partition_size=state_machine_partition_size,
        state_machine_batch_size=state_machine_batch_size,
        state_machine_min_island_size=state_machine_min_island_size,
        state_machine_autoencoder_latent_size=state_machine_autoencoder_latent_size,
        state_machine_prunning_error_tolerance=state_machine_prunning_error_tolerance,
        state_machine_model_mode_iter_n=state_machine_model_mode_iter_n,
        state_machine_cluster_dimensionality_reduce=state_machine_cluster_dimensionality_reduce,
        state_machine_information_content_regularization_scale=state_machine_information_content_regularization_scale,
        state_machine_mutual_information_regularization_scale=state_machine_mutual_information_regularization_scale,
        state_machine_mutual_information_mini_batch_size=state_machine_mutual_information_mini_batch_size,
        state_machine_optimization_iter_n=state_machine_optimization_iter_n,
        state_machine_predicate_hyperparameters=state_machine_predicate_hyperparameters,
        rl_exploration_window_size=rl_exploration_window_size,
        rl_extrinsic_reward_constant=rl_extrinsic_reward_constant,
        rl_train_timestep_n=rl_train_timestep_n,
        rl_datapoint_n=rl_datapoint_n,
        rl_experiment_n=rl_experiment_n,
        rl_parallel_env_n=rl_parallel_env_n,
        rl_plot_episode_n=rl_plot_episode_n,
        rl_eval_episode_n=rl_eval_episode_n,
        env_id=env_id,
        ground_truth_env_id=ground_truth_env_id,
        environment_specific_hyperparameters=environment_specific_hyperparameters,
    )
    with open(output_dir / "hyperparameters.json", "wt") as fp:
        json.dump(hyperparameters, fp, indent=2)

    # Log date
    with open(date_file, "w") as f:
        f.write(datetime.datetime.now().isoformat())

    jobs = list[subprocess.Popen]()

    # Gather trajectories
    # Note: we use sys.executable to inherit the virtual virtual environment
    print("Gathering trajectories")
    trajectory_dirs = [train_trajectories_dir, test_trajectories_dir]

    cached_dataset_dir = trajectory_cache_dir / env_id / f"{trajectory_n}_trajectories"
    for trajectory_dir in trajectory_dirs:
        tcache_dir = cached_dataset_dir / trajectory_dir.stem
        if not tcache_dir.exists():
            # Generate new trajectories to the cache dir
            print(f"Generating new cache {tcache_dir}")
            cached_dataset_dir.mkdir(exist_ok=True, parents=True)
            generate_trajectories(
                env_id=env_id,
                src_dir=src_dir,
                trajectory_dir=tcache_dir,
                trajectory_n=trajectory_n,
                worker_n=worker_n,
                environment_specific_hyperparameters=environment_specific_hyperparameters,
                seed=str(_random.random()),
                dt=dt,
            )

        # Cache directory
        print(f"Using cached trajectories {tcache_dir}")
        shutil.copytree(tcache_dir, trajectory_dir)
    print("Done")

    # Synthesize state machine
    print("Synthesizing state machine...")
    state_machine_predicate_hyperparameters_path = output_dir/"predicate_hyperparameters.json"
    with open(state_machine_predicate_hyperparameters_path, "wt") as fp:
        json.dump(state_machine_predicate_hyperparameters, fp)
    subprocess.run([
        "python", str(src_dir / "swmpo_experiments/state_machine_synthesis.py"),
        "--output_dir", str(state_machine_dir),
        "--hidden_sizes", *state_machine_hidden_sizes.split(" "),
        "--learning_rate", str(state_machine_learning_rate),
        "--optimization_iter_n", str(state_machine_optimization_iter_n),
        "--mode_model_iter_n", str(state_machine_model_mode_iter_n),
        "--state_n", str(state_machine_partition_size),
        "--min_island_size", str(state_machine_min_island_size),
        "--partition_latent_size", str(state_machine_autoencoder_latent_size),
        "--cluster_dimensionality_reduce", str(state_machine_cluster_dimensionality_reduce),
        "--cluster_information_content_regularization_scale", str(state_machine_information_content_regularization_scale),
        "--cluster_mutual_information_regularization_scale", str(state_machine_mutual_information_regularization_scale),
        "--cuda_device", cuda_device,
        "--batch_size", str(state_machine_batch_size),
        "--mutual_information_mini_batch_size", str(state_machine_mutual_information_mini_batch_size),
        "--prunning_error_tolerance", str(state_machine_prunning_error_tolerance),
        "--seed", seed,
        "--dt", str(dt),
        "--train_trajectory_dir", str(train_trajectories_dir),
        "--predicate_hyperparameters_json", str(state_machine_predicate_hyperparameters_path),
    ])
    print("Synthesis of state machine finished.")

    # Inspect state machine
    print("Plotting state machine...")
    plot_tasks = [
        (state_machine_plot_dir, state_machine_zip),
        (state_machine_prunned_plot_dir, state_machine_prunned_zip)
    ]
    for plot_dir, zip_file in plot_tasks:
        subprocess.run([
            "python", str(src_dir / "swmpo_experiments/plot_state_machine.py"),
            "--output_dir", str(plot_dir),
            "--dt", str(dt),
            "--state_machine_zip", str(zip_file),
            "--test_trajectory_dir", str(test_trajectories_dir)
        ])
    print("Plotting state machine done.")

    # Synthesize HMM
    print("Synthesizing HMM...")
    subprocess.run([
        "python", str(src_dir / "swmpo_experiments/hmm_synthesis.py"),
        "--train_trajectory_dir", str(train_trajectories_dir),
        "--output_dir", str(hmm_dir),
        "--component_n", str(state_machine_partition_size),
        "--seed", seed,
        "--worker_n", str(worker_n)
    ])
    print("Synthesis of HMM finished.")

    # Compare HMM and state machine partition
    print("Benchmarking partition...")
    subprocess.run([
        "python", str(src_dir / "swmpo_experiments/benchmark_partitions.py"),
        "--test_trajectory_dir", str(test_trajectories_dir),
        "--output_dir", str(partition_benchmark_dir),
        "--state_machine_zip", str(state_machine_zip),
        "--dt", str(dt),
        "--hmm_pkl", str(hmm_pkl)
    ])
    print("Partition benchmark done.")

    if run_rl:
        # Benchmark RL
        print("Benchmarking RL...")
        rl_dir.mkdir()
        for experiment_i in range(rl_experiment_n):
            experiment_dir = rl_dir / str(experiment_i)
            j = async_submit_job([
                "python", str(src_dir / "swmpo_experiments" / "benchmark_rl.py"),
                "--output_dir", str(experiment_dir),
                "--train_timestep_n", str(rl_train_timestep_n),
                "--datapoint_n", str(rl_datapoint_n),
                "--plot_episode_n", str(rl_plot_episode_n),
                "--parallel_env_n", str(rl_parallel_env_n),
                "--eval_episode_n", str(rl_eval_episode_n),
                "--env_id", env_id,
                "--cuda_device", cuda_device,
                "--seed", str(experiment_i)
            ])
            jobs.append(j)
            wait_for_jobs(worker_n, jobs)
        wait_for_jobs(0, jobs)

        # Benchmark SWMPO
        print("Benchmarking SWMPO[RL]...")
        swmpo_dir.mkdir()
        for experiment_i in range(rl_experiment_n):
            experiment_dir = swmpo_dir / str(experiment_i)
            j = async_submit_job([
                "python", str(src_dir / "swmpo_experiments/benchmark_swmpo.py"),
                "--output_dir", str(experiment_dir),
                "--state_machine_zip", str(state_machine_prunned_zip),
                "--train_timestep_n", str(rl_train_timestep_n),
                "--datapoint_n", str(rl_datapoint_n),
                "--plot_episode_n", str(rl_plot_episode_n),
                "--parallel_env_n", str(rl_parallel_env_n),
                "--eval_episode_n", str(rl_eval_episode_n),
                "--env_id", env_id,
                "--dt", str(dt),
                "--extrinsic_reward_scale", str(rl_extrinsic_reward_constant),
                "--exploration_window_size", str(rl_exploration_window_size),
                "--cuda_device", cuda_device,
                "--seed", str(experiment_i)
            ])
            jobs.append(j)
            wait_for_jobs(worker_n, jobs)
        wait_for_jobs(0, jobs)

        # Benchmark biased SWMPO
        print("Benchmarking biased SWMPO[RL]...")
        biased_swmpo_dir.mkdir()
        for experiment_i in range(rl_experiment_n):
            experiment_dir = biased_swmpo_dir / str(experiment_i)
            j = async_submit_job([
                "python", str(src_dir / "swmpo_experiments/benchmark_swmpo_mode_bias.py"),
                "--output_dir", str(experiment_dir),
                "--state_machine_zip", str(state_machine_prunned_zip),
                "--train_timestep_n", str(rl_train_timestep_n),
                "--datapoint_n", str(rl_datapoint_n),
                "--plot_episode_n", str(rl_plot_episode_n),
                "--parallel_env_n", str(rl_parallel_env_n),
                "--eval_episode_n", str(rl_eval_episode_n),
                "--env_id", env_id,
                "--dt", str(dt),
                "--extrinsic_reward_scale", str(rl_extrinsic_reward_constant),
                "--exploration_window_size", str(rl_exploration_window_size),
                "--cuda_device", cuda_device,
                "--seed", str(experiment_i)
            ])
            jobs.append(j)
            wait_for_jobs(worker_n, jobs)
        wait_for_jobs(0, jobs)
        print("Biased SWMPO[RL] benchmark finished.")

        # Benchmark ground truth SWMPO
        print("Benchmarking ground truth SWMPO...")
        swmpo_ground_truth_dir.mkdir()
        for experiment_i in range(rl_experiment_n):
            experiment_dir = swmpo_ground_truth_dir / str(experiment_i)
            j = async_submit_job([
                "python", str(src_dir / "swmpo_experiments/benchmark_rl.py"),
                "--output_dir", str(experiment_dir),
                "--train_timestep_n", str(rl_train_timestep_n),
                "--datapoint_n", str(rl_datapoint_n),
                "--plot_episode_n", str(rl_plot_episode_n),
                "--parallel_env_n", str(rl_parallel_env_n),
                "--eval_episode_n", str(rl_eval_episode_n),
                "--env_id", ground_truth_env_id,
                "--cuda_device", cuda_device,
                "--seed", str(experiment_i)
            ])
            jobs.append(j)
            wait_for_jobs(worker_n, jobs)
        wait_for_jobs(0, jobs)

        # Benchmark biased ground truth SWMPO
        print("Benchmarking biased ground truth SWMPO...")
        biased_swmpo_ground_truth_dir.mkdir()
        for experiment_i in range(rl_experiment_n):
            experiment_dir = biased_swmpo_ground_truth_dir / str(experiment_i)
            j = async_submit_job([
                "python", str(src_dir / "swmpo_experiments/benchmark_rl_mode_bias.py"),
                "--output_dir", str(experiment_dir),
                "--train_timestep_n", str(rl_train_timestep_n),
                "--datapoint_n", str(rl_datapoint_n),
                "--plot_episode_n", str(rl_plot_episode_n),
                "--parallel_env_n", str(rl_parallel_env_n),
                "--eval_episode_n", str(rl_eval_episode_n),
                "--env_id", ground_truth_env_id,
                "--cuda_device", cuda_device,
                "--seed", str(experiment_i)
            ])
            jobs.append(j)
            wait_for_jobs(worker_n, jobs)
        wait_for_jobs(0, jobs)

        # Plot rewards
        print("Plotting")
        swmpo_csvs = list(Path(swmpo_dir).rglob("*.csv"))
        biased_swmpo_csvs = list(Path(biased_swmpo_dir).rglob("*.csv"))
        rl_csvs = list(Path(rl_dir).rglob("*.csv"))
        swmpo_ground_truth_csvs = list(Path(swmpo_ground_truth_dir).rglob("*.csv"))
        biased_swmpo_ground_truth_csvs = list(
            Path(biased_swmpo_ground_truth_dir).rglob("*.csv")
        )
        subprocess.run([
            sys.executable, str(src_dir / "swmpo_experiments/reward_plotting.py"),
            "--rl_csvs", *[str(csv) for csv in rl_csvs],
            "--swmpo_csvs", *[str(csv) for csv in swmpo_csvs],
            "--biased_swmpo_csvs", *[str(csv) for csv in biased_swmpo_csvs],
            "--swmpo_ground_truth_csvs", *[str(csv) for csv in swmpo_ground_truth_csvs],
            "--biased_swmpo_ground_truth_csvs", *[str(csv) for csv in biased_swmpo_ground_truth_csvs],
            "--output_path", str(plot_file)
        ])
        print("Done plotting")

    print("Pipeline complete!")
    subprocess.run(["zip", "-r", "-q", str(output_zip), str(output_dir)])
    print(f"Wrote {output_zip}")

    return output_dir


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Experiment pipeline',
        description='Execute the full experiment pipeline.',
    )
    parser.add_argument(
        '--hyperparameter_json',
        type=Path,
        required=True,
        help='JSON file with hyperparameters.'
    )
    args = parser.parse_args()

    # Create results dir if it doesn't exist
    results_dir = Path("output_results")
    results_dir.mkdir(exist_ok=True)
    trajectory_cache_dir = results_dir / "trajectory_cache"
    trajectory_cache_dir.mkdir(exist_ok=True)

    # Parse json
    with open(args.hyperparameter_json, "rt") as fp:
        hyperparameters = json.load(fp)

    main(
        **hyperparameters,
        results_dir=results_dir,
        trajectory_cache_dir=trajectory_cache_dir,
    )
