import dataclasses
import typing
from typing import Protocol, Sequence, TypeVar

import chex
import jax

from tabular_mvdrl.types import MRPTransitionBatch

StateT = TypeVar("StateT")


@typing.runtime_checkable
class ParticleUpdater(Protocol):
    def __call__(
        self,
        rng: chex.PRNGKey,
        state: StateT,
        batch: MRPTransitionBatch,
    ) -> StateT: ...


@dataclasses.dataclass(frozen=True, kw_only=True)
class NoOpParticleUpdater(ParticleUpdater):
    def __call__(
        self,
        rng: chex.PRNGKey,
        state: StateT,
        batch: MRPTransitionBatch,
    ) -> StateT:
        del rng
        del batch
        return state


@dataclasses.dataclass(frozen=True, kw_only=True)
class PeriodicParticleUpdater(ParticleUpdater):
    base: ParticleUpdater
    update_period: int

    def __call__(
        self, rng: chex.PRNGKey, state: StateT, batch: MRPTransitionBatch
    ) -> StateT:
        return jax.lax.cond(
            (state.step + 1) % self.update_period == 0,
            self.base,
            lambda _rng, _state, _: _state,
            rng,
            state,
            batch,
        )


@dataclasses.dataclass(frozen=True, kw_only=True)
class ComposedParticleUpdater(ParticleUpdater):
    components: Sequence[StateT]

    def __call__(
        self, rng: chex.PRNGKey, state: StateT, batch: MRPTransitionBatch
    ) -> StateT:
        for i, component in enumerate(self.components):
            key = jax.random.fold_in(rng, i)
            state = component(key, state, batch)
        return state


@dataclasses.dataclass(frozen=True, kw_only=True)
class ParticleFilter(ParticleUpdater):
    """
    TODO:
    Even if the support is state-independent, is it feasible to do something like this?
    The issue is that usually when we resample the support, the weights are set to uniform,
    which seems difficult to achieve when the weights are the output of a neural net.
    """

    def __call__(
        self,
        rng: chex.PRNGKey,
        state: StateT,
        batch: MRPTransitionBatch,
    ) -> StateT:
        raise NotImplementedError
