import os
import subprocess
import re

from algorithms.adqn.utils.utils import list_to_string, get_run_name


def launch_jobs_for_criterion(
    seeds,
    models,
    m_seeds,
    criterion,
    save_path,
    target_update_frequency=None,
    min_steps_per_epoch=None,
    val_episodes=None,
    n_epochs=None,
):
    user = os.getenv("USER")
    models_string = list_to_string(models)
    adqn_string = [
        (
            f"sbatch -p amd,amd2,amd3 -J adqn --array={seeds[0]}-{seeds[-1]} "
            f"--cpus-per-task=1 --mem-per-cpu=3G --time=6:00:00 "
            f"-o /home/{user}/aDQN/slurm_output/%A_%a.out "
            f"launch_job/train_adqn.sh -hl {models_string.replace('_', ' ')} "
            f"-ms {' '.join(m_seeds)}"
        )
    ]

    dqn_strings = []
    for model, m_seed in zip(models, m_seeds):
        model_string = list_to_string([model])
        dqn_strings.append(
            f"sbatch -p amd,amd2,amd3 -J dqn --array={seeds[0]}-{seeds[-1]} "
            f"--cpus-per-task=1 --mem-per-cpu=3G --time=6:00:00 "
            f"-o /home/{user}/aDQN/slurm_output/%A_%a.out "
            f"launch_job/train_adqn.sh -hl {model_string} -ms {m_seed}"
        )

    executes = adqn_string + dqn_strings

    for command in executes:
        command += f" -c {criterion} -p {save_path}"
        if target_update_frequency is not None:
            command += f" -tuf {target_update_frequency}"
        if min_steps_per_epoch is not None:
            command += f" -mspe {min_steps_per_epoch}"
        if val_episodes is not None:
            command += f" -ve {val_episodes}"
        if n_epochs is not None:
            command += f" -ne {n_epochs}"
        print(command)
        result = subprocess.run(command, capture_output=True, text=True, shell=True)
        print(result.stdout if len(result.stdout) > 0 else result.stderr)


def launch_all_jobs(
    seeds,
    models,
    m_seeds,
    save_path,
    target_update_frequency=None,
    min_steps_per_epoch=None,
    val_episodes=None,
    n_epochs=None,
):
    user = os.getenv("USER")
    models_string = list_to_string(models)
    adqn_string = [
        (
            f"sbatch -p amd,amd2,amd3 -J adqn --array={seeds[0]}-{seeds[-1]} "
            f"--cpus-per-task=1 --mem-per-cpu=3G --time=2:00:00 "
            f"-o /home/{user}/aDQN/slurm_output/%A_%a.out "
            f"launch_job/train_adqn.sh -hl {models_string.replace('_', ' ')} "
            f"-ms {' '.join(m_seeds)} -c {criterion}"
        )
        for criterion in ["min", "max", "random", "eps_min"]
    ]

    dqn_strings = []
    for model, m_seed in zip(models, m_seeds):
        model_string = list_to_string([model])
        dqn_strings.append(
            f"sbatch -p amd,amd2,amd3 -J dqn --array={seeds[0]}-{seeds[-1]} "
            f"--cpus-per-task=1 --mem-per-cpu=3G --time=2:00:00 "
            f"-o /home/{user}/aDQN/slurm_output/%A_%a.out "
            f"launch_job/train_adqn.sh -hl {model_string} -ms {m_seed}"
        )

    executes = adqn_string + dqn_strings

    for command in executes:
        command += f" -p {save_path}"
        if target_update_frequency is not None:
            command += f" -tuf {target_update_frequency}"
        if min_steps_per_epoch is not None:
            command += f" -mspe {min_steps_per_epoch}"
        if val_episodes is not None:
            command += f" -ve {val_episodes}"
        if n_epochs is not None:
            command += f" -ne {n_epochs}"
        print(command)
        result = subprocess.run(command, capture_output=True, text=True, shell=True)
        print(result.stdout if len(result.stdout) > 0 else result.stderr)


def transform_string(input_string):
    # Define the pattern to extract the variables
    pattern = r"(\d+)-(\d+)_(\d+)-(\d+)_(\w+)_tuf(\d+)_mspe(\d+)_ne(\d+)"

    # Extract the variables using the pattern
    match = re.match(pattern, input_string)

    # Assign the extracted values to the variables
    models = [
        [int(match.group(1)), int(match.group(2))],
        [int(match.group(3)), int(match.group(4))],
    ]
    criterion = match.group(5)
    target_update_frequency = int(match.group(6))
    min_steps_per_epoch = int(match.group(7))
    n_epochs = int(match.group(8))

    # Return the variables
    return models, criterion, target_update_frequency, min_steps_per_epoch, n_epochs


def get_unique_save_path(save_path):
    folder_exists = os.path.exists(f"algorithms/adqn/runs/{save_path}")
    new_save_path = save_path
    count = 1

    while folder_exists:
        new_save_path = f"{save_path}_{count}"
        folder_exists = os.path.exists(f"algorithms/adqn/runs/{new_save_path}")
        count += 1

    return new_save_path


if __name__ == "__main__":
    seeds = list(range(20))
    val_episodes = 0
    min_steps_per_epoch = 6000
    n_epochs = 80

    models_list = [
        # [[100, 100], [100, 100], [100, 100]],
        [[200, 200], [100, 100], [50, 50], [25, 25]],
        # [[100, 100], [50, 50, 50], [25, 25, 25, 25]],
    ]
    # tup_list = [1000]
    tup_list = [200]
    for models in models_list:
        m_seeds = [str(i) for i in range(len(models))]
        for target_update_frequency in tup_list:
            run_name = get_run_name(
                models,
                "all",
                target_update_frequency,
                min_steps_per_epoch,
                n_epochs,
            )
            save_path = get_unique_save_path(run_name)
            launch_all_jobs(
                seeds,
                models,
                m_seeds,
                save_path,
                target_update_frequency,
                min_steps_per_epoch,
                val_episodes=val_episodes,
                n_epochs=n_epochs,
            )
