#!/usr/bin/env python

# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field


@dataclass
class TDMPCConfig:
    """Configuration class for TDMPCPolicy.

    Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
    camera observations.

    The parameters you will most likely need to change are the ones which depend on the environment / sensors.
    Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.

    Args:
        n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
            action repeats in Q-learning or ask your favorite chatbot)
        horizon: Horizon for model predictive control.
        input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
            the input data name, and the value is a list indicating the dimensions of the corresponding data.
            For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
            indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
            include batch dimension or temporal dimension.
        output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
            the output data name, and the value is a list indicating the dimensions of the corresponding data.
            For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
            Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
        input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
            and the value specifies the normalization mode to apply. The two available modes are "mean_std"
            which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
            [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
            match the original implementation.
        output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
            original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
            to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
            normalization mode here.
        image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
        state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
        latent_dim: Observation's latent embedding dimension.
        q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
        mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
            (π), Q ensemble, and V.
        discount: Discount factor (γ) to use for the reinforcement learning formalism.
        use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
            (π) for each step.
        cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
        max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
        min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
            Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
        n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
            be non-zero.
        n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
            be zero.
        uncertainty_regularizer_coeff: Coefficient for the uncertainty regularization used when estimating
            trajectory values (this is the λ coeffiecient in eqn 4 of FOWM).
        n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
        elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
            elites, when updating the gaussian parameters for CEM.
        gaussian_mean_momentum: Momentum (α) used for EMA updates of the mean parameter μ of the gaussian
            parameters optimized in CEM. Updates are calculated as μ⁻ ← αμ⁻ + (1-α)μ.
        max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
            image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
            is applied. Note that the input images are assumed to be square for this augmentation.
        reward_coeff: Loss weighting coefficient for the reward regression loss.
        expectile_weight: Weighting (τ) used in expectile regression for the state value function (V).
            v_pred < v_target is weighted by τ and v_pred >= v_target is weighted by (1-τ). τ is expected to
            be in [0, 1]. Setting τ closer to 1 results in a more "optimistic" V. This is sensible to do
            because v_target is obtained by evaluating the learned state-action value functions (Q) with
            in-sample actions that may not be always optimal.
        value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
            value (V) expectile regression loss.
        consistency_coeff: Loss weighting coefficient for the consistency loss.
        advantage_scaling: A factor by which the advantages are scaled prior to exponentiation for advantage
            weighted regression of the policy (π) estimator parameters. Note that the exponentiated advantages
            are clamped at 100.0.
        pi_coeff: Loss weighting coefficient for the action regression loss.
        temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
            steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
            current time step.
        target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
            as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
            model being trained.
    """

    # Input / output structure.
    n_action_repeats: int = 2
    horizon: int = 5

    input_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "observation.image": [3, 84, 84],
            "observation.state": [4],
        }
    )
    output_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "action": [4],
        }
    )

    # Normalization / Unnormalization
    input_normalization_modes: dict[str, str] | None = None
    output_normalization_modes: dict[str, str] = field(
        default_factory=lambda: {"action": "min_max"},
    )

    # Architecture / modeling.
    # Neural networks.
    image_encoder_hidden_dim: int = 32
    state_encoder_hidden_dim: int = 256
    latent_dim: int = 50
    q_ensemble_size: int = 5
    mlp_dim: int = 512
    # Reinforcement learning.
    discount: float = 0.9

    # Inference.
    use_mpc: bool = True
    cem_iterations: int = 6
    max_std: float = 2.0
    min_std: float = 0.05
    n_gaussian_samples: int = 512
    n_pi_samples: int = 51
    uncertainty_regularizer_coeff: float = 1.0
    n_elites: int = 50
    elite_weighting_temperature: float = 0.5
    gaussian_mean_momentum: float = 0.1

    # Training and loss computation.
    max_random_shift_ratio: float = 0.0476
    # Loss coefficients.
    reward_coeff: float = 0.5
    expectile_weight: float = 0.9
    value_coeff: float = 0.1
    consistency_coeff: float = 20.0
    advantage_scaling: float = 3.0
    pi_coeff: float = 0.5
    temporal_decay_coeff: float = 0.5
    # Target model.
    target_model_momentum: float = 0.995

    def __post_init__(self):
        """Input validation (not exhaustive)."""
        # There should only be one image key.
        image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
        if len(image_keys) != 1:
            raise ValueError(
                f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
            )
        image_key = next(iter(image_keys))
        if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
            # TODO(alexander-soare): This limitation is solely because of code in the random shift
            # augmentation. It should be able to be removed.
            raise ValueError(
                f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
            )
        if self.n_gaussian_samples <= 0:
            raise ValueError(
                f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
            )
        if self.output_normalization_modes != {"action": "min_max"}:
            raise ValueError(
                "TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
                f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
                "information."
            )
