from typing import Callable, Tuple

import hydra
import numpy as np
import torch
from torch import nn

from lambda_ac.rl_types import (
    EncoderActorModule,
    EncoderCriticModule,
    EncoderModelNetwork,
    FeatureInput,
    PlanningStrategy,
)
from lambda_ac.util.model_util import evaluate_trajectory, rollout_model_with_actions
from lambda_ac.util.schedulers import LinearSchedule
from lambda_ac.util.torch_util import as_eval
from lambda_ac.lambda_ac_vaml import ViperVAML


class ActorFowardStrategy(PlanningStrategy):
    def __init__(
        self,
        model: EncoderModelNetwork,
        critic: EncoderCriticModule,
        actor: EncoderActorModule,
        seed_steps: int = 0,
        device: str = "cuda",
        share_encoder: bool = True,
    ) -> None:
        # IK I can lump the device and share_encoder init. into the parent
        # class init function, but will leave it to the child class for clarity
        super().__init__(model, critic, actor)
        self.action_dim = actor.action_dim
        self.device = device
        self.share_encoder = share_encoder
        self._total_steps = 0
        self.seed_steps = seed_steps

    def set_networks(self, actor, critic, model):
        self.actor = actor
        self.critic = critic
        self.model = model

    def plan(
        self,
        state: torch.Tensor,
        alpha: torch.Tensor,
        eval: bool = False,
        step: int = 0,
        episode: int = 0,
    ):
        self._total_steps += 1
        if self._total_steps < self.seed_steps:
            return 2 * (torch.rand(size=(self.action_dim,)).to(self.device) - 0.5)
        with as_eval(self.actor):
            if eval:
                action = self.actor.forward(state).mean.squeeze(0)
            else:
                dist = self.actor.forward(state)
                action = dist.rsample().squeeze(0)
        return action


class TdmpcStrategy(PlanningStrategy):
    def __init__(
        self,
        model: EncoderModelNetwork,
        critic: EncoderCriticModule,
        actor: EncoderActorModule,
        std_schedule: LinearSchedule,
        horizon_schedule: LinearSchedule,
        horizon: int = 5,
        seed_steps: int = 0,
        action_dim: int = 1,
        mixture_coef: float = 0.1,
        num_samples: int = 500,
        momentum: float = 0.99,
        iterations: int = 10,
        temperature: float = 0.1,
        num_elites: int = 5,  # K from the paper
        min_std: float = 0.01,
        discount: float = 1.0,
        discretize_done: bool = True,
        update_encoder_actor: bool = False,
        device: str = "cuda",
        share_encoder: bool = True,
    ) -> None:
        super().__init__(model, critic, actor)
        self.share_encoder = share_encoder
        self.device = device
        self.detach_encoder = not update_encoder_actor and self.share_encoder

        # hyper-params
        self.std_schedule = std_schedule
        self.horizon_schedule = horizon_schedule
        self.horizon = horizon
        self.seed_steps = seed_steps
        self.action_dim = action_dim
        self.mixture_coef = mixture_coef
        self.num_samples = num_samples
        self.momentum = momentum
        self.iterations = iterations
        self.min_std = min_std
        self.discount = discount
        self.temperature = temperature
        self.num_elites = num_elites
        self.discretize_done = discretize_done

        self._total_steps = 0
        self._episode = -1
        self._last_horizon = 0

        self._prev_mean = None

    def set_networks(self, actor, critic, model):
        self.actor = actor
        self.critic = critic
        self.model = model

    @torch.no_grad()
    def plan(
        self,
        state: torch.Tensor,
        alpha: torch.Tensor,
        eval: bool = False,
        step: int = 0,
        episode: int = 0,
    ):
        """
        Plan next action using TD-MPC inference.
        state: raw input observation.
        eval: uniform sampling and action noise is disabled during evaluation.
        step: current time step in the episode. determines e.g. planning horizon.
        first_step logic can be derived from step it seems
        episode: Added this to setup scedueler
        """
        # Seed steps
        if not eval:
            self._total_steps += 1

        with as_eval(self.model, self.actor, self.critic):

            self.std = self.std_schedule(self._total_steps)
            if self._total_steps < self.seed_steps:
                return torch.empty(
                    self.action_dim, dtype=torch.float32, device=self.device
                ).uniform_(-1, 1)

            # Sample policy trajectories
            state = state.float().to(self.device)

            # horizon = int(min(self.horizon, h.linear_schedule(self.horizon_schedule, step)))
            horizon = int(
                min(self.horizon, self.horizon_schedule(self._total_steps).item())
            )

            num_pi_trajs = int(self.mixture_coef * self.num_samples)
            pi_actions = torch.empty(
                num_pi_trajs, horizon, self.action_dim, device=self.device
            )
            if num_pi_trajs > 0:
                next_feature = FeatureInput.from_output(self.critic.encoder(state))
                next_feature = next_feature.repeat(num_pi_trajs, 1)

                for t in range(horizon):
                    pi_actions[:, t], _, _ = self.actor.head.forward_sample_log_prob(
                        next_feature
                    )
                    pi_actions[:, t] = torch.clamp(
                        pi_actions[:, t]
                        + torch.randn_like(pi_actions[:, t]) * self.min_std,
                        -1,
                        1,
                    )
                    (
                        next_feature_dist,
                        _,
                        _,
                    ) = self.model(next_feature, pi_actions[:, t])

                    next_feature = FeatureInput.from_output(next_feature_dist)

            # Initialize state and parameters

            mean = torch.zeros(horizon, self.action_dim, device=self.device)
            if episode == self._episode and self._prev_mean is not None:
                mean[: self._last_horizon - 1] = self._prev_mean[1 : self._last_horizon]
            std = torch.ones(horizon, self.action_dim, device=self.device) * 2.0
            mean_expanded = mean.unsqueeze(0).repeat(self.num_samples, 1, 1)
            std_expanded = std.unsqueeze(0).repeat(self.num_samples, 1, 1)

            state = state.repeat(self.num_samples, 1)

            # Iterate CEM
            for _ in range(self.iterations):
                actions = torch.clamp(
                    mean_expanded + std_expanded * torch.randn_like(std_expanded),
                    -1,
                    1,
                )
                if num_pi_trajs > 0:
                    actions[:num_pi_trajs] = pi_actions

                # Compute elite actions
                trajectory = rollout_model_with_actions(
                    self.model.encoder(state),
                    actions,
                    torch.zeros_like(actions[:, 0, :1]),
                    self.model,
                    discretize_done=self.discretize_done,
                    detach_encoder=True,
                    predict_done=True,
                )
                _, _, values = evaluate_trajectory(
                    trajectory,
                    self.actor,
                    self.critic,
                    torch.ones_like(trajectory.rewards[:, 0]) * self.discount,
                    alpha,
                    torch.min,
                    use_td_average=False,
                    add_first=True,
                )

                values = values.squeeze()

                # values = values.squeeze(1)
                elite_idxs = torch.topk(values, self.num_elites, dim=0).indices
                elite_value = values[elite_idxs]
                elite_actions = actions[elite_idxs]

                # Update parameters
                max_value = elite_value.max(0)[0]
                score = torch.exp(self.temperature * (elite_value - max_value))
                score /= score.sum(0)
                _mean = torch.sum(score.view(-1, 1, 1) * elite_actions, dim=0) / (
                    score.sum(0) + 1e-9
                )
                _std = torch.sqrt(
                    torch.sum(
                        score.view(-1, 1, 1)
                        * (elite_actions - _mean.view(1, horizon, -1)) ** 2,
                        dim=0,
                    )
                    / (score.sum(0) + 1e-9)
                )
                std = _std.clamp_(self.std, 2)
                mean = self.momentum * mean + (1 - self.momentum) * _mean
                mean_expanded = mean.unsqueeze(0).repeat(self.num_samples, 1, 1)
                std_expanded = std.unsqueeze(0).repeat(self.num_samples, 1, 1)

            # Outputs
        score = score.cpu().numpy()
        action = elite_actions[np.random.choice(np.arange(score.shape[0]), p=score)][0]
        self._prev_mean = mean
        self._episode = episode
        self._last_horizon = horizon

        if not eval:
            action += std[0] * torch.randn_like(action)
        return action
