# config_setup.py
import os
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils



import logging
from .logs import setup_logging, sanitize_for_pickle


import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Hierarchical POU Training")
    # Process/Cluster Args
    parser.add_argument("--cluster_detection_method", type=str, default="auto",
                        choices=["auto", "slurm", "mpi", "none"],
                        help="Method for detecting cluster environment ('auto', 'slurm', 'mpi', 'none')")
    parser.add_argument("--coordinator_address", type=str, default="localhost:1234",
                        help="Coordinator address for 'mpi' method.")
    # Backend Args
    parser.add_argument("--backend", type=str, default="gpu", choices=["gpu", "cpu"],
                        help="JAX backend to use ('gpu', 'cpu')")
    # JAX Config Args
    parser.add_argument("--debug", action="store_true", help="Enable JAX NaN/Inf checks.")
    parser.add_argument("--disable_jit", action="store_true", help="Disable JAX JIT compilation.")
    parser.add_argument("--profile", action="store_true", help="Enable JAX profiling.")
    parser.add_argument("--use_float64", action="store_true", default=False, help="Enable float64 precision.")

    parser.add_argument("--log_spectral", action="store_true", default=False, help="document spectral metrics")

    # Application Args
    parser.add_argument("--track_solver_iters", action="store_true", help="Track solver iterations (adds overhead).")
    parser.add_argument("--solver_type", type=str, default='direct',
                        choices=['direct', 'iterative_monolithic', 'iterative_block', 'iterative_block_distributed'],
                        help="Type of linear solver for coefficients.")
    parser.add_argument("--Pc",    type=int, required=True, default=2, help="Number of coarse partitions.")
    parser.add_argument("--Pf",    type=int, required=True, default=3, help="Number of fine partitions per coarse.")
    parser.add_argument("--gate_n_hidden",    type=int, required=True, default=20, help="hidden dim for gating.")
    parser.add_argument("--poly_n_hidden",    type=int, required=True, default=20, help="hidden dim for poly.")
    parser.add_argument("--poly_basis_size",  type=int, required=True, default=20, help="hidden dim for basis_size.")
    parser.add_argument("--outer_iters", type=int, required=True, default=2000, help="Number of training iterations.")
    parser.add_argument("--unjit_training_step", action="store_true", default=False, help="un-Jit training step.")
    parser.add_argument("--bench_mark_runs", type=int, required=False, default=100, help="Number of blended least square solve benchmark runs.")

    parser.add_argument("--problem_dim",    type=int, required=True, default=1, help="dimension of problem data.")

    # save output
    parser.add_argument("--log_file", type=str, default=None,
                        help="Optional file path to redirect stdout/stderr.")
    parser.add_argument("--metrics_file", type=str, default=None,
                        help="Optional file path to save training metrics (.npz format).")

    # for reproducability
    parser.add_argument("--results_dir_name", type=str, default="results",
                        help="File path to save the results.")
    parser.add_argument("--save_state_dir", type=str, default=None,
                        help="File path to save the final model state (params, coeffs).")
    parser.add_argument("--load_state_dir", type=str, default=None,
                        help="File path to load initial model state (params, coeffs) from.")

    return parser.parse_args()


def init_distrbutive_env(runtime_args):

    # Convert paths first
    # (Make paths absolute based on runtime_args)
    if runtime_args.log_file:
        runtime_args.log_file = os.path.abspath(runtime_args.log_file)
    if runtime_args.metrics_file:
        runtime_args.metrics_file = os.path.abspath(runtime_args.metrics_file)
    if runtime_args.save_state_dir:
        runtime_args.save_state_dir = os.path.abspath(runtime_args.save_state_dir)
    if runtime_args.load_state_dir:
        runtime_args.load_state_dir = os.path.abspath(runtime_args.load_state_dir)

    # Setup Logging and JAX
    log_level = "INFO"
    setup_logging(log_level_str=log_level, log_file=runtime_args.log_file, rank=0)
    logger = logging.getLogger()

    # This call HAS to happen before any other call to JAX, otherwise
    # jax.distributed.initialize will not initialize correctly and the code will
    # default to serial execution mode.
    rank, size = setup_environment_and_jax(runtime_args, logger)

    return (logger, rank, size)


def setup_environment_and_jax(args, logger):
    """
    Sets environment variables, initializes distributed JAX if needed,
    and configures JAX based on args.

    Returns:
        tuple: (rank, size)
    """
    # Environment Setup
    if args.backend.lower() == "cpu":
        os.environ["JAX_PLATFORMS"] = "cpu"
    else:
        # Default to GPU if not CPU, clear platform setting to let JAX find GPUs
        os.environ.pop("JAX_PLATFORMS", None)

    # Unset proxy variables
    for var in ['HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy',
                'no_proxy', 'NO_PROXY', 'FTP_PROXY', 'ftp_proxy']:
        os.environ.pop(var, None)

    # Distributed Initialization
    rank = 0
    size = 1
    cluster_method = args.cluster_detection_method.lower()

    # Try auto-detection first if requested
    if cluster_method == "auto":
        if "SLURM_PROCID" in os.environ:
            logger.info("Auto-detected SLURM environment.")
            cluster_method = "slurm"

    if cluster_method == "slurm":
        try:
            jax.distributed.initialize() # Auto-detects from SLURM env vars
            rank = jax.process_index()
            size = jax.process_count()
            logger.info(f"SLURM init successful: Rank {rank}/{size}")
        except Exception as e:
            logger.warn(f"SLURM initialization failed: {e}. Falling back to single process.")
            cluster_method = "none" # Fallback

    elif cluster_method == "mpi":
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            mpi_rank = comm.Get_rank()
            mpi_size = comm.Get_size()
            jax.distributed.initialize(
                coordinator_address=args.coordinator_address,
                num_processes=mpi_size,
                process_id=mpi_rank,
            )
            rank = jax.process_index() # Should match mpi_rank
            size = jax.process_count() # Should match mpi_size
            logger.info(f"MPI init successful: Rank {rank}/{size}")
        except ImportError:
             logger.warn("mpi4py not found. MPI initialization skipped. Falling back to single process.")
             cluster_method = "none" # Fallback
        except Exception as e:
            logger.warn(f"MPI initialization failed: {e}. Falling back to single process.")
            cluster_method = "none" # Fallback

    if cluster_method == "none" or size <= 1:
        rank = 0
        size = 1
        logger.info("Running in single-process mode.")
        # No need to call jax.distributed.initialize

    # JAX Configuration
    if args.disable_jit:
        jax.config.update("jax_disable_jit", True)
        logger.info("JIT disabled.")
    if args.debug:
        jax.config.update("jax_debug_nans", True)
        jax.config.update("jax_debug_infs", True)
        logger.info("JAX debug mode enabled (NaN/Inf checks).")
    if args.use_float64:
        jax.config.update("jax_enable_x64", True)
        logger.info("JAX float64 precision enabled.")
    else:
        jax.config.update("jax_enable_x64", False) # Explicitly disable if not requested

    logger.info(f"JAX process count: {jax.process_count()}, process index: {jax.process_index()}")
    logger.info(f"JAX local device count: {jax.local_device_count()}, global device count: {jax.device_count()}")

    if args.backend.lower() == "gpu" and jax.local_device_count() == 0:
         logger.warn("Backend set to GPU, but no local GPU devices found by JAX!")

    return rank, size


def create_mesh_and_shardings(mesh_axis_name='nodes', logger=None):
    """
    Creates a 1D device mesh based on the number of JAX processes
    and defines standard shardings. Assumes 1 device per process relevant for sharding.

    Args:
        mesh_axis_name (str): Name for the mesh axis.

    Returns:
        tuple: (mesh, replicate_sharding, data_sharding)
               Returns (None, None, None) if process_count is <= 1.
    """
    num_processes = jax.process_count()

    if num_processes <= 1:
        # Create dummy mesh/shardings for single-process compatibility
        # Use first available device
        devices = [jax.devices()[0]]
        mesh = Mesh(devices, axis_names=(mesh_axis_name,))
        logger.info("Created dummy Mesh for single process.")
    else:
        # Create a mesh across all devices detected by JAX, assuming 1D layout over processes
        mesh_shape = (num_processes,)
        devices = mesh_utils.create_device_mesh(mesh_shape)
        mesh = Mesh(devices, axis_names=(mesh_axis_name,))
        logger.info(f"Created Mesh: {mesh} with devices {devices}")

    # Define standard shardings ON THE MESH
    replicate_sharding = NamedSharding(mesh, P())          # Replicated across all devices in mesh
    data_sharding = NamedSharding(mesh, P(mesh_axis_name)) # Sharded along the 'nodes' axis

    return mesh, replicate_sharding, data_sharding


def print_separator(rank, logger):
    """Prints a separator line, synchronized roughly across ranks."""
    if jax.process_count() > 1:
        try:
            # This forces all devices to complete pending work before rank 0 prints.
            dummy = jnp.zeros(len(jax.local_devices()))
            dummy.block_until_ready()
        except Exception:
            pass # Ignore if it fails during shutdown etc.

    if rank == 0:
        logger.info(("---" * 10) + "\n\n")

