from typing import Any, Callable

from flax import core
from flax import struct
import optax

class MultiModelTrainState(struct.PyTreeNode):
    """TrainState for an ensemble of models."""

    step: int
    apply_fn: Callable = struct.field(pytree_node=False)
    params: core.FrozenDict[str, core.FrozenDict[str, Any]] = struct.field(pytree_node=True)
    txs: core.FrozenDict[str, optax.GradientTransformation] = struct.field(pytree_node=False)
    opt_state: core.FrozenDict[str, optax.OptState] = struct.field(pytree_node=True)

    def apply_gradients(self, *, grads, **kwargs):
        new_params = {}
        new_opt_state = {}
        for key, tx in self.txs.items():
            _updates, new_opt_state[key] = \
                tx.update(grads[key], self.opt_state[key], self.params[key])
            new_params[key] = optax.apply_updates(self.params[key], _updates)

        return self.replace(step=self.step + 1, params=new_params, opt_state=new_opt_state, **kwargs)

    @classmethod
    def create(cls, *, apply_fn, params, txs, **kwargs):
        opt_state = {key: tx.init(params[key]) for key, tx in txs.items()}
        return cls(step=0, apply_fn=apply_fn, params=params, txs=txs, opt_state=opt_state)
