from rlkit.variants.base import (
    PacTrainerKwargs,
    PacVariant,
    QFKwargs,
    w,
)
from math import sqrt, log
from rlkit.torch.algorithms.sac import mgpac, mgpac_trunc

# * Setup --------------------------------------------------
class BaseMGTrainerKwargs(PacTrainerKwargs):
    use_max_lambda = True
    # action_selection_mode = "max"
    delta_range = [0.0, 0.0]  #! lower, upper
    # delta = [0.2, 0.5, 1.0, 2.0]
    target_quantile = 0.9
    # delta = [sqrt(log(x) * -2) for x in [0.99, 0.9, 0.6, 0.3, 0.15]]
    beta_LB = 1.0

    action_selection_mode = "max"


class BaseMGPacVariant(PacVariant):
    checkpoint_params = "MG4"

    trainer_kwargs = BaseMGTrainerKwargs()
    trainer_cls = w(mgpac.MGPacTrainer)
    pipeline = w(mgpac.MGBasePipeline)


class TwoGaussianBaseMGPacVariant(BaseMGPacVariant):
    checkpoint_params = "MG2"


# * Runnable ------------------------------------------------
class VanillaMGPacVariant(BaseMGPacVariant):
    algorithm = "PAC-MG4"
    version = "delta-random-sampling-V5"
    env_id = "antmaze-umaze-v0"
    seed = 1


class BCVariant(BaseMGPacVariant):
    algorithm = "PAC-MG"
    version = "behavior-cloning-evaluation"
    env_id = "walker2d-medium-v2"
    seed = 2

    pipeline = w(mgpac.MGEvalBCPipeline)


# ----
class IQLQFKwargs(QFKwargs):
    hidden_sizes = [256, 256]


class IQLTrainerKwargs(BaseMGTrainerKwargs):
    IQN = False
    delta_range = [0.0, 0.5]
    num_delta = 5


class IQLVariant(BaseMGPacVariant):
    algorithm = "PAC-MG"
    version = "iql-eval-1000-policy"
    env_id = "hopper-medium-v2"
    seed = 2

    qf_kwargs = IQLQFKwargs()
    trainer_kwargs = IQLTrainerKwargs()
    normalize_env = True
    IQN = False
    pipeline = w(mgpac.MGIQLPipeline)


# * -----


class TruncIQLVariant(IQLVariant):
    algorithm = "PAC-MG"
    version = "iql-trunc-correct-GM"
    env_id = "antmaze-large-play-v0"
    checkpoint_params = "MGTrunc"
    seed = 6

    normalize_env = False
    trainer_cls = w(mgpac_trunc.MGPacTruncTrainer)
    pipeline = w(mgpac_trunc.MGTruncPipeline)
