from collections.abc import Callable
from copy import deepcopy
import inspect
from typing import Generic, NamedTuple, TypeVar

from flax import nnx
import optax

from offline.utils.jax import soft_update


OPTIMIZERS: dict[str, Callable[..., optax.GradientTransformation]] = {
    "adam": optax.adam,
    "adamw": optax.adamw,
    "lamb": optax.lamb,
    "sgd": optax.sgd,
}

ModelT = TypeVar("ModelT", bound=nnx.Module)
NotRng = nnx.Not(nnx.RngState)


def get_optimizer(
    model: nnx.Module,
    learning_rate,
    every_k_schedule: int = 1,
    max_gradient_norm: float = 0,
    optimizer_type: str = "adam",
    wrt: nnx.filterlib.Filter = nnx.Param,
    **kwargs,
) -> nnx.Optimizer:
    optimizer_fn = OPTIMIZERS[optimizer_type]
    kwargs = {
        k: v
        for k, v in kwargs.items()
        if k in inspect.signature(optimizer_fn).parameters
    }
    optimizer = optimizer_fn(learning_rate=learning_rate, **kwargs)
    if max_gradient_norm > 0:
        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm), optimizer
        )
    tx = (
        optimizer
        if every_k_schedule == 1
        else optax.MultiSteps(optimizer, every_k_schedule=every_k_schedule)
    )
    return nnx.Optimizer(model, tx, wrt)  # type: ignore


class TargetModel(nnx.Module, Generic[ModelT]):
    def __init__(self, model: ModelT, poi: nnx.filterlib.Filter | None = None):
        self.model = deepcopy(model)
        self.poi = NotRng if poi is None else nnx.All(NotRng, poi)
        self.eval()

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)  # type: ignore

    def update(self, model: ModelT, tau: float):
        state = soft_update(
            nnx.state(model, self.poi), nnx.state(self.model, self.poi), tau
        )
        nnx.update(self.model, state)

    def hard_update(self, model: ModelT):
        nnx.update(self.model, nnx.state(model, self.poi))


class TrainState(NamedTuple, Generic[ModelT]):
    model: ModelT
    optimizer: nnx.Optimizer


class TrainStateWithTarget(NamedTuple, Generic[ModelT]):
    model: ModelT
    optimizer: nnx.Optimizer
    target: TargetModel[ModelT]
