import logging
import sys
import os

def setup_logging(log_level_str="INFO", log_file=None, rank=0):
    """Configures the logging module.

    Logs to console (stdout) and optionally to a file.
    Only rank 0 produces output.
    """
    log_level = getattr(logging, log_level_str.upper(), logging.INFO)

    # helpful for debugging
    #log_format = '%(asctime)s - %(levelname)s - Rank %(process)d |     %(message)s'
    #date_format = '%m-%d %H:%M:%S'
    log_format = '%(message)s'
    date_format = None

    # Get the root logger
    logger = logging.getLogger()
    logger.setLevel(log_level) # Set the minimum level for the logger

    # Clear existing handlers (important in interactive environments)
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
        handler.close()

    # CONFIGURE HANDLERS (ONLY FOR RANK 0)
    if rank == 0:
        formatter = logging.Formatter(log_format, datefmt=date_format)

        # Console Handler (stdout)
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(log_level) # Handler respects logger level
        ch.setFormatter(formatter)
        logger.addHandler(ch)

        # File Handler (optional)
        if log_file:
            try:
                # Ensure directory exists
                log_dir = os.path.dirname(log_file)
                if log_dir and not os.path.exists(log_dir):
                    os.makedirs(log_dir, exist_ok=True)

                fh = logging.FileHandler(log_file, mode='w') # Overwrite mode
                fh.setLevel(log_level)
                fh.setFormatter(formatter)
                logger.addHandler(fh)
            except Exception as e:
                logging.error(f"Failed to configure file handler for {log_file}: {e}", exc_info=True)
        logger.info(f"Logging configured for Rank 0. Level: {log_level_str}. File: {log_file}")

    else:
        # For non-zero ranks, add a NullHandler to prevent "No handler found" warnings
        # if any library they use internally tries to log.
        logger.addHandler(logging.NullHandler())

def sanitize_for_pickle(d):
    import numpy as np
    import jax
    import jax.numpy as jnp
    from jaxlib.xla_extension import PjitFunction
    sanitized = {}

    if not isinstance(d, dict):
        # Handle non-dict items
        if isinstance(d, (jnp.ndarray, jax.Array)): return np.array(d)
        if isinstance(d, PjitFunction): return str(d)
        if callable(d): return f"<function:{getattr(d, '__name__', str(d))}>"
        if isinstance(d, (int, float, str, bool, list, tuple, np.ndarray)) or d is None: return d
        return str(d) # fallback


    for k, v in d.items():
        if isinstance(v, (jnp.ndarray, jax.Array)):
            sanitized[k] = np.array(v)
        elif isinstance(v, PjitFunction):
             sanitized[k] = str(v)
        elif callable(v):
            sanitized[k] = f"<function:{getattr(v, '__name__', str(v))}>"
        elif isinstance(v, dict):
            sanitized[k] = sanitize_for_pickle(v)
        elif isinstance(v, (int, float, str, bool, list, tuple, np.ndarray)) or v is None:
             sanitized[k] = v
        else:
             try:
                 sanitized[k] = v
             except (TypeError, pickle.PicklingError):
                 sanitized[k] = str(v)
    return sanitized
