from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp

from nais.gym.base import EnvState


class RecurrenceType(Enum):
    NONE = "none"
    LSTM = "lstm"
    SRWM = "srwm"
    FWP = "fwp"
    LC = "lc"
    PP = "pp"  # Not really recurrent, though, but convenient


# There is no need to differentiate through the sampling in GFlowNets
def _sample_categorical_with_key(key: jax.Array, probs: jax.Array) -> tuple[jax.Array, jax.Array]:
    _, subkey = jax.random.split(key)
    logits = jnp.log(probs)
    samples = jax.random.categorical(subkey, logits=logits, axis=-1)
    return subkey, samples


@dataclass
class ForwardPolicyConfig:
    # If None, the input_dim, hidden_dim, and output_dim will be inferred from the state
    input_dim: Optional[int] = None
    hidden_dim: Optional[int] = None
    output_dim: Optional[int] = None  # maximum number of actions
    hidden_dim_F_nn: Optional[int] = 64
    device: Optional[str] = None
    eps: float = 0.05
    masked_value: float = -1e5
    seed: int = 42

    def __hash__(self):
        return hash(
            (
                self.input_dim,
                self.hidden_dim,
                self.output_dim,
                self.device,
                self.eps,
                self.masked_value,
                self.seed,
            )
        )


@struct.dataclass
class FlowFunctionConfig:
    input_dim: Optional[int] = struct.field(pytree_node=False, default=None)
    hidden_dim: Optional[int] = struct.field(pytree_node=False, default=None)
    output_dim: Optional[int] = struct.field(pytree_node=False, default=None)


@struct.dataclass
class ForwardPolicyOutput:
    actions: jax.Array
    log_pf: jax.Array
    key: jax.Array


class FlowFunctionBase(nnx.Module, metaclass=ABCMeta):
    def __init__(self, config: FlowFunctionConfig):
        self.config = config
        self.input_dim = config.input_dim
        self.hidden_dim = config.hidden_dim

    @abstractmethod
    def __call__(self, state: EnvState) -> jax.Array:
        pass


class ForwardPolicyBase(nnx.Module, metaclass=ABCMeta):
    def __init__(self, config: ForwardPolicyConfig):
        self.input_dim = config.input_dim
        self.output_dim = config.output_dim
        self.device = config.device
        self.eps = config.eps
        self.masked_value = config.masked_value

        self.config = config

        self.is_off_policy = True
        self._eps = self.eps

    def _get_masked_policy(self, logits: jax.Array, mask: jax.Array):
        return logits * mask + (1 - mask) * self.masked_value

    def get_uniform_pol(self, state: EnvState):
        uniform_pol = state.forward_mask
        uniform_norm = uniform_pol.sum(axis=1, keepdims=True)
        return jnp.where(uniform_norm > 0, uniform_pol / uniform_norm, uniform_pol)

    def sample_actions(
        self,
        state: EnvState,
        key: jax.Array = None,
        actions: jax.Array = None,
    ) -> ForwardPolicyOutput:
        logits = self(state)
        # First we softmax the logits to obtain the policy
        logits = self._get_masked_policy(logits, state.forward_mask)

        probs = nnx.softmax(logits, axis=1)

        # Set the sampling policy as a mixture policy
        uniform_pol = self.get_uniform_pol(state)

        sampling_probs = (1 - self.eps) * probs + self.eps * uniform_pol

        # Generate actions from a multinomial distribution
        if actions is None:
            key, actions = _sample_categorical_with_key(key, sampling_probs)

        # This should raise an error if any probs == 0
        log_probs = jnp.log(probs[state.batch_ids, actions])
        log_probs = jnp.where(state.forward_mask[state.batch_ids, actions] > 0, log_probs, 0.0)
        return ForwardPolicyOutput(actions=actions, log_pf=log_probs, key=key)

    @abstractmethod
    def __call__(self, state: EnvState) -> jax.Array:
        pass

    def set_eps(self, eps: float):
        self._eps = self.eps
        self.eps = eps  # All samples are generated off-policy

    def reset_eps(self):
        self.eps = self._eps

    def kl_to_uniform(self, initial_states: EnvState):
        # Compute the KL divergence from the policy to the uniform distribution at [states]
        # clearly, KL[p || q] = sum pi * log (pi/qi) - except for the masked actions
        logits = self(initial_states)
        logits = self._get_masked_policy(logits, initial_states.forward_mask)
        probs = nnx.softmax(logits, axis=1)

        uniform_dist = initial_states.forward_mask / initial_states.forward_mask.sum(axis=1, keepdims=True)
        kl_div = jnp.sum(
            jnp.where(probs > 0, probs * jnp.log(probs / uniform_dist), 0),
            axis=1,
        )
        # All values should be the same, so computing the mean shouldn't matter
        return kl_div.mean()

    def entropy(self, states: EnvState):
        # similar to above, but for entropy
        logits = self(states)
        logits = self._get_masked_policy(logits, states.forward_mask)
        probs = nnx.softmax(logits, axis=1)
        entropy = -jnp.sum(
            jnp.where(probs > 0, probs * jnp.log(probs), 0),
            axis=1,
        )
        # All values should be the same, so computing the mean shouldn't matter
        return entropy.mean()

    def lazy_init(self, state: EnvState):
        return None


@struct.dataclass
class BackwardPolicyConfig:
    key: int = struct.field(pytree_node=False, default=42)
    masked_value: float = struct.field(pytree_node=False, default=-1e5)

    input_dim: Optional[int] = None
    hidden_dim: Optional[int] = None
    output_dim: Optional[int] = None


@struct.dataclass
class BackwardPolicyOutput:
    actions: jax.Array
    log_pb: jax.Array
    key: jax.Array


class BackwardPolicyBase(nnx.Module, metaclass=ABCMeta):
    def __init__(self, config: BackwardPolicyConfig):
        self.config = config

        self.key = jax.random.key(config.key)
        self.masked_value = config.masked_value

    @abstractmethod
    def __call__(self, state: EnvState) -> jax.Array:
        return

    def sample_actions(self, state: EnvState, key: jax.Array = None, actions: jax.Array = None):
        logits = self(state)

        masked_logits = logits * state.backward_mask + (1 - state.backward_mask) * self.masked_value
        pol = nnx.softmax(masked_logits, axis=1)

        if actions is None:
            key, actions = _sample_categorical_with_key(key, pol)

        probs = pol[state.batch_ids, actions]
        log_pb = jnp.log(probs)

        return BackwardPolicyOutput(log_pb=log_pb, actions=actions, key=key)

    def lazy_init(self, state: EnvState):
        return
