from dataclasses import dataclass

from offline.lbp import arguments as lbp
from offline.utils.parser import ArgumentParser


@dataclass(frozen=True)
class Arguments(lbp.Arguments):
    classifier_steps: int
    clip_eigenvalues: bool
    codebook_size: int
    commitment_cost: float
    deltas_multiplier: float
    deterministic_reward: bool
    deterministic_transition: bool
    latent_dim: int
    max_timescale: float
    min_timescale: float
    num_blocks: int
    observation_embedding_dim: int
    reward_embedding_dim: int
    reward_weight: float
    ssm_base_size: int
    subsample_decodes: int
    subsample_latents: int
    tc_batch_size: int
    tc_decay: float
    tc_hidden_features: int
    tc_phase_split: float
    tc_reweight: bool
    tc_steps: int
    tc_threshold: float
    threshold: float
    transition_weight: float
    use_next_observation: bool


def build_argument_parser(parser: ArgumentParser | None = None, **kwargs):
    if parser is None:
        parser = lbp.build_argument_parser(**kwargs)

    parser.add_argument("--classifier-steps", type=int, default=500000)
    parser.add_argument("--clip-eigenvalues", action="store_true")
    parser.add_argument("--codebook-size", type=int, default=4)
    parser.add_argument("--commitment-cost", type=float, default=0.25)
    parser.add_argument('--deltas-multiplier', type=float, default=1)
    parser.add_argument("--deterministic-reward", action="store_true")
    parser.add_argument("--deterministic-transition", action="store_true")
    parser.add_argument(
        "--do-not-reweight-tc", action="store_false", dest="tc_reweight"
    )
    parser.add_argument("--latent-dim", type=int, default=32)
    parser.add_argument("--max-timescale", type=float, default=0.1)
    parser.add_argument("--min-timescale", type=float, default=0.001)
    parser.add_argument("--num-blocks", type=int, default=8)
    parser.add_argument("--observation-embedding-dim", type=int, default=64)
    parser.add_argument(
        "--omit-next-observation",
        action="store_false",
        dest="use_next_observation",
    )
    parser.add_argument("--reward-embedding-dim", type=int, default=0)
    parser.add_argument("--reward-weight", type=float, default=0)
    parser.add_argument("--ssm-base-size", type=int, default=256)
    parser.add_argument("--subsample-decodes", type=int, default=64)
    parser.add_argument("--subsample-latents", type=int, default=64)
    parser.add_argument("--tc-batch-size", type=int, default=8)
    parser.add_argument("--tc-decay", type=float, default=0.99)
    parser.add_argument("--tc-hidden-features", type=int, default=64)
    parser.add_argument("--tc-phase-split", type=float, default=0.5)
    parser.add_argument('--tc-steps', type=int, default=500000)
    parser.add_argument("--tc-threshold", type=float, default=0.1)
    parser.add_argument("--threshold", type=float, default=5)
    parser.add_argument("--transition-weight", type=float, default=1)
    return parser
