import numpy as np
import torch

from erl_lib.agent.sac import SACAgent
from erl_lib.agent.svg import SVGAgent


class MPC(SVGAgent):
    def __init__(
        self, *args, mpc_horizon=5, action_sample=10, num_elite_actions=3, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.mpc_horizon = mpc_horizon
        self.action_sample = action_sample
        self.num_elite_actions = num_elite_actions

    def observe(self, obs, action, reward, next_obs, terminated, truncated, info):
        SACAgent.observe(
            self, obs, action, reward, next_obs, terminated, truncated, info
        )

    def _act(self, obs, sample):
        if 0 < self.mpc_horizon and sample:
            action = self.plan(
                obs, self.mpc_horizon, self.action_sample, self.num_elite_actions
            )
            return action.detach().cpu().numpy()
        else:
            return super()._act(obs, sample)

    def update_model(self, first_update=False):
        pass

    def update(self, opt_step, log=False, buffer=None):
        SACAgent.update(self, opt_step, log, buffer)

    def update_actor(self, obs, reward, log=False, **kwargs):
        return SACAgent.update_actor(self, obs, reward, log, **kwargs)

    def update_critic(self, replay_buffer, log=False):
        return SACAgent.update_critic(self, replay_buffer, log)

    def plan(self, obs, horizon, num_action_samples, num_elites):
        """MPPI trajectory optimization.

        Args:
            obs: [B, D]
            horizon:
            num_action_samples:
            sample:

        Returns:

        """
        self.actor_mpc_scale = 0.1
        self.mpc_iter = 5

        with torch.no_grad(), self.policy_evaluation_context() as ctx_modules:
            # Preprocess before feeding obs into actor
            batch_size = obs.shape[0]
            if isinstance(obs, np.ndarray):
                obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
            if self.normalize_input:
                actor_obs = self.input_normalizer.normalize(obs)
            else:
                actor_obs = obs
            # Take initial actions
            action, log_pi, pi = self.sample_action(ctx_modules.actor, actor_obs)

            output = self.rollout(
                ctx_modules,
                obs,
                action,
                horizon - 1,
                scales=[pi.scale],
            )
            actions = output[1]  # [(H-1) * B, D]
            last_action = output[6]  # [B, D]

            actions = torch.vstack([actions, last_action]).view(
                (horizon, batch_size, self.dim_act)
            )  # [H, B, D]
            mean = actions.repeat_interleave(num_action_samples, dim=1)  # [H, B * S, D]
            scale = (
                (output[2] * self.actor_mpc_scale)
                .clamp(0.05, 2.0)
                .reshape((horizon, batch_size, self.dim_act))
                .repeat_interleave(num_action_samples, dim=1)
            )  # [H, B * S, D]
            obs = obs.repeat_interleave(num_action_samples, dim=0)

            for iter in range(self.mpc_iter):
                # Sample population
                epsilon = torch.randn(
                    (horizon, batch_size * num_action_samples, self.dim_act),
                    device=self.device,
                    dtype=torch.float32,
                )
                actions = mean + scale * epsilon
                actions.clamp_(-1.0, 1.0)
                # Evaluate the population
                output = self.rollout(
                    ctx_modules.alpha,
                    obs,
                    actions,
                    ctx_modules.model_step,
                    horizon,
                )
                masks = output[3]
                rewards = output[4]
                last_obs = output[5]
                last_action = output[6]
                last_oa = torch.cat([last_obs, last_action], 1)
                last_q = ctx_modules.critic(last_oa)
                last_q = self._reduce(last_q, self.actor_reduction)
                rewards = torch.stack(rewards + [last_q[:, 0]], 0)
                values = self.discount_mat[:1, : len(rewards)].mm(rewards * masks).t()
                # Update parameters
                elite_values, elite_idx = values.view(-1, num_action_samples).topk(
                    num_elites, dim=1
                )  # [B, E]
                max_value = elite_values.max(1, keepdim=True).values
                elite_action_idx = elite_idx[None, :, :, None].expand(
                    horizon, batch_size, num_elites, self.dim_act
                )
                actions_view = actions.view(
                    horizon, batch_size, num_action_samples, self.dim_act
                )
                elite_actions = torch.take_along_dim(
                    actions_view, elite_action_idx, 2
                )  # [H, B, E, D]
                score = torch.exp(elite_values - max_value)[
                    None, ..., None
                ]  # [H, B, E, D]
                norm = score.sum(2, keepdim=True)
                score /= norm
                mean = torch.sum(score * elite_actions, dim=2) / (
                    norm.squeeze(2) + 1e-9
                )  # [H, B, D]
                scale = torch.sqrt(
                    torch.sum(
                        score * torch.square(elite_actions - mean.unsqueeze(2)),
                        dim=2,
                    )
                    / (score.sum(2) + 1e-9)
                ).clamp(0.05, 2.0)
                mean = mean.repeat_interleave(num_action_samples, dim=1)  # [H, B, S, D]
                scale = scale.repeat_interleave(num_action_samples, dim=1)
        # Select action
        onehot_idx = torch.distributions.OneHotCategorical(probs=score[0]).sample()
        actions = (elite_actions[0, ...] * onehot_idx).sum(1)
        # if self.stochastic_mpc:
        #     epsilon = torch.randn((batch_size, self.dim_act), device=self.device, dtype=torch.float32)
        #     actions.add_(scale[0, :batch_size, :] * epsilon)
        actions.clamp_(-1.0, 1.0)
        self._info.update(
            **{
                "mpc_scale": scale.mean(),
                "mpc_max_score": score.max(),
                "mpc_min_score": score.min(),
                "elite_value": elite_values.mean(),
            }
        )
        return actions

    def sample_action(self, actor, obs, log=False, scale=None, **kwargs):
        if scale is None:
            pi = actor(obs, log=log)
            action = pi.rsample()
            # log_prob = dist.log_prob(action).sum(-1, keepdims=True)
            #
            # action, _, pi = super().sample_action(actor, obs, log=log, **kwargs)
            return action, pi.scale, pi
        else:
            actions = kwargs["actions"]
            action = actions[0]
            actions[:] = actions[1:]
            return action, action, None
