import os
from itertools import product
import stat
import shutil
from Models.experiment_identifier import ExperimentIdentifier
from Models.training_tracker import get_training_status, TrainingMetadata
from Models.result_data_checkpointer import SampleGenerationMetadata
from Utils.io_utils import load_yaml_config
import argparse
from typing import Optional, Dict, List, Union, Tuple, Set, Any
import wadler_lindig as wl
from Models.models.ho_models.bwd_evaluation import BWD_EVAL_SAVE_DIR
import glob

def parse_args():
  parser = argparse.ArgumentParser(description='Check status of training experiments')
  # parser.add_argument('--kind', type=str, choices=['train', 'test', 'results', 'all'], default='all', help='What kinds of scripts to generate')
  return parser.parse_args()

# 1) python -m Data.harmonic_oscillator.make_data2
# 2) python -m Config.harmonic_oscillator.generate_configs
# 3) python -m Models.dynamic_latent_size_models.ho_models.generate_bwd_scripts

def get_experiment_configs(args):

  harmonic_oscillator_param_dict = {
    # "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    "time": ["0:59:59"],
    "gpu": ["gpu-preempt"],
    # "time": ["47:59:59"],
    # "gpu": ["gpu"],
    "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    "random_key_seed": [0, 1, 2, 3, 4],
    "config_file": [
      "Config/noisy_double_pendulum.yaml",
      "Config/lorenz.yaml",
      "Config/fitzhugh.yaml",
      "Config/lotka.yaml",
      "Config/brusselator.yaml",
      "Config/van_der_pol.yaml",
    ],
    "freq": [0, 1, 2, 3, 4],
    "sde_type": ["tracking"],
  }

  param_dict = harmonic_oscillator_param_dict


  configs = []
  keys = list(param_dict.keys())
  for combo in product(*param_dict.values()):
    config = dict(zip(keys, combo))

    config_name = os.path.splitext(os.path.basename(config["config_file"]))[0]
    config["config_name"] = config_name
    sde_type = config["sde_type"]
    random_key_seed = config["random_key_seed"]
    model_freq = config["freq"]
    config["experiment_name"] = f"{config_name}_{sde_type}_{random_key_seed}_{model_freq}_bwd_eval"

    print('Adding config')
    wl.pprint(config)
    configs.append(config)

  return configs


def write_script(cfg: Dict[str, Any],
                 template_str: str,
                 output_dir: str) -> tuple[str | None, str | None]:
  """Writes a single shell script."""

  # Construct python command safely
  python_command_parts = [
      "python -m Models.dynamic_latent_size_models.ho_models.bwd_evaluation",
      f"--config_name={cfg['config_name']}",
      f"--random_key_seed={cfg['random_key_seed']}",
      f"--model_freq={cfg['freq']}",
      f"--other_evaluation",
  ]

  # Join with space and backslash-newline for shell script line continuation
  # Ensure correct escaping for the format string later
  python_command = " ".join(python_command_parts)

  # Format into the template string
  script_content = template_str.format(
      python_command=python_command,
      experiment_name=cfg['experiment_name'],
      gpu=cfg['gpu'],
      gpu_type=cfg['gpu_type'],
      time=cfg['time']
  )

  os.makedirs(output_dir, exist_ok=True)
  script_filename = f"{cfg['experiment_name']}.sh"
  script_path = os.path.join(output_dir, script_filename)

  with open(script_path, "w") as f:
    f.write(script_content)
  print(f"Wrote script: {script_path}")

  # Make executable
  os.chmod(script_path, os.stat(script_path).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)

  return script_path, python_command

def main():
  args = parse_args()

  output_dir = f"{BWD_EVAL_SAVE_DIR}/generated_training_scripts"
  script_template_path = "script_generator/script_template.sh"

  with open(script_template_path, "r") as f:
    template_str = f.read()

  # Clean output directory
  if os.path.exists(output_dir):
    shutil.rmtree(output_dir)

  configs = get_experiment_configs(args)

  python_commands = []
  sbatch_commands = []
  processed_scripts = 0

  # Process training scripts
  for i, cfg in enumerate(configs):
    script_path, python_cmd = write_script(cfg, template_str, output_dir)
    if script_path and python_cmd:
      python_commands.append(python_cmd)
      sbatch_commands.append(f"sbatch {script_path}")
      processed_scripts += 1

  # Write command lists
  os.makedirs(f"{BWD_EVAL_SAVE_DIR}", exist_ok=True)
  with open(f"{BWD_EVAL_SAVE_DIR}/python_commands_latent.txt", "w") as f:
    f.write("\n".join(python_commands))
  with open(f"{BWD_EVAL_SAVE_DIR}/sbatch_commands_latent.txt", "w") as f:
    f.write("\n".join(sbatch_commands))

if __name__ == "__main__":
  from debug import *
  main()
