from typing import Callable

from pydantic import BaseModel

""" ENVS """


class EnvConfig(BaseModel):
    pass


class BitSeqConfig(EnvConfig):
    name: str = "BitSeq"
    min_len: int = 120
    max_len: int = 120
    substring_len: int = 8  # Number of bits added at a time
    num_modes: int = 60
    mode_seed: int = 0


class AMPConfig(EnvConfig):
    name: str = "AMP"
    min_len: int = (
        14  # Minimum sequence length (before which termination is not allowed, lengths include BOS/EOS tokens)
    )
    max_len: int = 62  # Maximum sequence length (after which termination is forced)
    mode_delta: int = 10  # Min distance between modes


class GFPConfig(EnvConfig):
    name: str = "GFP"
    min_len: int = 239
    max_len: int = 239
    mode_delta: int = 60


class UTRConfig(EnvConfig):
    name: str = "UTR"
    min_len: int = 52
    max_len: int = 52
    mode_delta: int = 13


""" ALGS """


class AlgConfig(BaseModel):
    pass


class TGMConfig(AlgConfig):
    name: str = "TGM"
    data_type: str = "trajectory"

    alpha: float = 1.0
    omega: float = 1.0
    q: float = 1.0
    gen_inv_temp: float = 1.0  # Inverse temperature modifier for generating samples

    q_fn: Callable = lambda x: x  # TGM can technically works with the softmax of any function of Q-values


class SACConfig(AlgConfig):
    name: str = "SAC"
    data_type: str = "transition"

    omega: float = 1.0


class PPOConfig(AlgConfig):
    name: str = "PPO"
    data_type: str = "transition"

    omega: float = 1.0
    gae_lambda: float = 0.95
    clip_eps: float = 0.1  # Clipping in PPO loss
    val_coef: float = 0.5  # Value loss coefficient

    num_minibatches: int = 4  # After generating, how many minibatches to split the data into


""" NETWORK """


class NetworkConfig(BaseModel):
    pass


class BaseTransformerConfig(NetworkConfig):
    embed_dim: int = 64
    hid_dim: int = 64
    num_layers: int = 3
    num_head: int = 8
    dropout: float = 0.1
    causal: bool = True


""" MAIN """


class Config(BaseModel):
    env: EnvConfig
    alg: AlgConfig
    network: NetworkConfig

    seed: int = 0
    save: bool = False  # Whether to save final checkpoint/metrics
    save_path: str = ""  # Where to save

    # Optimization
    lr: float = 1e-4
    weight_decay: float = 0.0001
    grad_clip_norm: float = 10

    # Sample gen
    eps: float = 0.01  # Probability of uniformly random sampling when generating
    num_envs: int = 16
    replay_buffer_size: int = 16  # replay_buffer_size == num_envs => online training

    # Training
    num_gen_samples: int = 100_000  # Number of samples to train for
    minibatch_size: int = 16  # Size of minibatch trained on
    reward_exp: float = 4  # Beta in the paper (final reward is multiplied by `reward_exp`)

    # Number of training steps before updating target network (0 to not use target network)
    target_update_steps: int = 10  # Only supported for SAC

    # Eval
    samples_per_eval: int = 10_000  # Number of samples trained on before evaluating
    num_eval_samples: int = 512  # Number of samples generated for evaluation

    # Sweep eval
    num_sweep_eval_samples: int = 512  # Number of samples generated for each temperature
    top_k: int = 100  # Number of modes
    eval_sweep_temps: list[float] = [
        0.005,
        0.01,
        0.02,
        0.05,
        0.1,
        0.2,
        0.5,
        1,
        2,
        5,
    ]  # Temperatures swept over to evaluate modes
    eval_modes: bool = False  # Whether to evaluate top_k modes during training (otherwise only done at the end of training as it is expensive)
