from debug import *
import os
from Utils.io_utils import load_yaml_config
import equinox as eqx
import jax
from diffusion_crf import TimeSeries
from typing import Optional, Dict, Any, TypeVar, Union, cast
TrainingState = TypeVar('TrainingState')
import optuna

"""
This module defines the ExperimentIdentifier class, which provides a standardized way to identify
and manage experiments in the neural diffusion CRF framework. The ExperimentIdentifier handles:

1. Experiment identification via a set of parameters (config, model type, SDE type, etc.)
2. File path management for saving and loading models, checkpoints, and results
3. Configuration management for experiment reproducibility
4. Experiment status tracking and metadata generation
5. Data loading and model state access

The ExperimentIdentifier serves as a central coordinator for all experiment-related operations,
ensuring consistent naming conventions and file organization across the codebase.
"""

################################################################################################################

if jax.devices()[0].platform == 'gpu':
  SAVE_PATH = '/project/pi_drsheldon_umass_edu/eddie'
else:
  SAVE_PATH = './'

class ExperimentIdentifier(eqx.Module):
  """Utility class to identify an experiment and deal with filepaths"""
  config_name: str
  model_name: str
  sde_type: str
  freq: int
  objective: str
  group: str
  global_key_seed: int

  def __init__(self, config: dict):
    args = config['command_line_args']
    model_configs = config['model']
    # Extract the data from the dataset
    config_name = os.path.splitext(os.path.basename(args['config_file']))[0]

    self.config_name = config_name
    self.model_name = args['model_name']
    self.sde_type = args['sde_type']
    self.freq = args['freq']
    self.objective = model_configs['objective']
    self.group = args['group']
    self.global_key_seed = args['global_key_seed']

  def get_save_path(self):
    return SAVE_PATH

  def get_samples_and_metrics_path(self):
    return os.path.join(self.samples_folder_name, 'samples_and_metrics.npz')

  @property
  def result_metric_path(self):
    return 'updated_experiment_results.csv'

  def get_metrics_to_compute(self):
    config = self.create_config()
    return config['dataset']['metric_to_compute']

  def get_evaluation_settings(self):
    config = self.create_config()
    return config['dataset']['evaluation_settings']

  def get_nice_dataset_name(self):
    config = self.create_config()
    return config['dataset']['nice_name']

  def has_denoised_data(self):
    # hacked solution
    evaluation_settings = self.get_evaluation_settings()
    return 'future_denoised_observation' in evaluation_settings

  def to_dict(self):
    return dict(config_name=self.config_name,
                model_name=self.model_name,
                sde_type=self.sde_type,
                freq=self.freq,
                objective=self.objective,
                group=self.group,
                global_key_seed=self.global_key_seed)

  def get_model_objective(self):
    """Get the original objective without trial information"""
    objective = self.objective
    if "_trial" in objective:
      return objective.split("_trial")[0]
    return objective

  def get_trial_id(self):
    """Get trial ID if present in the objective"""
    if "_trial" in self.objective:
      return int(self.objective.split("_trial")[1])
    return None

  def get_model_identifier(self):
    return (self.config_name, self.objective, self.model_name, self.sde_type, f'freq_{self.freq}', self.group, f'seed_{self.global_key_seed}')

  def __str__(self):
    return (
      f"Experiment Configuration:\n"
      f"  Config File: {self.config_name}\n"
      f"  Model: {self.model_name}\n"
      f"  Training Objective: {self.get_model_objective()}\n"
      f"  Trial ID: {self.get_trial_id()}\n"
      f"  SDE Type: {self.sde_type}\n"
      f"  Interpolation Frequency: {self.freq}\n"
      f"  Experiment Group: {self.group}\n"
      f"  Random Seed: {self.global_key_seed}"
    )

  @property
  def study_name(self):
    """The name of the study that is independent of the trial ID"""
    return f"{self.model_name}_{self.config_name}_freq{self.freq}_{self.sde_type}_seed{self.global_key_seed}_group{self.group}_objective{self.get_model_objective()}"

  @property
  def trial_name(self):
    return f"{self.model_name}_{self.config_name}_freq{self.freq}_{self.sde_type}_seed{self.global_key_seed}_group{self.group}_objective{self.get_model_objective()}_trial{self.get_trial_id()}"

  def create_config(self,
                    trial: Optional[optuna.trial.Trial] = None,
                    inject_hyperparameters: Optional[bool] = False,
                    load_best_hyper_params: Optional[bool] = False,
                    trial_id: Optional[int] = None,
                    command_line_args: Optional[dict] = None,
                    train_if_needed: Optional[bool] = False,
                    retrain: Optional[bool] = False):
    """
    Create a config dictionary from the experiment identifier attributes.
    This is the reverse of the __init__ method.

    Args:
      trial (optuna.trial.Trial, optional): The trial to inject the hyperparameters from.
      inject_hyperparameters (bool, optional): Whether to inject the hyperparameters from the trial into the config.
      load_best_hyper_params (bool, optional): Whether to load the best hyperparameters from the study.
      command_line_args (dict, optional): The command line args to use for the config.

    Returns:
      dict: A config dictionary with command_line_args, model, and dataset keys.
    """
    current_trial_id = self.get_trial_id()
    if inject_hyperparameters:
      assert trial is not None, 'Trial must be provided if inject_hyperparameters is True'
      assert load_best_hyper_params is False, 'Cannot load best hyperparameters if inject_hyperparameters is True'
      assert current_trial_id is None, "Cannot inject hyperparameters if trial ID is provided"

    if load_best_hyper_params:
      assert inject_hyperparameters is False, 'Cannot inject hyperparameters if load_best_hyper_params is True'
      assert current_trial_id is None, "Cannot load best hyperparameters if trial ID is provided"

    if trial_id is not None:
      assert current_trial_id is None, "Cannot provide trial ID if trial ID is already provided"

    # Create the config file path
    try:
      config_file = f"Config/{self.config_name}.yaml"
      full_config = load_yaml_config(config_file)
    except FileNotFoundError:
      config_file = f"Config/harmonic_oscillator/{self.config_name}.yaml"
      full_config = load_yaml_config(config_file)

    # Create the command line args with default values
    if command_line_args is None:
        command_line_args = {
          'config_file': config_file,
          'model_name': self.model_name,
          'sde_type': self.sde_type,
          'freq': self.freq,
          'group': self.group,
          'global_key_seed': self.global_key_seed,
          'train': train_if_needed,
          'retrain': retrain,
          'debug': False,
          'test_time_evaluation': False,
          'log_plots': False,
          'checkpoint_every': 1000,
          'sanity_check': False,
          'just_load': not train_if_needed
        }

    if self.model_name == 'my_neural_sde' or self.model_name == 'my_neural_ode':
      command_line_args['log_plots'] = False

    # Get the model configs
    model_configs = full_config[self.model_name]

    # Override the objective with our stored value
    model_configs['objective'] = self.objective

    # If we are running a trial, we need to add the trial ID to the objective
    trial_id = None
    if trial is not None:
      trial_id = trial.number
    elif trial_id is not None:
      trial_id = trial_id

    if trial_id is not None:
      model_configs['objective'] = f"{model_configs['objective']}_trial{trial_id}"

    # Get the dataset configs
    dataset_configs = full_config['dataset']

    # Return the final configs that we'll use
    output_configs = dict(dataset=dataset_configs,
                          model=model_configs,
                          command_line_args=command_line_args)

    from main import config_hack
    output_configs = config_hack(output_configs)

    if inject_hyperparameters:
      from Models.hyper_params import inject_hyperparameters
      output_configs = inject_hyperparameters(output_configs, trial)

    if load_best_hyper_params:
      from Models.hyper_params import load_best_hyper_params
      output_configs = load_best_hyper_params(output_configs)

    return output_configs

  @staticmethod
  def from_model_identifier(model_identifier: tuple):
    config_name, objective, model_name, sde_type, freq_str, group, seed_str = model_identifier

    # Parse freq and seed from their string representations
    freq = int(freq_str.split('_')[1])
    global_key_seed = int(seed_str.split('_')[1])

    # Construct the config dict that __init__ expects
    config = {
      'command_line_args': {
        'config_file': f'{config_name}.yaml',  # We don't actually need the file, just the name
        'model_name': model_name,
        'sde_type': sde_type,
        'freq': freq,
        'group': group,
        'global_key_seed': global_key_seed
      },
      'model': {
        'objective': objective
      }
    }

    return ExperimentIdentifier(config=config)

  @property
  def model_folder_name(self):
    path = os.path.join(SAVE_PATH, 'model_checkpoints', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def samples_folder_name(self):
    path = os.path.join(SAVE_PATH, 'samples', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def metrics_folder_name(self):
    path = os.path.join(SAVE_PATH, 'metrics', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def plots_folder_name(self):
    path = os.path.join(SAVE_PATH, 'plots', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def checkpoint_dir(self):
    path = os.path.join(SAVE_PATH, 'evaluation_checkpoints', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def nfe_results_folder_name(self):
    path = os.path.join(SAVE_PATH, 'nfe_results', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def optuna_dir(self):
    path = os.path.join(SAVE_PATH, 'optuna_studies', *self.get_model_identifier())
    os.makedirs(path, exist_ok=True)
    return path

  @property
  def study_storage_path(self):
    """Path to the SQLite database storing Optuna studies"""
    path = os.path.join(self.optuna_dir, f"{self.study_name}.db")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return f"sqlite:///{path}"

  def load_best_hyper_params(self):
    from Models.hyper_params import load_best_hyper_params
    return load_best_hyper_params(self)

  @property
  def better_model_name(self):
    """
    Maps the internal model_name to a more presentable name for papers.

    Returns:
        str: The paper-friendly model name, or the original if not found in the mapping
    """
    if self.freq != 0:
      my_ar_name = f'AR MSE (Upsampled {self.freq + 1}x)'
    else:
      my_ar_name = 'AR MSE'
    mapping = dict(
        my_non_probabilistic='MSE',
        my_autoregressive=my_ar_name,
        my_autoregressive_reparam='AR (Latent MSE)',
        # my_autoregressive_reparam='AR MSE #2',
        baseline_autoregressive='AR (Latent MLE)',
        my_diffusion_model='Diffusion (Latent FM)',
        my_neural_ode='Neural ODE (Latent FM)',
        my_neural_sde='Neural SDE (Latent DM)',
        true_baseline_autoregressive='AR (MLE)',
        baseline_diffusion_model='Diffusion (FM)',
        my_neural_sde_rnn_bwd='Neural SDE (MSE-bwd)',
        my_autoregressive_reparam_rnn_bwd='AR (MSE-bwd)',
    )

    lookup_model_name = self.model_name
    if self.model_name.endswith('_rnn'):
      lookup_model_name = self.model_name[:-4]  # Remove '_rnn' suffix

    if lookup_model_name in mapping:
        return mapping[lookup_model_name]
    else:
        print(f"Warning: No better name mapping found for model '{self.model_name}'")
        return self.model_name

  def experiment_training_status(self):
    from Models.training_tracker import get_training_status
    return get_training_status(self)

  def experiment_evaluation_status(self):
    from Models.result_data_checkpointer import get_evaluation_status
    return get_evaluation_status(self)

  def get_metric_status(self):
    """Get the status of computed metrics for this experiment."""
    from Models.check_metric_status import get_metric_status
    return get_metric_status(self)

  def get_raw_data(self) -> dict[str, TimeSeries]:
    from main import get_dataset_no_leakage
    config = self.create_config()
    return get_dataset_no_leakage(self, config, return_raw_data=True)

  def get_data(self) -> dict[str, TimeSeries]:
    assert 0, 'Deprecated'
    from main import get_dataset
    config = self.create_config()
    train_data, val_data, test_data = get_dataset(self, config)
    return dict(train_data=train_data, val_data=val_data, test_data=test_data)

  def get_data_fixed(self) -> dict[str, TimeSeries]:
    from main import get_dataset_no_leakage
    config = self.create_config()
    train_data, val_data, test_data = get_dataset_no_leakage(self, config)
    return dict(train_data=train_data, val_data=val_data, test_data=test_data)

  def get_train_state(self) -> TrainingState:
    from main import get_dataset, get_dataset_no_leakage, load_jax_model
    config = self.create_config()
    train_data, val_data, test_data = get_dataset_no_leakage(self, config)
    trainer, train_state = load_jax_model(train_data, val_data, test_data, self, config)
    return train_state

  def train_state_training_complete(self, train_state: TrainingState) -> bool:
    config = self.create_config()
    model_configs = config['model']
    args = config['command_line_args']
    max_train_steps = model_configs['optimizer']['max_train_steps']
    checkpoint_every = args['checkpoint_every']
    early_stopping_patience = 3
    early_stopping_patience_reached = train_state.number_of_steps_since_best_validation_loss > early_stopping_patience
    max_train_steps_reached = train_state.i >= max_train_steps
    return early_stopping_patience_reached or max_train_steps_reached

  def training_is_complete(self):
    train_state = self.get_train_state()
    return self.train_state_training_complete(train_state)

  def get_training_metadata(self) -> 'TrainingMetadata':
    from Models.training_tracker import load_training_metadata
    return load_training_metadata(self)

  def get_evaluation_metadata(self):
    from Models.result_data_checkpointer import ResultDataCheckpointer
    checkpointer = ResultDataCheckpointer(self)
    return checkpointer.get_evaluation_metadata()

  def get_experiment_result_data(self) -> 'ExperimentResultData':
    from Models.result_data_checkpointer import ResultDataCheckpointer
    checkpointer = ResultDataCheckpointer(self)
    return checkpointer.load_all_results()

  @staticmethod
  def find_all_experiments():
    """Find all experiments with training metadata using an efficient glob approach."""
    return find_all_experiments()

  @staticmethod
  def make_experiment_id(config_name: str, objective: Union[str, None], model_name: str,
                         sde_type: str, freq: Union[int, None], group: str, seed: int):
    """
    Create an ExperimentIdentifier directly from individual parameters.

    Args:
      config_name: Name of the configuration file (without extension)
      objective: The training objective
      model_name: Name of the model
      sde_type: Type of SDE
      freq: Interpolation frequency
      group: Experiment group
      seed: Random seed

    Returns:
      ExperimentIdentifier: A new experiment identifier instance
    """
    # If model name ends with _rnn, strip that suffix for objective lookup
    lookup_model_name = model_name

    if lookup_model_name.endswith('_bwd'):
      lookup_model_name = lookup_model_name[:-4]  # Remove '_bwd' suffix

    if lookup_model_name.endswith('_rnn'):
      lookup_model_name = lookup_model_name[:-4]  # Remove '_rnn' suffix

    if objective is None:
      lookup_dict = dict(my_autoregressive='mse',
                         my_autoregressive_reparam='mse',
                         my_neural_sde='drift_matching',
                         my_neural_ode='flow_matching',
                         my_diffusion_model='flow_matching',
                         baseline_autoregressive='ml',
                         baseline_diffusion_model='flow_matching',
                         true_baseline_autoregressive='ml',
                         my_non_probabilistic='mse')

      objective = lookup_dict[lookup_model_name]

    if freq is None:
      lookup_dict = dict(my_autoregressive=0,
                         my_autoregressive_reparam=0,
                         my_neural_sde=1,
                         my_neural_ode=1,
                         my_diffusion_model=0,
                         baseline_autoregressive=0,
                         baseline_diffusion_model=0,
                         true_baseline_autoregressive=0,
                         my_non_probabilistic=0)
      freq = lookup_dict[lookup_model_name]

    config = {
      'command_line_args': {
        'config_file': f'{config_name}.yaml',
        'model_name': model_name,
        'sde_type': sde_type,
        'freq': freq,
        'group': group,
        'global_key_seed': seed
      },
      'model': {
        'objective': objective
      }
    }

    return ExperimentIdentifier(config=config)

def find_all_experiments():
  """Find all experiments with training metadata using an efficient glob approach."""
  import glob
  from tqdm import tqdm
  import time

  base_dir = SAVE_PATH
  if not os.path.exists(base_dir):
    return []

  start_time = time.time()
  print(f"Searching for seed folders in {base_dir}...")

  # Find seed folders using os.walk (faster than recursive glob)
  seed_folders = []
  model_checkpoints_path = os.path.join(base_dir, "model_checkpoints")

  if os.path.exists(model_checkpoints_path):
    for root, dirs, _ in os.walk(model_checkpoints_path):
      # Only look for "seed_*" directories
      seed_dirs = [d for d in dirs if d.startswith("seed_")]
      for seed_dir in seed_dirs:
        seed_folders.append(os.path.join(root, seed_dir))


  print(f"Found {len(seed_folders)} seed folders in {time.time() - start_time:.1f} seconds")

  experiments = []
  # Process each seed folder with a progress bar
  for seed_folder in tqdm(seed_folders, desc="Processing experiments"):
    # Check if this folder contains a training_metadata.json file
    metadata_path = os.path.join(seed_folder, "training_metadata.json")
    if not os.path.exists(metadata_path):
      continue

    # Get relative path from base_dir
    rel_path = os.path.relpath(seed_folder, base_dir)
    path_parts = rel_path.split(os.sep)

    # Need at least 7 path components to create an experiment identifier
    if len(path_parts) >= 7:
      # Extract components needed for experiment identifier
      config_name = path_parts[-7]
      objective = path_parts[-6]
      model_name = path_parts[-5]
      sde_type = path_parts[-4]

      # Handle the freq component which may have different formats
      freq = path_parts[-3]
      if not freq.startswith("freq_"):
        freq = f"freq_{freq.split('_')[1] if '_' in freq else freq}"

      group = path_parts[-2]

      # Handle the seed component which may have different formats
      seed = path_parts[-1]
      if not seed.startswith("seed_"):
        seed = f"seed_{seed.split('_')[1] if '_' in seed else seed}"

      # Create model identifier tuple
      model_id = (
        config_name,
        objective,
        model_name,
        sde_type,
        freq,
        group,
        seed
      )

      # Create experiment identifier
      experiment_id = ExperimentIdentifier.from_model_identifier(model_id)
      try:
        config = experiment_id.create_config()
        experiments.append(experiment_id)
      except Exception as e:
        print(f"Error creating config for {seed_folder}: {e}")
    else:
      tqdm.write(f"Warning: Path does not have enough components ({len(path_parts)}): {seed_folder}")

  return experiments
