from rlkit.variants.base import (
    OfflineAlgoKwargs,
    OfflineVariant,
    PolicyKwargs,
    Three256,
    TrainerKwargs,
    w,
)
from rlkit.policies.gaussian_policy import (
    GaussianMixturePolicy,
    GaussianPolicy,
    TanhGaussianMixturePolicy,
)
from rlkit.torch.algorithms.bc import BCPipeline, BCTrainer
from torch.nn import functional as F
from rlkit.torch.algorithms.bc import (
    BCPipeline,
    BCWithValPipeline,
)
from torch import optim


# * Setup --------------------------------------------------


class BCAlgoKwargs(OfflineAlgoKwargs):
    start_epoch = -int(1000)
    num_eval_steps_per_epoch = 1000
    num_epochs = 0


class BCTrainerKwargs(TrainerKwargs):
    policy_lr = 1e-4


class BaseBCVariant(OfflineVariant):

    trainer_cls = w(BCTrainer)

    policy_kwargs = Three256()
    trainer_kwargs = BCTrainerKwargs()
    algorithm_kwargs = BCAlgoKwargs()
    IQN = True
    d4rl = True
    normalize_env = True

    pipeline = w(BCPipeline)


# * Runnable ------------------------------------------------


class VanillaBCVariant(BaseBCVariant):
    algorithm = "behavior-cloning"
    version = "normalized-256-256-256"
    seed = 0
    env_id = "hopper-medium-v2"


class TwoLayerFourGaussianPolicyKwargs(PolicyKwargs):
    num_gaussians = 4


class ThreeLayerTwoGaussianPolicyKwargs(Three256):
    num_gaussians = 2


class ThreeLayerFourGaussianPolicyKwargs(Three256):
    num_gaussians = 4


# * Mixture Gaussian
class MGBCVariant(VanillaBCVariant):
    seed = 0
    algorithm = "mg-behavior-cloning"
    version = "4-gaussian-fixed-atanh"

    policy_kwargs = TwoLayerFourGaussianPolicyKwargs()
    policy_class = w(TanhGaussianMixturePolicy)


# * Normalized -------------------


class TwoGaussianMGBCNormalizeVariant(MGBCVariant):
    normalize_env = True
    version = "2g-normalize-env-all-data"
    env_id = "hopper-medium-v2"
    seed = 0

    policy_kwargs = ThreeLayerTwoGaussianPolicyKwargs()
    pipeline = w(BCPipeline)


class FourGaussianMGBCNormalizeVariant(MGBCVariant):
    normalize_env = True
    version = "4-gaussian-3-layers-normalize-env-all-data"
    env_id = "hopper-medium-v2"
    seed = 0

    policy_kwargs = ThreeLayerFourGaussianPolicyKwargs()
    pipeline = w(BCPipeline)


# * Normalized with vaidation -------------
class MGBCNormalizeWithValVariant(MGBCVariant):
    normalize_env = True
    version = "2-gaussian-3-layers-normalize-env-with-val"
    env_id = "hopper-medium-v2"
    seed = 0
    train_ratio = 0.95
    fold_idx = 2

    policy_kwargs = ThreeLayerTwoGaussianPolicyKwargs()
    algorithm_kwargs = BCAlgoKwargs()
    pipeline = w(BCWithValPipeline)

    snapshot_gap = 50


# * Tuned --------------
class TunedBCTrainerKwargs(TrainerKwargs):
    optimizer_class = w(optim.AdamW)


class TunedPolicyKwargs(ThreeLayerFourGaussianPolicyKwargs):
    hidden_activation = w(F.leaky_relu)
    # layer_norm = True
    num_gaussians = 10
    dropout = True
    dropout_kwargs = {"p": 0.1}
    hidden_sizes = [256, 256, 256, 256]


class TunedSingleGaussianPolicyKwargs(PolicyKwargs):
    hidden_activation = w(F.leaky_relu)
    # layer_norm = True
    dropout = True
    dropout_kwargs = {"p": 0.1}
    hidden_sizes = [256, 256, 256, 256]


class TunedBCVariant(BaseBCVariant):
    seed = 0
    algorithm = "mg-behavior-cloning"
    version = "gaussian-tanh-before"
    trainer_kwargs = TunedBCTrainerKwargs()
    policy_kwargs = TunedSingleGaussianPolicyKwargs()
    policy_class = w(GaussianPolicy)


class TunedMGBCVariant(BaseBCVariant):
    seed = 0
    algorithm = "mg-behavior-cloning"
    version = "10-gaussian-correct-GM"
    env_id = "antmaze-umaze-diverse-v0"
    trainer_kwargs = TunedBCTrainerKwargs()
    normalize_env = False

    policy_kwargs = TunedPolicyKwargs()
    policy_class = w(GaussianMixturePolicy)
