from debug import *
import os
from Utils.io_utils import load_yaml_config
import equinox as eqx
import jax
from typing import Optional, Dict, Any, TypeVar, Union, cast, Callable
from Models.experiment_identifier import ExperimentIdentifier
import optuna
import wadler_lindig as wl
import jax.tree_util as jtu

def create_study(experiment_identifier: ExperimentIdentifier):
  """Create an optuna study for the given experiment identifier"""
  study = optuna.create_study(study_name=experiment_identifier.study_name,
                              storage=experiment_identifier.study_storage_path,
                              sampler=optuna.samplers.TPESampler(),
                              pruner=optuna.pruners.MedianPruner(n_warmup_steps=3),
                              load_if_exists=True,
                              direction="minimize")
  return study

def inject_hyperparameters(config: dict, trial: optuna.trial.Trial) -> dict:
  """
  Inject hyperparameters from an Optuna trial into the configuration.

  This function traverses the configuration tree, identifies hyperparameters
  marked with 'optuna_hyper_param', suggests values using Optuna, and updates
  the configuration with these values.

  Args:
      config: Configuration dictionary with hyperparameter specifications
      trial: Optuna trial for suggesting parameter values

  Returns:
      Updated configuration with suggested hyperparameter values
  """
  import copy

  # Clone the config to avoid modifying the original
  result = copy.deepcopy(config)

  def parse_numeric_value(value):
    """
    Safely convert a potential string value to a numeric type.

    Args:
        value: Value to convert, can be string, float, or int

    Returns:
        Numeric representation of the value
    """
    if isinstance(value, (int, float)):
        return value

    try:
        # Try converting to float first
        float_val = float(value)

        # Check if it's actually an integer
        int_val = int(float_val)
        if float_val == int_val:
            return int_val
        return float_val
    except (ValueError, TypeError):
        # If conversion fails, return the original value
        return value

  def flat_dict_key_to_name(key_tuple):
    """
    Convert a flattened dictionary tuple key to a string name for Optuna.

    Example: ('dataset', 'hyper_params', 'noise_std') -> 'dataset.noise_std'
    """
    # Filter out 'hyper_params' from the key path and join with dots
    return '.'.join(str(k) for k in key_tuple if k != 'hyper_params')

  def suggest_optuna_hyper_param(name: str, hyper_param: dict) -> Any:
    """
    Suggest a value for a hyperparameter using Optuna.

    Args:
        name: Parameter name for Optuna (used for reproducibility)
        hyper_param: Dictionary containing parameter specifications

    Returns:
        The suggested parameter value
    """
    param_type = hyper_param['type']
    print(f'Suggesting hyperparameter for: {name}')
    wl.pprint(hyper_param)

    # Handle different parameter types with appropriate Optuna suggestion methods
    if param_type == 'categorical':
      choices = hyper_param['choices']
      return trial.suggest_categorical(name=name, choices=choices)

    elif param_type == 'int':
      # Use min_val/max_val if available, otherwise fall back to min/max
      min_val = parse_numeric_value(hyper_param.get('min_val', hyper_param.get('min', 0)))
      max_val = parse_numeric_value(hyper_param.get('max_val', hyper_param.get('max', 100)))

      # Ensure values are integers
      min_val, max_val = int(min_val), int(max_val)

      step = hyper_param.get('step', 1)
      if step is None:
        step = 1
      step = int(step)

      log = hyper_param.get('log', False)

      print(f"INT parameter {name}: range={min_val}-{max_val}, step={step}, log={log}")

      return trial.suggest_int(
        name=name,
        low=min_val,
        high=max_val,
        step=step,
        log=log
      )

    elif param_type == 'float':
      # Use min_val/max_val if available, otherwise fall back to min/max
      min_val = parse_numeric_value(hyper_param.get('min_val', hyper_param.get('min', 0.0)))
      max_val = parse_numeric_value(hyper_param.get('max_val', hyper_param.get('max', 1.0)))

      log = hyper_param.get('log', False)
      step = hyper_param.get('step')

      print(f"FLOAT parameter {name}: range={min_val}-{max_val}, step={step}, log={log}")

      return trial.suggest_float(
        name=name,
        low=min_val,
        high=max_val,
        step=step,
        log=log
      )
    else:
      raise ValueError(f"Unsupported parameter type: {param_type}")

  def flatten_dictionary(prefix: tuple, d: Any) -> dict:
    """
    Recursively flatten a nested dictionary into a single-level dictionary
    with tuple keys representing the path in the original structure.

    When 'optuna_hyper_param' is encountered, suggests a value using Optuna.

    Args:
        prefix: Current path in the dictionary as a tuple
        d: Dictionary or value to flatten

    Returns:
        Flattened dictionary with tuple keys
    """
    if not isinstance(d, dict):
      assert 0, "Should never reach here with non-dictionary values"
      return d

    flattened = {}
    for key, val in d.items():
      current_key = prefix + (key,)  # Extend the current path with this key

      if key == 'optuna_hyper_param':
        # Found a hyperparameter specification - suggest a value
        hyper_name = flat_dict_key_to_name(prefix)  # Don't include 'optuna_hyper_param' in the name
        new_val = suggest_optuna_hyper_param(hyper_name, d['optuna_hyper_param'])
        flattened[prefix] = new_val  # Store at the parent path

      elif isinstance(val, dict):
        # Recursively flatten nested dictionaries
        flattened_sub_dict = flatten_dictionary(current_key, val)

        # Add all flattened key-value pairs to our result
        for sub_key, sub_val in flattened_sub_dict.items():
          flattened[sub_key] = sub_val

      else:
        # Store regular values with their full path
        flattened[current_key] = val

    return flattened

  # Flatten the config and suggest the hyperparameters
  flattened_config = flatten_dictionary((), result)

  # Identify hyperparameter keys and their corresponding target locations
  hyper_keys = [key for key in flattened_config.keys() if 'hyper_params' in key]

  # Create replacement keys by removing 'hyper_params' from the path
  replace_keys = [tuple([elt for elt in key if elt != 'hyper_params']) for key in hyper_keys]

  # Validate key mapping
  if len(replace_keys) != len(hyper_keys):
    raise ValueError(
      f"Number of hyper keys and replace keys do not match: {len(hyper_keys)} != {len(replace_keys)}. "
      "Check that the structure of the yaml file is so that the hyperparameters have the "
      "same tree structure as the corresponding section of the config."
    )

  # Replace values in the flattened config
  for hyper_key, replace_key in zip(hyper_keys, replace_keys):
    new_value = flattened_config[hyper_key]
    old_value = flattened_config[replace_key]
    flattened_config[replace_key] = new_value
    print(f"Replacing {replace_key} (value: {old_value}) with new value: {new_value}")
    del flattened_config[hyper_key]  # Remove the hyperparameter key

  # Now we need to reconstruct the nested dictionary from the flattened version
  reconstructed = {}
  for key_path, value in flattened_config.items():
    current = reconstructed

    # Navigate to the correct nested location
    for part in key_path[:-1]:
      if part not in current:
        current[part] = {}
      current = current[part]

    # Set the value at the final location
    if key_path:  # Check for empty path
      current[key_path[-1]] = value

  return reconstructed  # Return the reconstructed nested config


def load_best_hyper_params(config: dict) -> dict:
  """Load the best hyper parameters from the study"""
  experiment_identifier = ExperimentIdentifier(config)
  study = optuna.load_study(study_name=experiment_identifier.study_name,
                            storage=experiment_identifier.study_storage_path)
  return inject_hyperparameters(config, study.best_trial)







