from rlkit.variants.base import (
    OfflineSarsaVariant,
    OfflineVariant,
    QuantileMLPKwargs,
    TrainerKwargs,
    OfflineAlgoKwargs,
    w,
)
from rlkit.torch.algorithms.sarsa_iqn import (
    SarsaIQNTrainer,
    SarsaIQNPipeline,
)
from rlkit.torch.networks.mlp import QuantileMlp


class SarsaIQNAlgoKwargs(OfflineAlgoKwargs):
    num_eval_steps_per_epoch = 0
    max_path_length = 0
    num_epochs = 0
    batch_size = 256
    start_epoch = -int(500)


class SarsaIQNTrainerKwargs(TrainerKwargs):
    num_quantiles = 8
    qf_lr = 3e-4
    target_update_period = 1


class SarsaNormalizeIQNVariant(OfflineSarsaVariant):
    d4rl = True
    algorithm = "sarsa-iqn"
    version = "normalize-env-neg-one-reward"
    env_id = "antmaze-umaze-v0"
    normalize_env = True
    seed = 2

    qf_class = w(QuantileMlp)
    qf_kwargs = QuantileMLPKwargs()
    trainer_cls = w(SarsaIQNTrainer)
    trainer_kwargs = SarsaIQNTrainerKwargs()
    algorithm_kwargs = SarsaIQNAlgoKwargs()

    pipeline = w(SarsaIQNPipeline)

    snapshot_gap = 100
