import os
import pickle
import jax
import jax.numpy as jnp
from jax import random
import logging
from copy import deepcopy
import importlib.util

from .config_setup import (
    setup_environment_and_jax,
    create_mesh_and_shardings
)

from .models import BlendedMLPRegression
from .problems import basis_function_expansion


def import_fixed_params_from_above(results_dir):
    abs_path = os.path.abspath(os.path.join(results_dir, "../../fixed_params.py"))
    if not os.path.isfile(abs_path):
        raise FileNotFoundError(f"fixed_params.py not found at {abs_path}")

    spec = importlib.util.spec_from_file_location("fixed_params", abs_path)
    fixed_params = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(fixed_params)
    return fixed_params


def load_model_from_results(results_dir: str) -> BlendedMLPRegression | None:
    results_dir = os.path.abspath(results_dir)
    results_file = os.path.join(results_dir, "data.pkl")
    state_dir = os.path.join(results_dir, "state")

    # Use a temporary logger for loading process
    temp_logger = logging.getLogger("ModelLoader")
    temp_logger.setLevel(logging.INFO)
    if not temp_logger.handlers:
        ch = logging.StreamHandler()
        temp_logger.addHandler(ch)

    if not os.path.isdir(results_dir):
        temp_logger.error(f"Results directory not found: {results_dir}")
        return None
    if not os.path.isfile(results_file):
        temp_logger.error(f"Results pickle file not found: {results_file}")
        return None
    if not os.path.isdir(state_dir):
        temp_logger.error(f"Saved state directory ('state/') not found in {results_dir}")
        return None

    # Load fixed_params from two levels up
    try:
        fixed_params = import_fixed_params_from_above(results_dir)
        get_training_params = fixed_params.get_training_params
        base_problem_params = fixed_params.problem_params
        get_model_setup_params = fixed_params.get_model_setup_params
    except Exception as e:
        temp_logger.error(f"Failed to import fixed_params.py: {e}", exc_info=True)
        return None

    # Load Saved Data (contains runtime_args)
    temp_logger.info(f"\nLoading results data from: {results_file}")
    with open(results_file, 'rb') as f:
        saved_data = pickle.load(f)
    temp_logger.info("Results data loaded.")

    if 'runtime_args' not in saved_data:
        temp_logger.error("Saved data pickle does not contain 'runtime_args'. Cannot proceed.")
        return None
    runtime_args = saved_data['runtime_args']
    runtime_args.backend = 'cpu'

    # Setup JAX Environment
    rank, size = setup_environment_and_jax(runtime_args, temp_logger)
    mesh, replicate_sharding, data_sharding = create_mesh_and_shardings(mesh_axis_name='nodes', logger=temp_logger)

    # Reconstruct Config Dictionaries
    temp_logger.info("\nReconstructing configuration dictionaries...")
    precision = jnp.float64 if runtime_args.use_float64 else jnp.float32
    training_params = get_training_params(runtime_args)
    training_params['nIter'] = runtime_args.outer_iters
    problem_cfg = deepcopy(base_problem_params)
    net_setup_params = get_model_setup_params(runtime_args)

    # Prepare Problem Data
    temp_logger.info("\nPreparing problem data...")
    problem_params_eval = basis_function_expansion(
        problem_cfg,
        rng=random.PRNGKey(42),
        precision=precision,
        sharding=replicate_sharding,
        logger=temp_logger
    )

    # Initialize Model and Load State
    temp_logger.info(f"\nInitializing model and loading saved state from: {state_dir}")
    try:
        rng_load = random.PRNGKey(99)
        model = BlendedMLPRegression(
            key=rng_load,
            net_setup_params=net_setup_params,
            training_params=training_params,
            problem_params=problem_params_eval,
            mesh=mesh,
            replicate_sharding=replicate_sharding,
            logger=temp_logger,
            load_state_dir=state_dir
        )
        temp_logger.info("Model initialized successfully with loaded state.")
    except Exception as e:
        temp_logger.error(f"Failed to initialize model or load state: {e}", exc_info=True)
        return None

    temp_logger.info("--- Running Load Test ---")
    if True:
        test_x_point = model.x
        levels_to_test = [0, 1]
        for level in levels_to_test:
            temp_logger.info(f"Testing error for level {level}...")
            error = model.compute_l2_error(
                level,
                model.params_c,
                model.params_f,
                model.coeffs_c if level == 0 else model.coeffs_f
            )
            print(f"error={error:1.2e}")

    return model


