from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Self

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp


class EnvironmentEnum(Enum):
    SETS = "sets"
    SEQUENCES = "sequences"
    HYPERGRIDS = "hypergrids"
    LINES = "lines"
    SMALL_GRAPHS = "small-graphs"
    DISCRETE_DIFF = "discrete-diff"


@struct.dataclass(frozen=True)
class EnvState:
    state: jax.Array

    forward_mask: jax.Array
    backward_mask: jax.Array

    batch_ids: jax.Array

    stopped: jax.Array
    is_initial: jax.Array

    max_trajectory_length: int = struct.field(pytree_node=False)
    batch_size: int = struct.field(pytree_node=False)

    metadata: Optional[struct.PyTreeNode] = None


class LogRewardEnum(Enum):
    STD = "std"

    SEQS_TFN8 = "seqs-tfn8"
    SEQS_TFN10 = "seqs-tfn10"
    SEQS_PREFERENCES = "seqs-preferences"
    SEQS_BITS = "seqs-bits"

    LINES_NORMAL = "lines-normal"
    LINES_UNIFORM = "lines-uniform"
    LINES_CENTER = "lines-center"
    LINES_STEEP = "lines-steep"

    HYPERGRIDS_GAUSSIAN = "hypergrids-gaussian"

    DISCRETE_RINGS = "discrete-rings"
    DISCRETE_FOURGAUSS = "discrete-fourgauss"
    DISCRETE_GRADED = "discrete-graded"


@dataclass
class EnvironmentConfig:
    batch_size: int
    max_trajectory_length: Optional[int] = None
    device: Optional[str] = None


class Environment(metaclass=ABCMeta):
    def __init__(self, config: EnvironmentConfig):
        self.batch_size = config.batch_size
        self.device = config.device
        self.max_trajectory_length = config.max_trajectory_length
        self.batch_ids = jnp.arange(self.batch_size)
        self.traj_size = jnp.ones((self.batch_size,))
        self.stopped = jnp.zeros((self.batch_size,), dtype=jnp.int8)
        self.is_initial = jnp.ones((self.batch_size,), dtype=jnp.int8)

        self.state: jax.Array = None
        self.forward_mask: jax.Array = None
        self.backward_mask: jax.Array = None

        self.config = config

    @abstractmethod
    def apply(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
        pass

    @abstractmethod
    def backward(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
        pass

    @abstractmethod
    def fwd_to_bcw_actions(self, actions):
        pass

    def merge(self, batch_state: Self):
        self.batch_ids = jnp.hstack([self.batch_ids, self.batch_size + batch_state.batch_ids])
        self.batch_size += batch_state.batch_size
        self.stopped = jnp.hstack([self.stopped, batch_state.stopped])
        self.is_initial = jnp.hstack([self.is_initial, batch_state.is_initial])
        self.traj_size = jnp.hstack([self.traj_size, batch_state.traj_size])

        if isinstance(self.node_indices, jax.Array):
            self.node_indices = jnp.hstack([self.node_indices, batch_state.node_indices])

        if self.forward_mask is not None:
            self.forward_mask = jnp.vstack([self.forward_mask, batch_state.forward_mask])
            self.backward_mask = jnp.vstack([self.backward_mask, batch_state.backward_mask])

    def _active_mask(self, mask: jax.Array | None) -> jax.Array:
        if mask is None:
            return jnp.ones((self.batch_size,), dtype=bool)
        return mask.astype(bool)

    @property
    def log_space_size(self):
        # Should be implemented by subclasses
        return None

    def get_state(self):
        return self.state


@struct.dataclass
class LogRewardConfig:
    temperature: jax.Array


class LogRewardBase(nnx.Module, metaclass=ABCMeta):
    def __init__(self, config: LogRewardConfig):
        self.temperature = config.temperature

    @abstractmethod
    def __call__(self, state: EnvState) -> jax.Array:
        pass

    @property
    def mode_th(self):
        return 0
