from typing import Any

from flax.training.train_state import TrainState


class TrainState(TrainState):
    target_params: Any = None

    def copy_target_params(self):
        return self.replace(target_params=self.params)
