from dataclasses import dataclass
from enum import Enum
from typing import Optional
from enum import IntEnum


class TBVariant(int, Enum):
    """See algo.trajectory_balance.TrajectoryBalance for details."""

    TB = 0
    SubTB1 = 1
    DB = 2



class Backward(IntEnum):
    """
    See algo.trajectory_balance.TrajectoryBalance for details.
    The A variant of `Maxent` and `GSQL` equire the environment to provide $n$.
    This is true for sEH but not QM9.
    """

    Free = 1
    MaxentA = 2



@dataclass
class TBConfig:
    """Trajectory Balance config.

    Attributes
    ----------
    bootstrap_own_reward : bool
        Whether to bootstrap the reward with the own reward. (deprecated)
    epsilon : Optional[float]
        The epsilon parameter in log-flow smoothing (see paper)
    reward_loss_multiplier : float
        The multiplier for the reward loss when bootstrapping the reward. (deprecated)
    variant : TBVariant
        The loss variant. See algo.trajectory_balance.TrajectoryBalance for details.
    do_correct_idempotent : bool
        Whether to correct for idempotent actions
    do_parameterize_p_b : bool
        Whether to parameterize the P_B distribution (otherwise it is uniform)
    do_predict_n : bool
        Whether to predict the number of paths in the graph
    do_length_normalize : bool
        Whether to normalize the loss by the length of the trajectory
    subtb_max_len : int
        The maximum length trajectories, used to cache subTB computation indices
    Z_learning_rate : float
        The learning rate for the logZ parameter (only relevant when do_subtb is False)
    Z_lr_decay : float
        The learning rate decay for the logZ parameter (only relevant when do_subtb is False)
    """

    bootstrap_own_reward: bool = False
    epsilon: Optional[float] = None
    reward_loss_multiplier: float = 1.0
    variant: TBVariant = TBVariant.TB
    do_correct_idempotent: bool = False
    do_parameterize_p_b: bool = False
    do_predict_n: bool = False
    do_sample_p_b: bool = False
    do_length_normalize: bool = False
    subtb_max_len: int = 128
    Z_learning_rate: float = 1e-4
    Z_lr_decay: float = 50_000
    cum_subtb: bool = True
    backward_policy: Backward = 1

@dataclass
class MOQLConfig:
    gamma: float = 1
    num_omega_samples: int = 32
    num_objectives: int = 2
    lambda_decay: int = 10_000
    penalty: float = -10


@dataclass
class A2CConfig:
    entropy: float = 0.1
    gamma: float = 1
    penalty: float = -10


@dataclass
class FMConfig:
    epsilon: float = 1e-38
    balanced_loss: bool = False
    leaf_coef: float = 10
    correct_idempotent: bool = False


@dataclass
class SQLConfig:
    alpha: float = 0.01
    gamma: float = 1
    penalty: float = -10


@dataclass
class AlgoConfig:
    """Generic configuration for algorithms

    Attributes
    ----------
    method : str
        The name of the algorithm to use (e.g. "TB")
    global_batch_size : int
        The batch size for training
    max_len : int
        The maximum length of a trajectory
    max_nodes : int
        The maximum number of nodes in a generated graph
    max_edges : int
        The maximum number of edges in a generated graph
    illegal_action_logreward : float
        The log reward an agent gets for illegal actions
    offline_ratio: float
        The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
        `self.sampling_model`
    valid_offline_ratio: float
        Idem but for validation, and `self.test_data`.
    train_random_action_prob : float
        The probability of taking a random action during training
    train_det_after: Optional[int]
        Do not take random actions after this number of steps
    valid_random_action_prob : float
        The probability of taking a random action during validation
    sampling_tau : float
        The EMA factor for the sampling model (theta_sampler = tau * theta_sampler + (1-tau) * theta)
    """

    method: str = "TB"
    global_batch_size: int = 64
    max_len: int = 128
    max_nodes: int = 128
    max_edges: int = 128
    illegal_action_logreward: float = -100
    offline_ratio: float = 0.5
    valid_offline_ratio: float = 1
    train_random_action_prob: float = 0.0
    train_det_after: Optional[int] = None
    valid_random_action_prob: float = 0.0
    sampling_tau: float = 0.0
    tb: TBConfig = TBConfig()
    moql: MOQLConfig = MOQLConfig()
    a2c: A2CConfig = A2CConfig()
    fm: FMConfig = FMConfig()
    sql: SQLConfig = SQLConfig()
