#! this file needs to be refactored

from rlkit.variants.base import (
    EnsembleQFKwargs,
    OfflineAlgoKwargs,
    OfflineSarsaVariant,
    OfflineVariant,
    PolicyKwargs,
    QFKwargs,
    TrainerKwargs,
    w,
)
from rlkit.torch.algorithms.sarsa import (
    SarsaTrainer,
    SarsaPipeline,
    SarsaWithValPipeline,
)
import torch.optim as optim
from rlkit.torch.networks.mlp import ParallelMlp, QuantileMlp


class SarsaAlgoKwargs(OfflineAlgoKwargs):
    start_epoch = -int(2e3)
    num_eval_steps_per_epoch = 0
    max_path_length = 0
    num_epochs = 0


class SarsaTrainerKwargs(TrainerKwargs):
    target_update_period = 2


class SarsaVariant(OfflineVariant):
    algorithm = "sarsa"
    version = "ant-maze-pen"
    env_id = "pen-cloned-v1"

    qf_class = w(ParallelMlp)
    qf_kwargs = EnsembleQFKwargs()
    algorithm_kwargs = SarsaAlgoKwargs()
    trainer_kwargs = SarsaTrainerKwargs()

    trainer_cls = w(SarsaTrainer)
    pipeline = w(SarsaPipeline)


class TunedTrainerKwargs(TrainerKwargs):
    optimizer_class = w(optim.AdamW)


class TunedQFKwargs(EnsembleQFKwargs):
    hidden_activation = "leaky_relu"
    layer_norm = True
    hidden_sizes = [256, 256, 256]


class SarsaVariantTuned(OfflineVariant):
    algorithm = "sarsa"
    version = "final-0"

    qf_class = w(ParallelMlp)
    qf_kwargs = EnsembleQFKwargs()
    algorithm_kwargs = SarsaAlgoKwargs()
    trainer_kwargs = SarsaTrainerKwargs()

    trainer_cls = w(SarsaTrainer)
    pipeline = w(SarsaPipeline)


class SarsaNormalizeVariant(SarsaVariant):
    normalize_env = True
    algorithm = "sarsa"
    version = "normalize-env"
    env_id = "hopper-medium-replay-v2"


class ThreeLayerSarsaQFKwargs(EnsembleQFKwargs):
    hidden_sizes = [256, 256, 256]
    num_heads = 10


class ThreeLayerSarsaTrainerKwargs(TrainerKwargs):
    qf_lr = 3e-4
    target_update_period = 1


class ThreeLayerSarsaWithValAlgoKwargs(SarsaAlgoKwargs):
    batch_size = 256


class ThreeLayerSarsaAlgoKwargs(SarsaAlgoKwargs):
    batch_size = 256
    start_epoch = -int(400)


class SarsaNormalizeWithValVariant(SarsaNormalizeVariant):
    algorithm = "sarsa"
    version = "normalize-env-with-val-loss-3layers"
    env_id = "hopper-medium-replay-v2"

    normalize_env = True
    train_ratio = 0.95
    fold_idx = 2

    qf_kwargs = ThreeLayerSarsaQFKwargs()
    trainer_kwargs = ThreeLayerSarsaTrainerKwargs()
    algorithm_kwargs = ThreeLayerSarsaWithValAlgoKwargs()

    pipeline = w(SarsaWithValPipeline)


class SarsaNormalizeThreeLayerVariant(SarsaNormalizeVariant):
    version = "normalize-env-3layers"
    env_id = "hopper-medium-expert-v2"
    assert env_id == "hopper-medium-expert-v2"

    qf_kwargs = ThreeLayerSarsaQFKwargs()
    trainer_kwargs = ThreeLayerSarsaTrainerKwargs()
    algorithm_kwargs = ThreeLayerSarsaAlgoKwargs()

    snapshot_gap = 50


class TunedTrainerKwargs(TrainerKwargs):
    optimizer_class = optim.AdamW


class TunedQFKwargs(EnsembleQFKwargs):
    hidden_activation = "leaky_relu"
    dropout = True
    hidden_sizes = [256, 256, 256]


class TunedSarsaVariant(OfflineSarsaVariant):
    algorithm = "sarsa"
    version = "final-0"

    qf_class = w(ParallelMlp)
    qf_kwargs = TunedQFKwargs()
    algorithm_kwargs = SarsaAlgoKwargs()
    trainer_kwargs = SarsaTrainerKwargs()

    trainer_cls = w(SarsaTrainer)
    pipeline = w(SarsaPipeline)


class TunedSarsaNormalizeWithValVariant(SarsaNormalizeWithValVariant):
    version = "tuned-normalize-env-with-val-loss"
    qf_kwargs = TunedQFKwargs()
