import jax
jax.config.update("jax_debug_nans", True)
import optax
import jax.numpy as jnp
from typing import Any, Callable, NamedTuple
from pipeline.training.state_init import TrainState, EmulatorState, SummaryState, CriticState

class EpochFlags(NamedTuple):
    use_ot: bool
    lambda_ot: float

def make_train_step(
        mse_objective: Callable[..., jax.Array],
        ot_objective: Callable[..., jax.Array],
        emulator_optimizer: optax.GradientTransformation,
        summary_optimizer: optax.GradientTransformation,
        critic_optimizer: optax.GradientTransformation,
        summary_has_params: bool,
        critic_is_adversary: bool,
        maximizer_steps: int,
        summary_adversarial: bool,
        summary_apply: Callable[..., jax.Array] | None = None,
        critic_weight_clip: float | None = None,
    ) -> Callable[..., tuple[Any, dict[str, jnp.ndarray]]]:

    def train_step(
        state: TrainState,
        batch: jnp.ndarray,
        rng: jnp.ndarray,
        epoch_flags: EpochFlags,
    ) -> tuple[Any, dict[str, jnp.ndarray]]:
        
        λ = epoch_flags.lambda_ot
        use_ot = epoch_flags.use_ot

        def compute_grad_norm(grads):
            leaves = jax.tree_util.tree_leaves(grads)
            if len(leaves) == 0:
                return jnp.array(0.0)
            return jnp.sqrt(sum(jnp.sum(g**2) for g in leaves))

        def zero_non_finite(grads):
            return jax.tree_util.tree_map(lambda g: jnp.where(jnp.isfinite(g), g, 0.0), grads)

        def clip_grads(grads):
            grads = zero_non_finite(grads)
            grads_clipped, _ = optax.clip_by_global_norm(max_norm=max_grad_norm).update(grads, state=None)
            grad_norm_clipped = compute_grad_norm(grads_clipped)
            grads_clipped = jax.lax.cond(
                jnp.isfinite(grad_norm_clipped),
                lambda g: g,
                lambda g: jax.tree_util.tree_map(jnp.zeros_like, g),
                grads_clipped,
            )
            grad_norm_clipped = compute_grad_norm(grads_clipped)
            return grads_clipped, grad_norm_clipped

        summary_state = state.summary
        critic_state = state.critic
        grad_norm_ψ = jnp.array(0.0)
        grad_norm_ω = jnp.array(0.0)
        grad_norm_ψ_clipped = jnp.array(0.0)
        grad_norm_ω_clipped = jnp.array(0.0)
        grad_norm_θ_clipped = jnp.array(0.0)
        max_grad_norm = 1.0

        # "Summary adversary" means the learned embedding maximizes OT (min-max game).
        # "Critic adversary" is the WGAN functional (min-max-max game).
        do_summary_max = use_ot and summary_has_params and summary_adversarial
        do_critic_max = use_ot and critic_is_adversary
        max_steps = maximizer_steps if (do_summary_max or do_critic_max) else 0

        if max_steps > 0 and (do_summary_max or do_critic_max):
            θ_sg = jax.lax.stop_gradient(state.emulator.params)

            def ot_only_objective(ψ: jax.Array, ω: jax.Array) -> jax.Array:
                return ot_objective(θ_sg, ψ, ω, batch)

            for _ in range(max_steps):
                ψ = summary_state.params
                ω = critic_state.params

                grads_ψ = None
                grads_ω = None
                if do_summary_max and do_critic_max:
                    grads_ψ, grads_ω = jax.grad(ot_only_objective, argnums=(0, 1))(ψ, ω)
                elif do_summary_max:
                    grads_ψ = jax.grad(lambda _ψ: ot_only_objective(_ψ, ω))(ψ)
                elif do_critic_max:
                    grads_ω = jax.grad(lambda _ω: ot_only_objective(ψ, _ω))(ω)

                if do_summary_max:
                    grad_norm_ψ = compute_grad_norm(grads_ψ)
                    grads_ψ_clipped, grad_norm_ψ_clipped = clip_grads(grads_ψ)
                    negative_grads_ψ = jax.tree_util.tree_map(lambda x: -x, grads_ψ_clipped)
                    f_update, new_f_opt_state = summary_optimizer.update(negative_grads_ψ, summary_state.opt_state)
                    new_ψ = optax.apply_updates(ψ, f_update)
                    summary_state = SummaryState(new_ψ, new_f_opt_state)

                if do_critic_max:
                    grad_norm_ω = compute_grad_norm(grads_ω)
                    grads_ω_clipped, grad_norm_ω_clipped = clip_grads(grads_ω)
                    negative_grads_ω = jax.tree_util.tree_map(lambda x: -x, grads_ω_clipped)
                    φ_update, new_φ_opt_state = critic_optimizer.update(negative_grads_ω, critic_state.opt_state)
                    new_ω = optax.apply_updates(ω, φ_update)
                    if critic_weight_clip is not None:
                        new_ω = jax.tree_util.tree_map(
                            lambda x: jnp.clip(x, -critic_weight_clip, critic_weight_clip),
                            new_ω
                        )
                    critic_state = CriticState(new_ω, new_φ_opt_state)

        def emulator_objective(θ: jax.Array) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]:
            mse = mse_objective(θ, batch)

            def compute_ot(_) -> jax.Array:
                return ot_objective(θ, summary_state.params, critic_state.params, batch)

            ot = jax.lax.cond(
                use_ot,
                compute_ot,
                lambda _: jnp.array(0.0, dtype=mse.dtype),
                operand=None
            )
            J = mse + λ * ot
            return J, (mse, ot)

        if summary_has_params and use_ot and not summary_adversarial:
            def joint_objective(θ: jax.Array, ψ: jax.Array) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]:
                mse = mse_objective(θ, batch)

                def compute_ot(_) -> jax.Array:
                    return ot_objective(θ, ψ, critic_state.params, batch)

                ot = jax.lax.cond(
                    use_ot,
                    compute_ot,
                    lambda _: jnp.array(0.0, dtype=mse.dtype),
                    operand=None
                )
                J = mse + λ * ot
                return J, (mse, ot)

            (J, (mse, ot)), (grads_θ, grads_ψ) = jax.value_and_grad(
                joint_objective,
                argnums=(0, 1),
                has_aux=True
            )(state.emulator.params, summary_state.params)

            grad_norm_ψ = compute_grad_norm(grads_ψ)
            grads_ψ_clipped, grad_norm_ψ_clipped = clip_grads(grads_ψ)
            f_update, new_f_opt_state = summary_optimizer.update(grads_ψ_clipped, summary_state.opt_state)
            new_ψ = optax.apply_updates(summary_state.params, f_update)
            summary_state = SummaryState(new_ψ, new_f_opt_state)
        else:
            (J, (mse, ot)), grads_θ = jax.value_and_grad(
                emulator_objective,
                has_aux=True
            )(state.emulator.params)

        grad_norm_θ = compute_grad_norm(grads_θ)

        grads_θ_clipped, grad_norm_θ_clipped = clip_grads(grads_θ)
        g_update, new_g_opt_state = emulator_optimizer.update(grads_θ_clipped, state.emulator.opt_state)
        new_θ = optax.apply_updates(state.emulator.params, g_update)

        new_state = TrainState(
            emulator=EmulatorState(new_θ, new_g_opt_state),
            summary=summary_state,
            critic=critic_state,
            step=state.step + 1
        )
        
        summary_std = jnp.array(0.0)
        if summary_has_params and summary_apply is not None:
            z_true = summary_apply(summary_state.params, batch)
            summary_std = jnp.std(z_true)

        metrics = {
            "loss/mse": mse,
            "loss/ot": ot,
            "loss/total": J,
            "grad_norm/emulator": grad_norm_θ,
            "grad_norm/summary": grad_norm_ψ,
            "grad_norm/critic": grad_norm_ω,
            "grad_norm_clipped/emulator": grad_norm_θ_clipped,
            "grad_norm_clipped/summary": grad_norm_ψ_clipped,
            "grad_norm_clipped/critic": grad_norm_ω_clipped,
        }
        
        return new_state, metrics

    return train_step

def make_val_step(
        rollout_mse: Callable[..., jax.Array],
        rollout_ot: Callable[..., jax.Array],
        summary_apply: Callable[..., jax.Array],
    ) -> Callable[..., dict[str, jnp.ndarray]]:

    def val_step(
        state: TrainState,
        batch: jnp.ndarray,
        rng: jnp.ndarray,
        epoch_flags: EpochFlags,
    ) -> dict[str, jnp.ndarray]:
        
        u_true = batch

        u_hat_anchored = rollout_mse(state.emulator.params, u_true)
        if epoch_flags.use_ot:
            u_hat_full = rollout_ot(state.emulator.params, u_true)
        else:
            u_hat_full = u_hat_anchored
        s_true = summary_apply(state.summary.params, u_true)
        s_hat = summary_apply(state.summary.params, u_hat_full)
        return {
            "u_true": u_true,
            "u_hat": u_hat_anchored,
            "u_hat_full": u_hat_full,
            "s_true": s_true,
            "s_hat": s_hat,
        }
    return val_step
