from typing import Dict, Optional

import gym
import pytorch_lightning as pl
from ray.rllib.policy.sample_batch import SampleBatch
import torch
import torch.nn as nn

from offline_rl.rewards.reward_model import RewardModel


class DirectRegressionRewardModel(RewardModel, pl.LightningModule):
    """A reward model trained to directly regress known reward values.

    Args:
        submodel_cls: The sub-reward-model delegated to in computing the reward value.
            This class delegates like this so that it can easily train different types
            of networks for different types of input spaces. The sub-model must be
            instantiated inside the class so torch recognizes its parameters.
        submodel_kwargs: The key word arguments passed to the the submodel init method.
        label_overwriting_reward_model: If not `None`, overwrites the reward labels with
            the output from this reward model. This exists to allow for cloning a known
            reward model on an otherwise pre-defined dataset.
        obs_key: The key to use in accessing the obserations.
        act_key: The key to use in accessing the actions.
        next_obs_key: The key to use in accessing the next observations.
        terminals_key: The key to use in accessing the terminal indicators.
        rewards_key: The key to use in accessing the rewards in the batch.
        learning_rate: The learning rate passed to the optimizer.
    """
    def __init__(
            self,
            submodel_cls: type,
            submodel_kwargs: Dict,
            label_overwriting_reward_model: Optional[RewardModel] = None,
            obs_key: str = SampleBatch.OBS,
            act_key: str = SampleBatch.ACTIONS,
            next_obs_key: str = SampleBatch.NEXT_OBS,
            terminals_key: str = SampleBatch.DONES,
            rewards_key: str = SampleBatch.REWARDS,
            learning_rate: float = 5e-4,
    ):
        super().__init__()
        self.model = submodel_cls(**submodel_kwargs)
        self.label_overwriting_reward_model = label_overwriting_reward_model
        self.obs_key = obs_key
        self.act_key = act_key
        self.next_obs_key = next_obs_key
        self.terminals_key = terminals_key
        self.rewards_key = rewards_key
        self.learning_rate = learning_rate
        self.loss_fn = nn.MSELoss()

    def forward(self, *args) -> torch.Tensor:
        """Runs a forward pass through this model.

        Returns:
            The unnormalized reward values, which this model later interprets as logits.
        """
        return self.model.forward(*args)

    def reward(self, *args) -> torch.Tensor:
        """Computes the reward value.

        Returns:
            The reward, which for this model is between 0 and 1.
        """
        return self.model.reward(*args)

    def generic_step(self, batch: Dict, batch_idx: int, split: str) -> torch.Tensor:
        """Performs a generic step of this model.

        Args:
            batch: The batch of data used in this step.
            batch_idx: The index of the batch.
            split: The train/val/test split.

        Returns:
            The loss computed during this step.
        """
        batch = self._maybe_overwrite_reward_labels(batch)
        pred = self.forward(
            batch[self.obs_key],
            batch[self.act_key],
            batch[self.next_obs_key],
            batch[self.terminals_key],
        )
        assert pred.shape == batch[self.rewards_key].shape
        loss = self.loss_fn(pred, batch[self.rewards_key])

        on_step = True if split == "train" else False
        self.log(f"{split}_loss", loss, on_step=on_step, on_epoch=True, prog_bar=True, logger=True)
        del batch_idx
        return loss

    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "train")

    def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "val")

    def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        """See documentation for `generic_step`."""
        return self.generic_step(batch, batch_idx, "test")

    def configure_optimizers(self):
        """Sets up the optimizer for this model."""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.model.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.model.action_space

    def _maybe_overwrite_reward_labels(self, batch: Dict) -> Dict:
        """Overwrites the reward labels using the output from an existing reward model.

        Args:
            batch: The batch to potentially overwrite the rewards in.

        Returns:
            The batch with rewards potentially overwritten.
        """
        if self.label_overwriting_reward_model is None:
            return batch
        batch[self.rewards_key] = self.label_overwriting_reward_model.reward(
            batch[self.obs_key],
            batch[self.act_key],
            batch[self.next_obs_key],
            batch[self.terminals_key],
        )
        return batch
