from typing import Dict, Tuple

import gym
import pytorch_lightning as pl
from ray.rllib.policy.sample_batch import SampleBatch
import torch
import torch.nn as nn

from offline_rl.rewards.reward_model import RewardModel


class PreferenceBasedRewardModel(RewardModel, pl.LightningModule):
    """A reward model trained based on preferences between trajectory segments.

    This class is designed for use only in an artificial preference-based reward learning setting.
    This manifests mainly in the requirement that it be provided a `target_model`, which is a
    reward model that is used by this class to artificially define preferences between segments.

    The algorithm implemented by this class is inspired by the paper:
    Christiano, Paul, Jan Leike, Tom B. Brown, Miljan Martic, Shane Legg, and Dario Amodei.
    "Deep reinforcement learning from human preferences." arXiv preprint arXiv:1706.03741 (2017).

    The implementation is loosely based on the one here:
    https://github.com/HumanCompatibleAI/evaluating-rewards/blob/master/src/evaluating_rewards/rewards/preferences.py

    Args:
        submodel_cls: The sub-reward-model delegated to in computing the reward value.
            This class delegates like this so that it can easily train different types
            of networks for different types of input spaces. The sub-model must be
            instantiated inside the class so torch recognizes its parameters.
        submodel_kwargs: The key word arguments passed to the the submodel init method.
        target_model: The model defining the preferences between trajectory segments.
        obs_key: The key to use in accessing the obserations.
        act_key: The key to use in accessing the actions.
        next_obs_key: The key to use in accessing the next observations.
        terminals_key: The key to use in accessing the terminal indicators.
        learning_rate: The learning rate passed to the optimizer.
        reward_reg_weight: Weight to apply to the reward norm loss component.
    """
    def __init__(
            self,
            submodel_cls: type,
            submodel_kwargs: Dict,
            target_model: RewardModel,
            obs_key: str = SampleBatch.OBS,
            act_key: str = SampleBatch.ACTIONS,
            next_obs_key: str = SampleBatch.NEXT_OBS,
            terminals_key: str = SampleBatch.DONES,
            learning_rate: float = 5e-4,
            reward_reg_weight: float = 1e-5,
    ):
        super().__init__()
        self.model = submodel_cls(**submodel_kwargs)
        self.target_model = target_model
        self.obs_key = obs_key
        self.act_key = act_key
        self.next_obs_key = next_obs_key
        self.terminals_key = terminals_key
        self.learning_rate = learning_rate
        self.reward_reg_weight = reward_reg_weight
        self.labeling_loss_fn = nn.CrossEntropyLoss()

    def forward(self, *args) -> torch.Tensor:
        """Runs a forward pass through this model.

        Returns:
            The unnormalized reward values, which this model later interprets as logits.
        """
        return self.model.forward(*args)

    def reward(self, *args) -> torch.Tensor:
        """Computes the reward value.

        Returns:
            The reward, which for this model is between 0 and 1.
        """
        return self.model.reward(*args)

    def generic_step(self, batch: Dict, batch_idx: int, split: str) -> torch.Tensor:
        """Performs a generic step of this model.

        Args:
            batch: The batch of data used in this step.
            batch_idx: The index of the batch.
            split: The train/val/test split.

        Returns:
            The loss computed during this step.
        """
        del batch_idx
        prefs = self._compute_target_preferences(batch)
        pred_pref_logits, rewards = self._predict_preferences(batch)
        labeling_loss = self.labeling_loss_fn(pred_pref_logits, prefs)

        # Compute the norm over the rewards of each segment, than average over those norms.
        reward_reg_loss = torch.norm(rewards, dim=2).mean()
        overall_loss = labeling_loss + reward_reg_loss * self.reward_reg_weight

        # Compute and log accuracy for a more interpretable measure of performance.
        first_preferred = ~(prefs.to(bool))
        pred_first_preferred = pred_pref_logits[:, 0] > pred_pref_logits[:, 1]
        accuracy = (torch.eq(first_preferred, pred_first_preferred)).to(torch.float32).mean()

        on_step = True if split == "train" else False
        self.log(f"{split}_labeling_loss", labeling_loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log(f"{split}_reward_reg_loss", reward_reg_loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log(f"{split}_loss", overall_loss, on_step=on_step, on_epoch=True, prog_bar=True, logger=True)
        self.log(f"{split}_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return overall_loss

    def _compute_target_preferences(self, batch: Dict) -> torch.Tensor:
        """Computes the target preferences between segments in the batch.

        Args:
            batch: The batch of segments, including the obs, action, next_obs of each segment.

        Returns:
            The preferences between segments represented as the index of the preferred segment.
        """
        obs, act, next_obs = batch[SampleBatch.OBS], batch[SampleBatch.ACTIONS], batch[SampleBatch.NEXT_OBS]
        batch_size, num_segments, segment_length, _ = obs.shape
        assert num_segments == 2
        rewards = self.target_model.reward(
            obs.flatten(0, -2),
            act.flatten(0, -2),
            next_obs.flatten(0, -2),
            # By design this class doesn't care about terminals because it doesn't consider a discount factor,
            # which is a bit weird because it means that segments that span episodes can validly be compared
            # at least in the artificial case (in the real-world case such a preference is hard to elicit).
            terminals=None,
        )
        rewards = rewards.reshape(batch_size, num_segments, segment_length)
        returns = rewards.sum(dim=-1)
        preferences = returns.argmax(dim=-1)
        return preferences

    def _predict_preferences(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes the predicted preferences of the model being trained.

        Args:
            batch: The batch of segments, including the obs, action, next_obs of each segment.

        Returns:
            A tuple. The first element are the logits of the preferences. The second are the base rewards.
        """
        obs, act, next_obs = batch[SampleBatch.OBS], batch[SampleBatch.ACTIONS], batch[SampleBatch.NEXT_OBS]
        batch_size, num_segments, segment_length, _ = obs.shape
        assert num_segments == 2
        rewards = self.model.reward(
            obs.flatten(0, -2),
            act.flatten(0, -2),
            next_obs.flatten(0, -2),
            terminals=None,
        )
        rewards = rewards.reshape(batch_size, num_segments, segment_length)
        returns = rewards.sum(dim=-1)
        return returns, rewards

    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "train")

    def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "val")

    def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "test")

    def configure_optimizers(self):
        """Sets up the optimizer for this model."""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.model.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.model.action_space
