from rlkit.variants.base import (
    PacVariant,
    PacTrainerKwargs,
    QFKwargs,
    w,
)
from rlkit.torch.algorithms.sac import mgpac, pac
from rlkit.variants.variant_mgpac import BaseMGTrainerKwargs, IQLQFKwargs, IQLTrainerKwargs

# * Setup ---------------------------------------------------


class BaseSGPacVariant(PacVariant):
    checkpoint_params = "SG"

    trainer_cls = w(pac.PACTrainer)
    pipeline = w(pac.SGBasePipeline)


# * Runnable ------------------------------------------------
class VanillaSGPacVariant(BaseSGPacVariant):
    algorithm = "PAC-SG"
    version = "delta-tuning"
    env_id = "hopper-medium-v2"
    seed = 1


class BCTrainerKwargs(PacTrainerKwargs):
    delta = [0.0]


class BCVariant(BaseSGPacVariant):
    algorithm = "PAC-MG"
    version = "behavior-cloning-evaluation"
    env_id = "walker2d-medium-v2"
    seed = 2

    trainer_kwargs = PacTrainerKwargs()


class EpochBCVariant(BaseSGPacVariant):
    algorithm = "PAC"
    version = "variable_ensemble_size"
    env_id = "hopper-medium-replay-v2"
    seed = 1
    epoch_no = -250

    pipeline = w(pac.EpochBCExperiment)


class GroundTruthVariant(BaseSGPacVariant):
    algorithm = "PAC"
    version = "ground_truth"
    env_id = "hopper-medium-v2"
    seed = 1

    pipeline = w(pac.GTExperiment)


class IQLQFKwargs(QFKwargs):
    hidden_sizes = [256, 256]


class IQLTrainerKwargs(BaseMGTrainerKwargs):
    IQN = False
    delta_range = [0.0, 0.0]
    num_delta = 5


class AllIQLVariant(BaseSGPacVariant):
    algorithm = "IQL"
    version = "full-iql-eval"
    env_id = "hopper-medium-replay-v2"
    seed = 2

    qf_kwargs = IQLQFKwargs()
    trainer_kwargs = IQLTrainerKwargs()
    normalize_env = True
    IQN = False
    pipeline = w(mgpac.AllIQLPipeline)


