import numpy as np
import torch

from lfrl.policies.mpc.mpc import MPCPolicy
import lfrl.torch.pytorch_util as ptu
from lfrl.util.eval_util import create_stats_ordered_dict


class MPCPolicyController(MPCPolicy):

    """
    Perform MPC planning over a policy that takes in an additional latent.
    """

    def __init__(
            self,
            policy,                             # control policy to run that takes in a latent
            latent_dim,                         # dimension of the latent to feed the policy
            *args,
            **kwargs
    ):
        super().__init__(plan_dim=latent_dim, *args, **kwargs)
        self.policy = policy

    def convert_plan_to_action(self, obs, plan, deterministic=False):
        # return plan
        action, *_ = self.policy.get_action(
            np.concatenate((obs, plan), axis=-1),
            deterministic=True,
        )
        return action

    def convert_plans_to_actions(self, obs, plans, deterministic=True):
        # return plans
        actions, *_ = self.policy(
            torch.cat((obs, plans), dim=-1),
            deterministic=deterministic,
        )
        return actions
