from typing import Callable, Dict, Tuple, Literal

import jax

from deeperwin.mcmc import MetropolisHastingsMonteCarlo, MCMCState


def _run_mcmc_with_cache(
    log_psi_sqr_func: Callable,
    cache_func: Callable,
    mcmc: MetropolisHastingsMonteCarlo,
    params: Dict,
    spin_state: Tuple[int],
    mcmc_state: MCMCState,
    fixed_params: Dict,
    split_mcmc=True,
    merge_mcmc=True,
    mode: Literal["burnin", "intersteps"] = "intersteps",
):
    if split_mcmc:
        mcmc_state = mcmc_state.split_across_devices()

    if cache_func is not None:
        cache_func_pmapped = jax.pmap(cache_func, axis_name="devices", static_broadcasted_argnums=(1, 2))
        fixed_params["cache"] = cache_func_pmapped(params, *spin_state, *mcmc_state.build_batch(fixed_params))

    log_psi_squared_pmapped = jax.pmap(log_psi_sqr_func, axis_name="devices", static_broadcasted_argnums=(1, 2))
    mcmc_state.log_psi_sqr = log_psi_squared_pmapped(params, *spin_state, *mcmc_state.build_batch(fixed_params))
    if mode == "burnin":
        mcmc_state = mcmc.run_burn_in(log_psi_sqr_func, mcmc_state, params, *spin_state, fixed_params)
    elif mode == "intersteps":
        mcmc_state = mcmc.run_inter_steps(log_psi_sqr_func, mcmc_state, params, *spin_state, fixed_params)
    else:
        raise ValueError(f"Unknown MCMC mode: {mode}")
    if merge_mcmc:
        mcmc_state = mcmc_state.merge_devices()
    return mcmc_state, fixed_params
