import jax
from jax import Array
from typing import NamedTuple, Any
from pipeline.training.train_configs import TrainConfig
from pipeline.training.runtime_resolve import Runtime

class EmulatorState(NamedTuple):
    params: Any
    opt_state: Any

class SummaryState(NamedTuple):
    params: Any
    opt_state: Any

class CriticState(NamedTuple):
    params: Any
    opt_state: Any

class TrainState(NamedTuple):
    emulator: EmulatorState
    summary: SummaryState
    critic: CriticState
    step: int

def init_state(
    train_config: TrainConfig,
    runtime: Runtime,
    rng: Array,
) -> TrainState:
    rng, emu_rng, sum_rng = jax.random.split(rng, 3)

    # Emulator
    emulator_params = runtime.emulator_init(
        rng=emu_rng,
        dtype=runtime.dtype
    )

    emulator_opt_state = runtime.emulator_optimizer.init(emulator_params)
    emulator_state = EmulatorState(
        params=emulator_params,
        opt_state=emulator_opt_state,
    )

    # Summary
    summary_params = runtime.summary_init(
        rng=sum_rng,
        dtype=runtime.dtype,
    )

    summary_opt_state = runtime.summary_optimizer.init(summary_params)
    summary_state = SummaryState(
        params=summary_params,
        opt_state=summary_opt_state,
    )

    # Critic
    critic_params = runtime.distance_init(
        rng=sum_rng,
        dtype=runtime.dtype,
    )

    critic_opt_state = runtime.critic_optimizer.init(critic_params)
    critic_state = CriticState(
        params=critic_params,
        opt_state=critic_opt_state,
    )

    return TrainState(
        emulator=emulator_state,
        summary=summary_state,
        critic=critic_state,
        step=0,
    )
