from pydantic import BaseSettings, BaseModel
from rlkit.core.trainer import Trainer
from rlkit.data_management.env_replay_buffer import (
    EnvReplayBuffer,
    EnvReplayBufferNextAction,
)
from rlkit.policies.gaussian_policy import TanhGaussianPolicy
from rlkit.torch.algorithms.sarsa import SarsaPipeline, SarsaTrainer
from rlkit.torch.networks.mlp import ConcatMlp
from rlkit.torch.torch_rl_algorithm import OfflineTorchBatchRLAlgorithm, TorchTrainer
from rlkit.launchers.pipeline import Pipelines
from rlkit.samplers.rollout_functions import rollout
from pydantic import BaseModel as PydanticBaseModel

class BaseModel(PydanticBaseModel):
    class Config:
        arbitrary_types_allowed = True


class FuncWrapper: 
    def __init__(self, f) -> None:
        self.f = f


w = FuncWrapper  # alias

#! Algo Kwargs
class AlgoKwargs(BaseModel):
    start_epoch = -1000  # offline epochs
    num_epochs = 0
    batch_size = 256
    max_path_length = 1000
    num_trains_per_train_loop = 1000


class OnlineAlgoKwargs(AlgoKwargs):
    num_expl_steps_per_train_loop = 0
    min_num_steps_before_training = 1000


class OfflineAlgoKwargs(AlgoKwargs):
    num_eval_steps_per_epoch = 5000


#! Policy Kwargs
class PolicyKwargs(BaseModel):
    hidden_sizes = [1024, 1024]


class Three256(PolicyKwargs):
    hidden_sizes = [256, 256, 256]


class Four256(PolicyKwargs):
    hidden_sizes = [256, 256, 256]


#! QF Kwargs
class QFKwargs(BaseModel):
    hidden_sizes = [1024, 1024]


class EnsembleQFKwargs(QFKwargs):
    num_heads = 10


class QuantileMLPKwargs(QFKwargs):
    hidden_sizes = [256, 256, 256]
    num_quantiles = 8
    embedding_size = 64


#! Trainer Kwargs
class TrainerKwargs(BaseModel):
    discount = 0.99
    policy_lr = 3e-4
    qf_lr = 1e-4
    reward_scale = 1
    soft_target_tau = 0.005
    target_update_period = 1


class SacTrainerKwargs(TrainerKwargs):
    use_automatic_entropy_tuning = False


class PathLoaderKwargs(BaseModel):
    pass


#! Variants
class OfflineVariant(BaseSettings):
    # require children to specify
    algorithm = ""
    version = ""
    env_id = ""
    seed = -1

    algorithm_kwargs = OfflineAlgoKwargs()
    policy_kwargs = PolicyKwargs()
    qf_kwargs = QFKwargs()
    trainer_kwargs = TrainerKwargs()
    path_loader_kwargs = PathLoaderKwargs()

    policy_class = w(TanhGaussianPolicy)
    qf_class = w(ConcatMlp)
    trainer_cls = w(TorchTrainer)
    alg_class = w(OfflineTorchBatchRLAlgorithm)
    replay_buffer_class = w(EnvReplayBuffer)
    rollout_fn = w(rollout)

    replay_buffer_size = int(2e6)

    snapshot_mode = "gap_and_last"
    snapshot_gap = 100


class PacAlgoKwargs(OfflineAlgoKwargs):
    zero_step = True
    num_eval_steps_per_epoch = 1000
    num_epochs = 100


class PacTrainerKwargs(TrainerKwargs):
    beta_LB = 1.0
    # delta: list = [0.0]
    delta_range = [0.2, 2.0]
    num_delta = 10

    target_quantile = 0.7


class PacVariant(OfflineVariant):
    trainer_cls = w(Trainer)  # require child to specify
    checkpoint_params = "SPECIFY"

    policy_kwargs = Three256()
    qf_kwargs = QuantileMLPKwargs()
    trainer_kwargs = PacTrainerKwargs()
    algorithm_kwargs = PacAlgoKwargs()
    IQN = True
    d4rl = True
    normalize_env = True

    pipeline = w(Pipelines.offline_zerostep_pac_pipeline)


class OfflineSarsaVariant(OfflineVariant):
    pipeline = w(SarsaPipeline)
    trainer_cls = w(SarsaTrainer)
    replay_buffer_class = w(EnvReplayBufferNextAction)
