from typing import Dict

import gym
import pytorch_lightning as pl
from ray.rllib.policy.sample_batch import SampleBatch
import torch
import torch.nn as nn

from offline_rl.rewards.reward_model import RewardModel


class DiscriminativeRewardModel(RewardModel, pl.LightningModule):
    """A discriminatively-trained reward model.

    This class implements a reward model and discriminative training of that reward model.
    The reasons this class isn't called "DiscriminativeRewardTrainer" are (i) to avoid a 
    naming conflict with `pytorch_lightning.Trainer` and (ii) because this class actually
    is a reward model _distinct from the submodel it instantiates_ for the reason that it
    outputs reward values between 0 and 1.

    This class implements a positive-unlabeled / discriminative reward learning algorithm.

    It essentially trains a classifier to predict whether a given (s, a, s') triple
    comes from demonstrations that are known to be of good quality or from demonstrations
    that are unlabeled, and might be of either poor quality or of good quality.

    For more details, see the following paper:

    Zolna, Konrad, Alexander Novikov, Ksenia Konyushkova, Caglar Gulcehre, Ziyu Wang,
    Yusuf Aytar, Misha Denil, Nando de Freitas, and Scott Reed. "Offline learning from
    demonstrations and unlabeled experience." arXiv preprint arXiv:2011.13885 (2020).

    Args:
        submodel_cls: The sub-reward-model delegated to in computing the reward value.
            This class delegates like this so that it can easily train different types
            of networks for different types of input spaces. The sub-model must be
            instantiated inside the class so torch recognizes its parameters.
        submodel_kwargs: The key word arguments passed to the the submodel init method.
        obs_key: The key to use in accessing the obserations.
        act_key: The key to use in accessing the actions.
        next_obs_key: The key to use in accessing the next observations.
        terminals_key: The key to use in accessing the terminal indicators.
        use_next_obs: Whether to pass the next obs to the underlying reward model.
        use_terminals: Whether to pass the terminals to the underlying reward model.
        learning_rate: The learning rate passed to the optimizer.
    """
    def __init__(self,
                 submodel_cls: type,
                 submodel_kwargs: Dict,
                 obs_key: str = SampleBatch.OBS,
                 act_key: str = SampleBatch.ACTIONS,
                 next_obs_key: str = SampleBatch.NEXT_OBS,
                 terminals_key: str = SampleBatch.DONES,
                 use_next_obs: bool = True,
                 use_terminals: bool = True,
                 learning_rate: float = 5e-4):
        super().__init__()
        self.model = submodel_cls(**submodel_kwargs)
        self.obs_key = obs_key
        self.act_key = act_key
        self.next_obs_key = next_obs_key
        self.terminals_key = terminals_key
        self.use_next_obs = use_next_obs
        self.use_terminals = use_terminals
        self.learning_rate = learning_rate
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, *args) -> torch.Tensor:
        """Runs a forward pass through this model.

        Returns:
            The unnormalized reward values, which this model later interprets as logits.
        """
        return self.model.forward(*args)

    def reward(self, *args) -> torch.Tensor:
        """Computes the reward value.

        Returns:
            The reward, which for this model is between 0 and 1.
        """
        rewards = self.model.reward(*args)
        return torch.sigmoid(rewards)

    def generic_step(self, batch: Dict, batch_idx: int, split: str) -> torch.Tensor:
        """Performs a generic step of this model.

        Args:
            batch: The batch of data used in this step.
            batch_idx: The index of the batch.
            split: The train/val/test split.

        Returns:
            The loss computed during this step.
        """
        forward_args = [batch[self.obs_key], batch[self.act_key]]
        # In order to adhere to the same interface for all reward models, append None when not used.
        forward_args.append(batch[self.next_obs_key] if self.use_next_obs else None)
        forward_args.append(batch[self.terminals_key] if self.use_terminals else None)
        logits = self.forward(*forward_args)
        loss = self.loss_fn(logits, batch["label"])

        self.log(f"{split}_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        del batch_idx
        return loss

    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "train")

    def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "val")

    def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "test")

    def configure_optimizers(self):
        """Sets up the optimizer for this model."""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.model.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.model.action_space
