from enum import Enum

# Discrete diffusion
from nais.gym.discrete_diff import (
    LogReward as DiscreteDiffLogReward,
)
from nais.gym.discrete_diff import (
    LogRewardFourGaussians as DiscreteDiffLogRewardFourGaussians,
)
from nais.gym.discrete_diff import (
    LogRewardGraded as DiscreteDiffLogRewardGraded,
)
from nais.gym.discrete_diff import (
    LogRewardRings as DiscreteDiffLogRewardRings,
)
from nais.gym.discrete_diff import (
    apply_fn as discrete_diff_apply_fn,
)
from nais.gym.discrete_diff import (
    backward_fn as discrete_diff_backward_fn,
)
from nais.gym.discrete_diff import (
    factory as discrete_diff_factory,
)
from nais.gym.discrete_diff import (
    get_discrete_diff_grid,
)

# Hypergrids
from nais.gym.hypergrids import (
    LogReward as HypergridsLogReward,
)
from nais.gym.hypergrids import (
    LogRewardGaussian as HypergridsLogRewardGaussian,
)
from nais.gym.hypergrids import (
    apply_fn as hypergrids_apply_fn,
)
from nais.gym.hypergrids import (
    backward_fn as hypergrids_backward_fn,
)
from nais.gym.hypergrids import (
    factory as hypergrids_factory,
)
from nais.gym.hypergrids import (
    get_grid,
)

# Lines
from nais.gym.lines import (
    LogReward as LinesLogReward,
)
from nais.gym.lines import (
    LogRewardCentered as LinesLogRewardCentered,
)
from nais.gym.lines import (
    LogRewardNormal as LinesLogRewardNormal,
)
from nais.gym.lines import (
    LogRewardSteep as LinesLogRewardSteep,
)
from nais.gym.lines import (
    LogRewardUniform as LinesLogRewardUniform,
)
from nais.gym.lines import (
    apply_fn as lines_apply_fn,
)
from nais.gym.lines import (
    backward_fn as lines_backward_fn,
)
from nais.gym.lines import (
    factory as lines_factory,
)

# Sequences
from nais.gym.sequences import (
    LogReward as SequencesLogReward,
)
from nais.gym.sequences import (
    LogRewardBits as SequencesLogRewardBits,
)
from nais.gym.sequences import (
    LogRewardPreferences as SequencesLogRewardPreferences,
)
from nais.gym.sequences import (
    LogRewardTFN as SequencesLogRewardTFN,
)
from nais.gym.sequences import (
    apply_fn as sequences_apply_fn,
)
from nais.gym.sequences import (
    backward_fn as sequences_backward_fn,
)
from nais.gym.sequences import (
    factory as sequences_factory,
)
from nais.gym.sequences import (
    get_predictive_marginal,
)

# Sets
from nais.gym.sets import (
    LogReward as SetsLogReward,
)
from nais.gym.sets import (
    apply_fn as sets_apply_fn,
)
from nais.gym.sets import (
    backward_fn as sets_backward_fn,
)
from nais.gym.sets import (
    factory as sets_factory,
)

# Small graphs
from nais.gym.small_graphs import (
    NUM_STATES as SG_NUM_STATES,
)
from nais.gym.small_graphs import (
    LogReward as SmallGraphsLogReward,
)
from nais.gym.small_graphs import (
    apply_fn as small_graphs_apply_fn,
)
from nais.gym.small_graphs import (
    backward_fn as small_graphs_backward_fn,
)
from nais.gym.small_graphs import (
    factory as small_graphs_factory,
)


class EnvironmentEnum(Enum):
    SETS = "sets"
    SEQUENCES = "sequences"
    HYPERGRIDS = "hypergrids"
    LINES = "lines"
    SMALL_GRAPHS = "small-graphs"
    DISCRETE_DIFF = "discrete-diff"


__all__ = [
    "EnvironmentEnum",
    # Discrete diffusion
    "DiscreteDiffLogReward",
    "DiscreteDiffLogRewardRings",
    "DiscreteDiffLogRewardFourGaussians",
    "DiscreteDiffLogRewardGraded",
    "discrete_diff_apply_fn",
    "discrete_diff_backward_fn",
    "discrete_diff_factory",
    "get_discrete_diff_grid",
    # Hypergrids
    "HypergridsLogReward",
    "HypergridsLogRewardGaussian",
    "hypergrids_apply_fn",
    "hypergrids_backward_fn",
    "hypergrids_factory",
    "get_grid",
    # Lines
    "LinesLogReward",
    "LinesLogRewardCentered",
    "LinesLogRewardNormal",
    "LinesLogRewardSteep",
    "LinesLogRewardUniform",
    "lines_apply_fn",
    "lines_backward_fn",
    "lines_factory",
    # Sequences
    "SequencesLogReward",
    "SequencesLogRewardBits",
    "SequencesLogRewardPreferences",
    "SequencesLogRewardTFN",
    "sequences_apply_fn",
    "sequences_backward_fn",
    "sequences_factory",
    "get_predictive_marginal",
    # Sets
    "SetsLogReward",
    "sets_apply_fn",
    "sets_backward_fn",
    "sets_factory",
    # Small Graphs
    "SG_NUM_STATES",
    "SmallGraphsLogReward",
    "small_graphs_apply_fn",
    "small_graphs_backward_fn",
    "small_graphs_factory",
]
