"""Defines the policy class, including the main training step logic."""

from typing import Optional, Tuple, Type

import gym
from gym import spaces
import numpy as np
import pytorch_lightning as pl
from stable_baselines3.common import policies, type_aliases, utils
import torch
from torch import nn, optim

from rvs import dataset, layers, step, util


def make_obs_goal_space(
    observation_space: gym.Space,
    unconditional_policy: bool = False,
    reward_conditioning: bool = False,
    xy_conditioning: bool = False,
    # xy_reward_conditioning: bool = False,
) -> gym.Space:
    """Create the policy's input space.

    This includes the observation as well as a possible goal.

    Args:
        observation_space: The observation space of the environment. By default, it's
            duplicated to create the goal space.
        unconditional_policy: If True, do not use any goals, and only use the
            observation_space.
        reward_conditioning: If True, condition on a reward scalar appended to the
            observation_space.
        xy_conditioning: If True, condition on (x, y) coordinates appended to the
            observation_space.

    Returns:
        The new space including observation and goal.

    Raises:
        ValueError: If conflicting types of spaces are specified.
    """
    if sum([unconditional_policy, reward_conditioning, xy_conditioning]) > 1:
        raise ValueError("You must choose at most one policy conditioning setting.")

    if unconditional_policy:
        return observation_space
    # elif xy_reward_conditioning:
    #     return util.add_scalar_to_space(
    #         util.add_scalar_to_space(util.add_scalar_to_space(observation_space))
    #     )
    elif reward_conditioning:
        return util.add_scalar_to_space(observation_space)
    elif xy_conditioning:
        return util.add_scalar_to_space(util.add_scalar_to_space(observation_space))
    else:
        return util.create_observation_goal_space(observation_space)


class RvS(pl.LightningModule):
    """A Reinforcement Learning via Supervised Learning (RvS) policy."""

    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        hidden_size: int = 1024,
        depth: int = 2,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-4,
        learning_rate_scheduler: str = "constant",
        batch_size: int = 256,
        activation_fn: Type[nn.Module] = nn.ReLU,
        dropout_p: float = 0.1,
        obs_noise: float = 0.0,
        unconditional_policy: bool = False,
        reward_conditioning: bool = False,
        env_name: Optional[str] = None,
        mean_obs: Optional[np.ndarray] = None,
        std_obs: Optional[np.ndarray] = None,
    ):
        """Builds RvS.

        Args:
            observation_space: The policy's observation space
            action_space: The policy's action space
            hidden_size: The width of each hidden layer
            depth: The number of hidden layers
            learning_rate: A learning rate held constant throughout training
            batch_size: The batch size for each gradient step
            activation_fn: The network's activation function
            dropout_p: The dropout probability
            unconditional_policy: If True, ignore goals and act only based on
                observations
            reward_conditioning: If True, condition on a reward scalar instead of future
                observations
            env_name: The name of the environment for which to configure the policy
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.learning_rate_scheduler = learning_rate_scheduler
        self.unconditional_policy = unconditional_policy
        self.batch_size = batch_size
        self.save_hyperparameters(
            "hidden_size",
            "depth",
            "learning_rate",
            "batch_size",
            "activation_fn",
            "dropout_p",
            "unconditional_policy",
            "reward_conditioning",
            "env_name",
        )

        xy_conditioning = (
            env_name in step.d4rl_antmaze
            and not unconditional_policy
            and not reward_conditioning
        )
        # xy_reward_conditioning = (
        #     env_name in step.d4rl_antmaze
        #     and not unconditional_policy
        #     and reward_conditioning
        # )
        observation_goal_space = make_obs_goal_space(
            observation_space,
            unconditional_policy=unconditional_policy,
            reward_conditioning=reward_conditioning,
            xy_conditioning=xy_conditioning,
            # xy_reward_conditioning=xy_reward_conditioning,
        )
        lr_schedule = utils.constant_fn(learning_rate)
        net_arch = [hidden_size] * depth
        layers.DropoutActivation.activation_fn = activation_fn
        layers.DropoutActivation.p = dropout_p
        self.model = ExtendedActorCriticPolicy(
            observation_goal_space,
            action_space,
            lr_schedule,
            net_arch=net_arch,
            activation_fn=layers.DropoutActivation,
            obs_mean=mean_obs,
            obs_std=std_obs,
            obs_noise=obs_noise,
        )

    def on_fit_start(self) -> None:
        # get the data module from the trainer
        data_module = self.trainer.datamodule
        # get the mean and std of the observations
        obs_mean, obs_std = data_module.mean_std()
        # set the mean and std of the model
        self.model.setup_normalization(obs_mean, obs_std)

    def forward(
        self,
        *args,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute a forward pass with the model."""
        return self.model.forward(*args, **kwargs)

    def training_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
        log_prefix: str = "train",
    ) -> torch.Tensor:
        """Computes loss for a training batch."""
        obs_goal, action = batch
        _, log_probs, _, prediction = self.model.evaluate_and_predict(
            obs_goal,
            action,
        )

        prob_true_act = torch.exp(log_probs).mean()
        loss = -log_probs.mean()

        self.log(f"{log_prefix}_prob_true_act", prob_true_act)
        self.log(f"{log_prefix}_loss", loss, prog_bar=True)
        try:
            self.log(f"{log_prefix}_std", torch.exp(self.model.log_std).mean())
            self.log(f"{log_prefix}_log_std", self.model.log_std.mean())
        except AttributeError:
            pass
        if prediction is not None:
            self.log(f"{log_prefix}_mse", ((prediction - action) ** 2).mean())

        return loss

    def validation_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> torch.Tensor:
        """Computes loss for a validation batch."""
        loss = self.training_step(batch, batch_idx, log_prefix="val")
        return loss

    def configure_optimizers(self) -> optim.Optimizer:
        """Configures the optimizer used by PyTorch Lightning."""
        weight_decay = 1e-5
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
        )
        if self.learning_rate_scheduler == "constant":
            return optimizer
        elif self.learning_rate_scheduler == "linear":
            return {
                "optimizer": optimizer,
                "lr_scheduler": optim.lr_scheduler.LambdaLR(
                    optimizer,
                    lr_lambda=lambda epoch: 1 - epoch / self.trainer.max_epochs,
                ),
            }
        elif self.learning_rate_scheduler == "cosine":
            return {
                "optimizer": optimizer,
                "lr_scheduler": optim.lr_scheduler.CosineAnnealingLR(
                    optimizer,
                    T_max=self.trainer.max_epochs,
                ),
            }
        else:
            raise ValueError(
                f"Unknown learning rate scheduler: {self.learning_rate_scheduler}"
            )

    ##################
    # POLICY FUNCTIONS
    ##################

    def get_probabilities(
        self,
        observations: np.ndarray,
        goals: np.ndarray,
        actions: np.ndarray,
    ) -> torch.Tensor:
        """Get the policy's probabilities.

        Returns a probability for each action given the corresponding observation and
        goal.
        """
        assert actions.shape[0] == observations.shape[0] == goals.shape[0]
        s_g_tensor = dataset.make_s_g_tensor(observations, goals)
        a_tensor = torch.tensor(actions)

        self.model.eval()
        with torch.no_grad():
            s_g_tensor = self.model.process_obs(s_g_tensor, eval=True)
            _, log_probs, _ = self.model.evaluate_actions(s_g_tensor, a_tensor)
        probs = torch.exp(log_probs)
        return probs

    def get_action(
        self,
        observation: np.ndarray,
        goal: np.ndarray,
        deterministic: bool = True,
    ) -> np.ndarray:
        """Get an action for a single observation / goal pair."""
        if len(observation.shape) == 2:
            return self.get_actions(
                observation,
                goal,
                deterministic=deterministic,
            )
        return self.get_actions(
            observation[np.newaxis],
            goal[np.newaxis],
            deterministic=deterministic,
        )[0]

    def get_actions(
        self,
        observations: np.ndarray,
        goals: np.ndarray,
        deterministic: bool = True,
    ) -> np.ndarray:
        """Get actions for each observation / goal pair."""
        assert observations.shape[0] == goals.shape[0]

        self.model.eval()
        with torch.no_grad():
            if self.unconditional_policy:
                s_tensor = torch.tensor(observations).to(self.device)
                s_tensor = self.model.process_obs(s_tensor, eval=True)
                actions = self.model._predict(s_tensor, deterministic=deterministic)
            else:
                s_g_tensor = dataset.make_s_g_tensor(
                    observations, goals).to(self.device)
                s_g_tensor = self.model.process_obs(s_g_tensor, eval=True)
                actions = self.model._predict(s_g_tensor, deterministic=deterministic)

        return actions.cpu().numpy()


class ExtendedActorCriticPolicy(policies.ActorCriticPolicy):
    """Extends the functionality of stable-baseline3's ActorCriticPolicy.

    The extended functionality includes:
    - Action and value predictions at the same time as evaluating probabilities.
    - The option to skip value function computation.
    """

    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        lr_schedule: type_aliases.Schedule,
        compute_values: bool = False,
        obs_mean: Optional[np.ndarray] = None,
        obs_std: Optional[np.ndarray] = None,
        obs_noise: float = 0.0,
        **kwargs,
    ):
        """Builds ExtendedActorCriticPolicy.

        Args:
            observation_space: The policy's observation space.
            action_space: The policy's action space.
            lr_schedule: A learning rate schedule.
            compute_values: We'll skip value function computation unless this is True.
            **kwargs: Keyword arguments passed along to parent class.
        """
        self.compute_values = compute_values

        super(ExtendedActorCriticPolicy, self).__init__(
            observation_space,
            action_space,
            lr_schedule,
            **kwargs,
        )

        # Set obs_mean and obs_std to 0 and 1 (as torch parameters) if they're not provided.
        if obs_mean is None:
            obs_mean = np.zeros(observation_space.shape)
        if obs_std is None:
            obs_std = np.ones(observation_space.shape)
        self.obs_mean = torch.nn.Parameter(torch.tensor(obs_mean, dtype=torch.float32))
        self.obs_std = torch.nn.Parameter(torch.tensor(obs_std, dtype=torch.float32))
        self.obs_noise = obs_noise

    # Takes tensors and sets them as paramters.
    def setup_normalization(self, obs_mean: torch.Tensor, obs_std: torch.Tensor):
        """Update the mean and std parameters and make sure they are on the same device."""
        self.obs_mean = torch.nn.Parameter(obs_mean.to(self.obs_mean))
        self.obs_std = torch.nn.Parameter(obs_std.to(self.obs_std))
        print("Set up normalization.")
        print(f"obs_mean: {self.obs_mean}")
        print(f"obs_std: {self.obs_std}")

    def process_obs(self, obs: torch.Tensor, eval: bool = False) -> torch.Tensor:
        if self.obs_mean is not None:
            obs = (obs - self.obs_mean) / self.obs_std
        if self.obs_noise > 0 and not eval:
            obs = obs + torch.randn_like(obs) * self.obs_noise
        return obs

    # TODO: once the experiment is done, add resnet support and other custom architectures to this.

    def _build(self, lr_schedule: type_aliases.Schedule) -> None:
        super(ExtendedActorCriticPolicy, self)._build(lr_schedule)
        # Skip unused value function computation
        if not self.compute_values:
            self.value_net = nn.Sequential()  # Identity function
            self.optimizer = self.optimizer_class(
                self.parameters(),
                lr=lr_schedule(1),
                **self.optimizer_kwargs,
            )

    def evaluate_and_predict(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Evaluate probability of actions and provide action and value prediction."""
        # Normalize obs
        obs = self.process_obs(obs, eval=False)
        latent_pi, latent_vf, latent_sde = self._get_latent(obs)
        distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
        log_prob = distribution.log_prob(actions)
        values = self.value_net(latent_vf)
        if isinstance(self.action_space, spaces.Box):
            predictions = distribution.get_actions(deterministic=True)
        else:
            predictions = None
        return values, log_prob, distribution.entropy(), predictions
