import jax
import pickle
from pathlib import Path
from pipeline.training.runtime_resolve import Experiment
from pipeline.training.state_init import TrainState
from pipeline.training.train_configs import TrainConfig
from pipeline.training.state_init import EmulatorState, SummaryState

def save_checkpoint(
    path: Path,
    state: TrainState,
    train_config: TrainConfig,
    epoch: int,
    experiment: Experiment
):
    path.mkdir(parents=True, exist_ok=True)

    ckpt = {
        "epoch": epoch,
        "step": state.step,
        "emulator_params": jax.device_get(state.emulator.params),
        "emulator_opt_state": jax.device_get(state.emulator.opt_state),
        "summary_params": jax.device_get(state.summary.params),
        "summary_opt_state": jax.device_get(state.summary.opt_state),
        "critic_params": jax.device_get(state.critic.params),
        "critic_opt_state": jax.device_get(state.critic.opt_state),
    }

    with open(path / "checkpoint.pkl", "wb") as f:
        pickle.dump(ckpt, f)

def load_train_state_from_checkpoint(ckpt: dict) -> TrainState:
    """
    Reconstruct TrainState exactly as used in training.
    No initialization. No randomness. No assumptions.
    """

    emulator_state = EmulatorState(
        params=ckpt["emulator_params"],
        opt_state=ckpt["emulator_opt_state"],
    )

    summary_state = SummaryState(
        params=ckpt["summary_params"],
        opt_state=ckpt["summary_opt_state"],
    )

    return TrainState(
        emulator=emulator_state,
        summary=summary_state,
        step=int(ckpt["step"]),
    )
