from torch import true_divide
from rlkit.variants.base import (
    EnsembleQFKwargs,
    OfflineAlgoKwargs,
    OfflineSarsaVariant,
    TrainerKwargs,
    w,
)
from rlkit.torch.algorithms.sac import trq
from rlkit.torch.networks.mlp import ParallelMlp


class TRQAlgoKwargs(OfflineAlgoKwargs):
    start_epoch = -int(5e2)
    num_eval_steps_per_epoch = 1000
    max_path_length = 1000
    num_epochs = 0
    num_trains_per_train_loop = 1000
    batch_size = 256
    zero_step = False


class TRQTrainerKwargs(TrainerKwargs):
    target_update_period = 1
    q_lr = 3e-4

    beta_LB = 1.0
    delta = 0.1
    rollout_delta = 1.0


class TRQEnsembleQFKwargs(EnsembleQFKwargs):
    num_heads = 2
    hidden_sizes = [256, 256, 256]


class TRQFromEnsembleVariant(OfflineSarsaVariant):
    algorithm = "TRQ"
    version = "DEBUG-trq-normalize-env"
    env_id = "hopper-medium-v2"
    seed = 3
    normalize_env = True
    warm_start_q = False

    checkpoint_params = "SG"

    qf_class = w(ParallelMlp)
    qf_kwargs = TRQEnsembleQFKwargs()
    pipeline = w(trq.TRQPipeline)
    trainer_cls = w(trq.TRQTrainer)

    trainer_kwargs = TRQTrainerKwargs()
    algorithm_kwargs = TRQAlgoKwargs()

    pipeline = w(trq.TRQPipeline)
