import os
import logging
from .config_setup import (
    create_mesh_and_shardings,
    print_separator
)
from .logs import sanitize_for_pickle

import jax
import jax.numpy as jnp
from jax import random, jit
from jax.sharding import Mesh, NamedSharding
from jax.profiler import trace
import optax
from functools import partial

from contextlib import nullcontext
from time import time
import argparse
import numpy as np
from copy import deepcopy
import pickle

from .custom_types import Callable
from .models import BlendedMLPRegression, TrainingState
from .problems import get_problem_data, get_2d_problem_data

# Define Function for Test after Load
def run_load_test(
    original_model: BlendedMLPRegression,
    problem_params: dict,
    net_setup_params: dict,
    training_params: dict,
    mesh: Mesh,
    replicate_sharding: NamedSharding,
    load_state_dir: str,
    base_seed: int,
    logger: logging.Logger
):

    logger.info("--- Running Load Test")
    try:
        logger.info("Creating a new model instance for loading...")
        rng_test = random.PRNGKey(base_seed + 1) # Sanity check - should have no effect

        test_model = BlendedMLPRegression(
             key=rng_test,
             net_setup_params=net_setup_params,
             training_params=training_params,
             problem_params=problem_params,
             mesh=mesh,
             replicate_sharding=replicate_sharding,
             logger=logger,
             load_state_dir=load_state_dir
        )
        logger.info("Load test model initialized and state loaded.")

        # Compare u_net output
        logger.info("Comparing u_net output for original vs loaded model...")

        original_state = original_model.get_training_state()
        loaded_state = test_model.get_training_state()
        test_x_point = original_model.x
        levels_to_test = [0, 1]
        test_passed = True
        for level in levels_to_test:
            logger.info(f"Testing u_net for level {level}...")

            @partial(jit, static_argnums=(0, 2,))
            def evaluate_unet(u_net, state, level, x):

                coeffs_arg = state.coeffs_c if level == 0 else state.coeffs_f
                params_c_arg = state.params_c
                params_f_arg = state.params_f

                return jax.vmap(u_net, in_axes=(None, None, None, None, 0), out_axes=0)(
                                  level, params_c_arg, params_f_arg, coeffs_arg, x)

            output_orig = evaluate_unet(original_model.u_net, original_state, level, test_x_point)
            output_load = evaluate_unet(test_model.u_net, loaded_state, level, test_x_point)

            output_orig_np = np.array(jax.device_get(output_orig))
            output_load_np = np.array(jax.device_get(output_load))

            are_close = np.allclose(output_orig_np, output_load_np, rtol=1e-6, atol=1e-6)
            diff = np.linalg.norm(output_orig_np - output_load_np)

            logger.info(f"  Level {level}: Absolute diff:   {diff:.4e}")
            logger.info(f"  Level {level}: Outputs close?   {are_close}")

            if not are_close:
                test_passed = False

        if test_passed:
            logger.info("Load Test: PASSED (u_net outputs match for all tested levels).")
        else:
            logger.warning("Load Test: FAILED (u_net outputs differ for at least one level).")

    except Exception as e:
        logger.error(f"Load test failed: {e}", exc_info=True)
    logger.info("--- End Simple Load Test")


# Core Training Execution Function
def run_training_session(
    init_var : tuple,
    runtime_args: argparse.Namespace,
    net_setup_params_getter: Callable,
    training_params_getter: Callable,
    problem_params_getter: Callable,
    base_seed: int = 42
):
    """
    Sets up environment, initializes model, runs training, handles saving/loading.

    Args:
        runtime_args: Parsed arguments for runtime behavior (files, profiling, etc.).
        net_setup_params: Dictionary configuring the network architecture.
        training_params: Dictionary configuring the training process (lr, nIter, etc.).
        problem_params: Dictionary configuring the problem data generation.
        base_seed: Base random seed.
    """


    logger, rank, size = init_var
    mesh, replicate_sharding, data_sharding = create_mesh_and_shardings(mesh_axis_name='nodes', logger=logger)


    net_setup_params = net_setup_params_getter(runtime_args)
    training_params  =  training_params_getter(runtime_args)
    problem_params   =   problem_params_getter(runtime_args)


    # Current code assumes that # of devices evenly divides the # of partitions
    print_separator(rank, logger)
    n_devices = jax.device_count()
    Pc = net_setup_params["gating"]["coarse"]["num_partitions"]
    Pf = net_setup_params["gating"]["fine"]["num_partitions"]
    if  Pc % n_devices != 0:
        logger.info(f"{n_devices=} does not divide {Pc=}")
        return None
    if (Pc * Pf) % n_devices != 0:
        logger.info(f"{n_devices=} does not divide Pc*Pf={Pc*Pf}")
        return None


    print_separator(rank, logger)
    rng = random.PRNGKey(base_seed)
    if jax.process_index() == 0:
        logger.info(f"Rank {rank+1}/{size} starting training session.")
    if jax.process_index() == 0:
        logger.info("Runtime Arguments:")
        for arg_name, value in sorted(vars(runtime_args).items()):
            logger.info(f"  {arg_name}: {value}")

        #logger.info("Network Setup Params: %s", net_setup_params)
        logger.info("Training Params: %s", training_params)
        logger.info("Problem Config: %s", problem_params)
        logger.info(f"Using PRNGKey seed: {base_seed}")
    print_separator(rank, logger)


    # Get Problem Data
    problem_params = get_problem_data(
         problem_params,
         rng_key=rng,
         precision=net_setup_params['dtype'],
         sharding=replicate_sharding,
         logger=logger
    )


    # Model Initialization
    if jax.process_index() == 0:
        logger.info("Initializing the BlendedMLPRegression model...")
    model = BlendedMLPRegression(
        key=rng,
        net_setup_params=net_setup_params,
        training_params=training_params,
        problem_params=problem_params,
        mesh=mesh,
        replicate_sharding=replicate_sharding,
        logger=logger,
        load_state_dir=runtime_args.load_state_dir
    )
    if jax.process_index() == 0:
        logger.info("Model initialized.")
    print_separator(rank, logger)

    # WARM-UP / PRE-COMPILATION
    if jax.process_index() == 0:
        logger.info("Performing warm-up step for JIT compilation...")
    tic_compile = time()
    warmup_state = model.get_training_state()
    _, warmup_metrics = model._train_step(warmup_state, False)
    _ = model.compute_l2_error(0, warmup_state.params_c, warmup_state.params_f, warmup_state.coeffs_c)
    _ = model.compute_l2_error(1, warmup_state.params_c, warmup_state.params_f, warmup_state.coeffs_f)
    jax.block_until_ready(warmup_metrics)
    toc_compile = time()
    compile_time_approx = toc_compile - tic_compile
    if jax.process_index() == 0:
        logger.info(f"Warm-up (compile + 1st exec) approx time: {compile_time_approx:.4f}s")
    print_separator(rank, logger)

    # Training Execution
    trace_context = (
            trace("jax_profile/", create_perfetto_trace=False)
            if runtime_args.profile else nullcontext()
    )
    nIter = training_params["nIter"]
    if jax.process_index() == 0:
        logger.info(f"Starting timed training execution for {nIter} iterations...")
    tic_exec = time()
    with trace_context:
        model.train(
            gpu=(runtime_args.backend == 'gpu'),
            track_solver_iters=runtime_args.track_solver_iters,
        )
        final_state = model.get_training_state()
        jax.block_until_ready(final_state)

    toc_exec = time()
    print_separator(rank, logger)

    # Results and Timing
    execution_time = toc_exec - tic_exec
    iterations_per_sec = nIter / execution_time if execution_time > 0 else float('inf')
    if jax.process_index() == 0:
        logger.info(f"Training completed.")
        logger.info(f"Approx. Compilation Time: {compile_time_approx:.4f}s")
        logger.info(f"Execution Time ({nIter} iterations): {execution_time:.4f}s")
        logger.info(f"Iterations/s (execution only): {iterations_per_sec:.4f}")
    print_separator(rank, logger)

    # Solver Benchmarking (Post-Training)
    benchmark_runs = runtime_args.bench_mark_runs # Configurable number of benchmark runs
    solver_type = net_setup_params["coef_slv_params"]["slv_type"]

    if benchmark_runs > 0 and nIter > 0:
        if jax.process_index() == 0:
            logger.info(f"Starting post-training solver benchmark ({benchmark_runs} runs per level)...")

        benchmark_results_c = model.lstsq_solver.benchmark_solver(0, model, solver_type, num_runs=benchmark_runs, logger=logger)
        print_separator(rank, logger)
        benchmark_results_f = model.lstsq_solver.benchmark_solver(1, model, solver_type, num_runs=benchmark_runs, logger=logger)
        print_separator(rank, logger)

        params_ = model.lstsq_solver.coef_slv_params
        vmap_benchmark_results_c = vmap_benchmark_results_f = None
        if params_['slv_type'] == 'iterative_block_distributed' and size > 1:
            params_['use_vmap'] = True
            vmap_benchmark_results_c = model.lstsq_solver.benchmark_solver(0, model, solver_type, num_runs=benchmark_runs, logger=logger)
            print_separator(rank, logger)
            vmap_benchmark_results_f = model.lstsq_solver.benchmark_solver(1, model, solver_type, num_runs=benchmark_runs, logger=logger)
        elif params_['slv_type'] == 'iterative_block' or (params_['slv_type'] == 'iterative_block_distributed' and size == 1):
            params_['use_vmap'] = True
            model.lstsq_solver._define_iterative_solver()
            vmap_benchmark_results_c = model.lstsq_solver.benchmark_solver(0, model, solver_type, num_runs=benchmark_runs, logger=logger)
            print_separator(rank, logger)
            vmap_benchmark_results_f = model.lstsq_solver.benchmark_solver(1, model, solver_type, num_runs=benchmark_runs, logger=logger)


        if jax.process_index() == 0:
            logger.info("Solver benchmark complete.")
        print_separator(rank, logger)


    model.run_times = {'compile_time_approx': compile_time_approx,
                       'execution_time': execution_time,
                       'iterations_per_sec': iterations_per_sec,
                     }
    model.runtime_args = runtime_args
    # Save Combined Results (Metrics, Params, Times) using Pickle
    if rank == 0 and runtime_args.metrics_file and nIter > 0:
        logger.info(f"Saving combined results to {runtime_args.metrics_file}...")
        try:
            logs_raw = jax.device_get(model.logs)
            logs_np = jax.tree_util.tree_map(np.array, logs_raw)

            # Combine
            save_data = {
                'logs': logs_np,
                'net_setup_params': sanitize_for_pickle(deepcopy(net_setup_params)),
                'training_params': sanitize_for_pickle(deepcopy(training_params)),
                'problem_cfg': sanitize_for_pickle(deepcopy(problem_params)),
                'n_processes':  dict(mesh.shape),
                'run_times': model.run_times,
                'runtime_args': runtime_args,
            }

            if benchmark_runs > 0:
                save_data['run_times'].update({
                     'solver_benchmark_coarse': benchmark_results_c,
                     'solver_benchmark_fine': benchmark_results_f,
                     'solver_benchmark_coarse_vmap': vmap_benchmark_results_c,
                     'solver_benchmark_fine_vmap': vmap_benchmark_results_f,
                })

            # pickle save
            with open(runtime_args.metrics_file, 'wb') as f:
                pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)
            logger.info("Combined results (.pkl) saved successfully.")

        except Exception as e:
            logger.error(f"Error saving combined results pickle: {e}", exc_info=True)
    print_separator(rank, logger)

    """
    if size <= 1:
        # Save State and Data
        if runtime_args.save_state_dir:
            model.save_model_state(runtime_args.save_state_dir)
        print_separator(rank, logger)

        # Load Test
        if rank == 0 and runtime_args.save_state_dir:
            run_load_test(
                original_model=model,
                problem_params=problem_params,
                net_setup_params=net_setup_params,
                training_params=training_params,
                mesh=mesh,
                replicate_sharding=replicate_sharding,
                load_state_dir=runtime_args.save_state_dir,
                base_seed=base_seed,
                logger=logger
            )
    print_separator(rank, logger)
    """

    logger.info(f"Rank {rank+1}/{size} finished training session.")
    return model

