import os
from itertools import product
import stat
import shutil
from Models.experiment_identifier import ExperimentIdentifier
from Models.training_tracker import get_training_status, TrainingMetadata
from Models.result_data_checkpointer import SampleGenerationMetadata
from Utils.io_utils import load_yaml_config
import argparse
from typing import Optional, Dict, List, Union, Tuple, Set, Any

def parse_args():
  parser = argparse.ArgumentParser(description='Check status of training experiments')
  parser.add_argument('--kind', type=str, choices=['train', 'test', 'results', 'all'], default='all', help='What kinds of scripts to generate')
  return parser.parse_args()

def get_experiment_identifier(config):
  config_file_path = config["config_file"]
  model_name = config["model_name"]
  freq_value = config["freq"]
  sde_type = config["sde_type"]
  group = config["group"]
  global_key_seed = config["global_key_seed"]

  # Load the YAML config file to get the model's objective
  yaml_config = load_yaml_config(config_file_path)
  if model_name not in yaml_config:
    raise ValueError(f"Model name '{model_name}' not found in config file {config_file_path}")

  objective = yaml_config[model_name]["objective"]

  # Create the model identifier tuple
  base_name = os.path.splitext(os.path.basename(config_file_path))[0]
  model_id = (
    base_name,
    objective,
    model_name,
    sde_type,
    f'freq_{freq_value}',
    group,
    f'seed_{global_key_seed}'
  )

  # Create experiment identifier from model identifier
  experiment_id = ExperimentIdentifier.from_model_identifier(model_id)

  return experiment_id

def get_experiment_configs(args):
  # Each key is a parameter name, each value is a list of possible values

  param_dict = {
    # "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    # "time": ["1:59:59"],
    # "gpu": ["gpu-preempt"],
    "time": ["47:59:59"],
    "gpu": ["gpu"],
    "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    # "global_key_seed": [0],
    "global_key_seed": [0, 1, 2, 3, 4],
    "config_file": [
      # "Config_new/noisy_double_pendulum.yaml",
      "Config_new/stocks.yaml",
      "Config_new/energy.yaml",
      "Config_new/etth.yaml",
      "Config_new/mujoco.yaml",
      "Config_new/fmri.yaml",
      "Config_new/sines.yaml",
      # "Config_new/m4.yaml",
      # "Config_new/uber_tlc.yaml",
      # "Config_new/solar.yaml",
      # "Config_new/kdd_cup.yaml",
      # "Config_new/exchange.yaml",
    ],
    "freq": [0, 1],
    # "freq": [0],
    # "freq": [0, 1, 2, 4],
    # "sde_type": ["tracking"],
    "sde_type": ["brownian", "tracking"],
    "model_name": [
      "true_baseline_autoregressive", #
      # "baseline_autoregressive", #
      # "my_autoregressive", #
      "my_non_probabilistic", #
      "my_autoregressive_reparam", #
      "my_neural_sde", #
      "my_neural_ode", #
      # "my_diffusion_model", #
      "baseline_diffusion_model", #
      # "baseline_diffusion_model_rnn",
      # "true_baseline_autoregressive_rnn",
      # "baseline_autoregressive_rnn",
      # "my_autoregressive_rnn",
      # "my_non_probabilistic_rnn",
      # "my_autoregressive_reparam_rnn",
      # "my_neural_sde_rnn",
      # "my_neural_ode_rnn",
      # "my_diffusion_model_rnn",
      # "my_reparameterized_autoregressive_small",
    ],
    # "train_test_sample": ["--train"],
    # "train_test_sample": ["--test_time_evaluation"],
    "train_test_sample": ["--train", "--test_time_evaluation"],
    # "retrain": ["--retrain"],
    "retrain": [""],
    "log_plots": [""],
    # "group": ["hyperparameter_tuning"],
    "group": ["no_leakage_obs_forecasting"],
    # "group": ["no_leakage_latent_forecasting", "no_leakage_obs_forecasting"],
    # "group": ["rnn_models"],
    "sanity_check": [""],
    # "only_generate_samples": [""],
    # "only_generate_samples": ["--only_generate_samples"],
    "only_generate_samples": ["", "--only_generate_samples"],
    "restart_evaluation": [""],
    # "restart_evaluation": ["--restart_evaluation"],
    "no_leakage": [""], # --no_leakage is the default now
    # "no_leakage": ["--no_leakage"],
    "hyperparameter_tuning": [""],
    # "hyperparameter_tuning": ["--hyperparameter_tuning"],
  }



  rnn_param_dict = {
    # "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    "time": ["1:59:59"],
    "gpu": ["gpu-preempt"],
    # "time": ["47:59:59"],
    # "gpu": ["gpu"],
    "gpu_type": ["a40|a100|rtx8000|a16|2080ti"],
    # "global_key_seed": [0],
    "global_key_seed": [0, 1, 2, 3, 4],
    "config_file": [
      "Config/noisy_double_pendulum.yaml",
      "Config/lorenz.yaml",
      "Config/fitzhugh.yaml",
      "Config/lotka.yaml",
      "Config/brusselator.yaml",
      "Config/van_der_pol.yaml",
    ],
    "freq": [0, 1, 2, 3, 4],
    "sde_type": ["tracking"],
    "model_name": [
      "baseline_diffusion_model_rnn",
      "true_baseline_autoregressive_rnn",
      "baseline_autoregressive_rnn",
      "my_non_probabilistic_rnn",
      "my_autoregressive_reparam_rnn",
      "my_neural_sde_rnn",
      # "my_neural_ode_rnn",
      "my_diffusion_model_rnn",
      "my_autoregressive_reparam_rnn_bwd",
      "my_neural_sde_rnn_bwd",
    ],
    # "train_test_sample": ["--train"],
    # "train_test_sample": ["--test_time_evaluation"],
    "train_test_sample": ["--train", "--test_time_evaluation"],
    # "retrain": ["--retrain"],
    "retrain": [""],
    "log_plots": [""],
    "group": ["final_models"],
    "sanity_check": [""],
    # "only_generate_samples": [""],
    # "only_generate_samples": ["--only_generate_samples"],
    "only_generate_samples": ["", "--only_generate_samples"],
    "restart_evaluation": [""],
    # "restart_evaluation": ["--restart_evaluation"],
    "no_leakage": [""], # --no_leakage is the default now
    # "no_leakage": ["--no_leakage"],
    "hyperparameter_tuning": [""],
    # "hyperparameter_tuning": ["--hyperparameter_tuning"],
  }

  param_dict = rnn_param_dict

  # python -m Models.check_training_status --group final_latent_forecasting no_leakage_obs_forecasting

  configs = []
  keys = list(param_dict.keys())
  for combo in product(*param_dict.values()):
    config = dict(zip(keys, combo))

    base_name = os.path.splitext(os.path.basename(config["config_file"]))[0]
    freq_value = config["freq"]
    sde_type = config["sde_type"]
    model_name = config["model_name"]
    global_key_seed = config["global_key_seed"]
    group = config["group"]
    experiment_name = f"{model_name}_{base_name}_freq{freq_value}_{sde_type}_seed{global_key_seed}"
    config["experiment_name"] = experiment_name
    eid = get_experiment_identifier(config)

    allowable_groups_for_tracking = ["final_models", "harmonic_oscillator", "final_latent_forecasting", "no_leakage_latent_forecasting", "hyperparameter_tuning", "rnn_models"]
    allowable_groups_for_brownian = ["final_models", "harmonic_oscillator", "exp_april_12", "exp_april_22", "final_obs_forecasting", "no_leakage_obs_forecasting", "rnn_models"]


    physics_base_names = ["noisy_double_pendulum", "lorenz", "fitzhugh", "lotka", "brusselator", "van_der_pol"]

    if base_name in physics_base_names:
      if sde_type != "tracking":
        continue

      if group not in allowable_groups_for_tracking:
        # Only running noisy double pendulum in latent forecasting mode
        continue

    if sde_type == "tracking":
      if group not in allowable_groups_for_tracking:
        if group != "noisy_pendulum" and group != "fixed_times_pendulum":
          continue

      if base_name not in physics_base_names:
        continue

    elif sde_type == "brownian":
      if group not in allowable_groups_for_brownian:
        continue

      if base_name in physics_base_names:
        continue




    if model_name == "my_neural_ode" or model_name == "my_neural_ode_rnn":
      if freq_value != 1:
        continue

    elif model_name == "my_neural_sde" or model_name == "my_neural_sde_rnn" or model_name == "my_neural_sde_rnn_bwd":
      if freq_value != 1:
        continue

    elif model_name == "my_diffusion_model" or model_name == "my_diffusion_model_rnn":
      if freq_value != 0:
        continue

    elif model_name == "baseline_autoregressive" or model_name == "baseline_autoregressive_rnn":
      if freq_value != 0:
        continue

    elif model_name == "my_non_probabilistic" or model_name == "my_non_probabilistic_rnn":
      if freq_value != 0:
        continue

    elif model_name == "true_baseline_autoregressive" or model_name == "true_baseline_autoregressive_rnn":
      if freq_value != 0:
        continue

    elif model_name == "baseline_diffusion_model" or model_name == "baseline_diffusion_model_rnn":
      if freq_value != 0:
        continue

    elif model_name == "my_autoregressive_reparam" or model_name == "my_autoregressive_reparam_rnn":
      if freq_value != 0:
        continue

    elif model_name == "my_autoregressive_reparam_small" or model_name == "my_autoregressive_reparam_small_rnn":
      if freq_value != 0:
        continue

    elif model_name == "my_autoregressive" or model_name == "my_autoregressive_rnn":
      """ONLY DOING THIS FOR THE MOMENT!!!"""
      if freq_value != 0:
        continue

    elif model_name == "my_neural_sde_rnn_bwd":
      if freq_value != 1:
        continue

    train_status: Dict[str, Any] = eid.experiment_training_status()
    eval_status: Dict[str, Any] = eid.experiment_evaluation_status()
    metric_status: Dict[str, Any] = eid.get_metric_status()

    # Determine whether to skip this configuration based on training/testing status
    if config["train_test_sample"] == "--train":
      # Skip if the kind argument is not train
      if args.kind != "all":
        if args.kind != "train":
          print(f"Skipping training config (kind argument is not train): \n{str(eid)}\n")
          continue

      # Skip if training is already complete
      if train_status["is_complete"] and not config["retrain"]:
        print(f"Skipping training config (training already complete): \n{str(eid)}\n")
        continue

      if config["only_generate_samples"] != "":
        print(f"Cannot specify both --train and --only_generate_samples")
        continue

    elif config["train_test_sample"] == "--test_time_evaluation":
      # Skip if the kind argument is not test or results
      if args.kind != "all":
        if args.kind != "test" and args.kind != "results":
          print(f"Skipping testing config (kind argument is not test or results): \n{str(eid)}\n")
          continue

      # Skip if training is not complete
      if not train_status["is_complete"]:
        print(f"Skipping testing config (training not complete): \n{str(eid)}\n")
        continue

      # Apply additional testing logic
      if config["only_generate_samples"] == "": # Compute metrics
        # Skip if sample generation is not complete
        if eval_status["is_complete"] == False:
          print(f"Skipping testing config (sample generation not complete): \n{str(eid)}\n")
          continue

        # # Skip if metrics computation is already complete
        # if metric_status["is_complete"] == True:
        #   print(f"Skipping testing config (metrics already computed): \n{str(eid)}\n")
        #   continue
      elif config["only_generate_samples"] == "--only_generate_samples": # Generate samples
        # Skip if the kind argument is results
        if args.kind != "all":
          if args.kind == "results":
            print(f"Skipping sample generation config (kind argument is results): \n{str(eid)}\n")
            continue

        # Skip if sample generation is already complete
        if eval_status["is_complete"] and not config["restart_evaluation"]:
          print(f"Skipping sample generation config (samples already generated): \n{str(eid)}\n")
          continue

    print(f"Adding config: \n{str(eid)}\n")
    configs.append(config)

  return configs

def write_script(cfg: Dict[str, Any],
                 template_str: str,
                 output_dir: str) -> tuple[str | None, str | None]:
  """Writes a single shell script."""

  # Ensure required keys exist in cfg (simplified check)
  required_keys = ['config_file', 'freq', 'sde_type', 'model_name', 'group',
                    'global_key_seed', 'retrain', 'train_test_sample', 'log_plots',
                    'sanity_check', 'only_generate_samples', 'restart_evaluation',
                    'gpu', 'gpu_type', 'experiment_name', 'time']
  missing_keys = [key for key in required_keys if key not in cfg]
  if missing_keys:
    raise ValueError(f"Missing keys {missing_keys} in config for {cfg.get('experiment_name', 'N/A')}")

  # Construct python command safely
  python_command_parts = [
      "python main.py",
      f"--config_file={cfg['config_file']}",
      f"--freq={cfg['freq']}",
      f"--sde_type={cfg['sde_type']}",
      f"--model_name={cfg['model_name']}",
      f"--group={cfg['group']}",
      f"--global_key_seed={cfg['global_key_seed']}",
  ]
  # Add optional flags only if they are not empty strings
  for key in ['retrain', 'train_test_sample', 'log_plots', 'sanity_check',
              'only_generate_samples', 'restart_evaluation', 'no_leakage', 'hyperparameter_tuning']:
    # Check if the key exists and its value is truthy (not empty string, not None, not False)
    if cfg.get(key):
        python_command_parts.append(str(cfg[key]))

  # Join with space and backslash-newline for shell script line continuation
  # Ensure correct escaping for the format string later
  python_command = " ".join(python_command_parts)

  # Format into the template string
  script_content = template_str.format(
      python_command=python_command,
      gpu=cfg['gpu'],
      gpu_type=cfg['gpu_type'],
      experiment_name=cfg['experiment_name'],
      time=cfg['time']
  )

  os.makedirs(output_dir, exist_ok=True)
  script_filename = f"{cfg['experiment_name']}.sh"
  script_path = os.path.join(output_dir, script_filename)

  with open(script_path, "w") as f:
      f.write(script_content)
  print(f"Wrote script: {script_path}")

  # Make executable
  os.chmod(script_path, os.stat(script_path).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)

  return script_path, python_command

def main():
  args = parse_args()

  output_dir = "script_generator/generated_training_scripts"
  script_template_path = "script_generator/script_template.sh"

  if not os.path.exists(script_template_path):
    raise ValueError(f"Script template file not found: {script_template_path}")

  with open(script_template_path, "r") as f:
    template_str = f.read()

  # Clean output directory
  if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    print(f"Removed existing output directory: {output_dir}")

  configs = get_experiment_configs(args)
  print(f"Processing {len(configs)} experiment configurations.")

  # Separate configs (keep simple separation)
  training_configs = [cfg for cfg in configs if cfg.get("train_test_sample") == "--train"]
  eval_configs = [cfg for cfg in configs if cfg.get("train_test_sample") == "--test_time_evaluation"]

  python_commands = []
  sbatch_commands = []
  processed_scripts = 0

  # Process training scripts
  for i, cfg in enumerate(training_configs):
    script_path, python_cmd = write_script(cfg, template_str, output_dir)
    if script_path and python_cmd:
      python_commands.append(python_cmd)
      sbatch_commands.append(f"sbatch {script_path}")
      processed_scripts += 1

  # Process evaluation scripts
  for i, cfg in enumerate(eval_configs):
    script_path, python_cmd = write_script(cfg, template_str, output_dir)
    if script_path and python_cmd:
      python_commands.append(python_cmd)
      sbatch_commands.append(f"sbatch {script_path}")
      processed_scripts += 1

  # Write command lists
  os.makedirs("script_generator", exist_ok=True)
  with open("script_generator/python_commands.txt", "w") as f:
    f.write("\n".join(python_commands))
  with open("script_generator/sbatch_commands.txt", "w") as f:
    f.write("\n".join(sbatch_commands))

  print(f"Generated {len(sbatch_commands)} sbatch commands successfully.")

if __name__ == "__main__":
    # Assuming get_experiment_configs is defined elsewhere or you add it back
    # from debug import * # Assuming this is used for debugging
    main()