"""PyTorch policy class used for SlateQ"""

from typing import Dict, List, Sequence, Tuple

import gym
import numpy as np

import ray
from src.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from src.rllib.models.torch.torch_action_dist import (TorchCategorical,
                                                      TorchDistributionWrapper)
from src.rllib.models.torch.torch_modelv2 import TorchModelV2
from src.rllib.policy.policy import Policy
from src.rllib.policy.policy_template import build_policy_class
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.framework import try_import_torch
from src.rllib.utils.typing import (ModelConfigDict, TensorType,
                                    TrainerConfigDict)

torch, nn = try_import_torch()
F = None
if nn:
    F = nn.functional


class QValueModel(nn.Module):
    """The Q-value model for SlateQ"""

    def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
        super().__init__()

        # construct hidden layers
        layers = []
        ins = 2 * embedding_size
        for n in q_hiddens:
            layers.append(nn.Linear(ins, n))
            layers.append(nn.LeakyReLU())
            ins = n
        layers.append(nn.Linear(ins, 1))
        self.layers = nn.Sequential(*layers)

    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user-doc Q model

        Args:
            user (TensorType): User embedding of shape (batch_size,
                embedding_size).
            doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
                embedding_size).

        Returns:
            score (TensorType): q_values of shape (batch_size, num_docs + 1).
        """
        batch_size, num_docs, embedding_size = doc.shape
        doc_flat = doc.view((batch_size * num_docs, embedding_size))
        user_repeated = user.repeat(num_docs, 1)
        x = torch.cat([user_repeated, doc_flat], dim=1)
        x = self.layers(x)
        # Similar to Google's SlateQ implementation in RecSim, we force the
        # Q-values to zeros if there are no clicks.
        x_no_click = torch.zeros((batch_size, 1), device=x.device)
        return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)


class UserChoiceModel(nn.Module):
    r"""The user choice model for SlateQ

    This class implements a multinomial logit model for predicting user clicks.

    Under this model, the click probability of a document is proportional to:

    .. math::
        \exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
    """

    def __init__(self):
        super().__init__()
        self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
        self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))

    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user choice model

        This function outputs user click scores for candidate documents. The
        exponentials of these scores are proportional user click probabilities.
        Here we return the scores unnormalized because because only some of the
        documents will be selected and shown to the user.

        Args:
            user (TensorType): User embeddings of shape (batch_size,
                embedding_size).
            doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
                embedding_size).

        Returns:
            score (TensorType): logits of shape (batch_size, num_docs + 1),
                where the last dimension represents no_click.
        """
        batch_size = user.shape[0]
        s = torch.einsum("be,bde->bd", user, doc)
        s = s * self.beta
        s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
        return s


class SlateQModel(TorchModelV2, nn.Module):
    """The SlateQ model class

    It includes both the user choice model and the Q-value model.
    """

    def __init__(
            self,
            obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            model_config: ModelConfigDict,
            name: str,
            *,
            embedding_size: int,
            q_hiddens: Sequence[int],
    ):
        nn.Module.__init__(self)
        TorchModelV2.__init__(
            self,
            obs_space,
            action_space,
            # This required parameter (num_outputs) seems redundant: it has no
            # real imact, and can be set arbitrarily. TODO: fix this.
            num_outputs=0,
            model_config=model_config,
            name=name)
        self.choice_model = UserChoiceModel()
        self.q_model = QValueModel(embedding_size, q_hiddens)
        self.slate_size = len(action_space.nvec)

    def choose_slate(self, user: TensorType,
                     doc: TensorType) -> Tuple[TensorType, TensorType]:
        """Build a slate by selecting from candidate documents

        Args:
            user (TensorType): User embeddings of shape (batch_size,
                embedding_size).
            doc (TensorType): Doc embeddings of shape (batch_size,
                num_docs, embedding_size).

        Returns:
            slate_selected (TensorType): Indices of documents selected for
                the slate, with shape (batch_size, slate_size).
            best_slate_q_value (TensorType): The Q-value of the selected slate,
                with shape (batch_size).
        """
        # Step 1: compute item scores (proportional to click probabilities)
        # raw_scores.shape=[batch_size, num_docs+1]
        raw_scores = self.choice_model(user, doc)
        # max_raw_scores.shape=[batch_size, 1]
        max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
        # deduct scores by max_scores to avoid value explosion
        scores = torch.exp(raw_scores - max_raw_scores)
        scores_doc = scores[:, :-1]  # shape=[batch_size, num_docs]
        scores_no_click = scores[:, [-1]]  # shape=[batch_size, 1]

        # Step 2: calculate the item-wise Q values
        # q_values.shape=[batch_size, num_docs+1]
        q_values = self.q_model(user, doc)
        q_values_doc = q_values[:, :-1]  # shape=[batch_size, num_docs]
        q_values_no_click = q_values[:, [-1]]  # shape=[batch_size, 1]

        # Step 3: construct all possible slates
        _, num_docs, _ = doc.shape
        indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
        # slates.shape = [num_slates, slate_size]
        slates = torch.combinations(indices, r=self.slate_size)
        num_slates, _ = slates.shape

        # Step 4: calculate slate Q values
        batch_size, _ = q_values_doc.shape
        # slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
        slate_decomp_q_values = torch.gather(
            # input.shape: [batch_size, num_slates, num_docs]
            input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
            dim=2,
            # index.shape: [batch_size, num_slates, slate_size]
            index=slates.unsqueeze(0).expand(batch_size, -1, -1))
        # slate_scores.shape: [batch_size, num_slates, slate_size]
        slate_scores = torch.gather(
            # input.shape: [batch_size, num_slates, num_docs]
            input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
            dim=2,
            # index.shape: [batch_size, num_slates, slate_size]
            index=slates.unsqueeze(0).expand(batch_size, -1, -1))
        # slate_q_values.shape: [batch_size, num_slates]
        slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
                          (q_values_no_click * scores_no_click)) / (
                              slate_scores.sum(dim=2) + scores_no_click)

        # Step 5: find the slate that maximizes q value
        best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
        # slates_selected.shape: [batch_size, slate_size]
        slates_selected = slates[max_idx]
        return slates_selected, best_slate_q_value

    def forward(self, input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        # user.shape: [batch_size, embedding_size]
        user = input_dict[SampleBatch.OBS]["user"]
        # doc.shape: [batch_size, num_docs, embedding_size]
        doc = torch.cat([
            val.unsqueeze(1)
            for val in input_dict[SampleBatch.OBS]["doc"].values()
        ], 1)

        slates_selected, _ = self.choose_slate(user, doc)

        state_out = []
        return slates_selected, state_out


def build_slateq_model_and_distribution(
        policy: Policy, obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
    """Build models for SlateQ

    Args:
        policy (Policy): The policy, which will use the model for optimization.
        obs_space (gym.spaces.Space): The policy's observation space.
        action_space (gym.spaces.Space): The policy's action space.
        config (TrainerConfigDict):

    Returns:
        (q_model, TorchCategorical)
    """
    model = SlateQModel(
        obs_space,
        action_space,
        model_config=config["model"],
        name="slateq_model",
        embedding_size=config["recsim_embedding_size"],
        q_hiddens=config["hiddens"],
    )
    return model, TorchCategorical


def build_slateq_losses(policy: Policy, model: SlateQModel, _,
                        train_batch: SampleBatch) -> TensorType:
    """Constructs the losses for SlateQPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    obs = restore_original_dimensions(
        train_batch[SampleBatch.OBS],
        policy.observation_space,
        tensorlib=torch)
    # user.shape: [batch_size, embedding_size]
    user = obs["user"]
    # doc.shape: [batch_size, num_docs, embedding_size]
    doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
    # action.shape: [batch_size, slate_size]
    actions = train_batch[SampleBatch.ACTIONS]

    next_obs = restore_original_dimensions(
        train_batch[SampleBatch.NEXT_OBS],
        policy.observation_space,
        tensorlib=torch)

    # Step 1: Build user choice model loss
    _, _, embedding_size = doc.shape
    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        input=doc,
        dim=1,
        # index.shape: [batch_size, slate_size, embedding_size]
        index=actions.unsqueeze(2).expand(-1, -1, embedding_size))

    scores = model.choice_model(user, selected_doc)
    choice_loss_fn = nn.CrossEntropyLoss()

    # clicks.shape: [batch_size, slate_size]
    clicks = torch.stack(
        [resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
    no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
    # clicks.shape: [batch_size, slate_size+1]
    targets = torch.cat([clicks, no_clicks], dim=1)
    choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    # Step 2: Build qvalue loss
    # Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
    # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
    # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
    # 'batch_indexes']
    learning_strategy = policy.config["slateq_strategy"]

    if learning_strategy == "SARSA":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_actions = train_batch["next_actions"]
        _, _, embedding_size = next_doc.shape
        # selected_doc.shape: [batch_size, slate_size, embedding_size]
        next_selected_doc = torch.gather(
            # input.shape: [batch_size, num_docs, embedding_size]
            input=next_doc,
            dim=1,
            # index.shape: [batch_size, slate_size, embedding_size]
            index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
        next_user = next_obs["user"]
        dones = train_batch["dones"]
        with torch.no_grad():
            # q_values.shape: [batch_size, slate_size+1]
            q_values = model.q_model(next_user, next_selected_doc)
            # raw_scores.shape: [batch_size, slate_size+1]
            raw_scores = model.choice_model(next_user, next_selected_doc)
            max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
            scores = torch.exp(raw_scores - max_raw_scores)
            # next_q_values.shape: [batch_size]
            next_q_values = torch.sum(
                q_values * scores, dim=1) / torch.sum(
                    scores, dim=1)
            next_q_values[dones] = 0.0
    elif learning_strategy == "MYOP":
        next_q_values = 0.
    elif learning_strategy == "QL":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_user = next_obs["user"]
        dones = train_batch["dones"]
        with torch.no_grad():
            _, next_q_values = model.choose_slate(next_user, next_doc)
        next_q_values[dones.long()] = 0.0
    else:
        raise ValueError(learning_strategy)
    # target_q_values.shape: [batch_size]
    target_q_values = next_q_values + train_batch["rewards"]

    q_values = model.q_model(user, selected_doc)  # shape: [batch_size, slate_size+1]
    # raw_scores.shape: [batch_size, slate_size+1]
    raw_scores = model.choice_model(user, selected_doc)
    max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
    scores = torch.exp(raw_scores - max_raw_scores)
    q_values = torch.sum(
        q_values * scores, dim=1) / torch.sum(
            scores, dim=1)  # shape=[batch_size]

    q_value_loss = nn.MSELoss()(q_values, target_q_values)
    return [choice_loss, q_value_loss]


def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
                            ) -> List["torch.optim.Optimizer"]:
    optimizer_choice = torch.optim.Adam(
        policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
    optimizer_q_value = torch.optim.Adam(
        policy.model.q_model.parameters(),
        lr=config["lr_q_model"],
        eps=config["adam_epsilon"])
    return [optimizer_choice, optimizer_q_value]


def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
                      explore, timestep):
    """Determine which action to take"""
    # First, we transform the observation into its unflattened form
    obs = restore_original_dimensions(
        input_dict[SampleBatch.CUR_OBS],
        policy.observation_space,
        tensorlib=torch)

    # user.shape: [batch_size(=1), embedding_size]
    user = obs["user"]
    # doc.shape: [batch_size(=1), num_docs, embedding_size]
    doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)

    selected_slates, _ = model.choose_slate(user, doc)

    action = selected_slates
    logp = None
    state_out = []
    return action, logp, state_out


def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
                                              batch: SampleBatch,
                                              other_agent=None,
                                              episode=None) -> SampleBatch:
    """Add next_actions to SampleBatch for SARSA training"""
    if policy.config["slateq_strategy"] == "SARSA":
        if not batch["dones"][-1]:
            raise RuntimeError(
                "Expected a complete episode in each sample batch. "
                f"But this batch is not: {batch}.")
        batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
    return batch


SlateQTorchPolicy = build_policy_class(
    name="SlateQTorchPolicy",
    framework="torch",
    get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,

    # build model, loss functions, and optimizers
    make_model_and_action_dist=build_slateq_model_and_distribution,
    optimizer_fn=build_slateq_optimizers,
    loss_fn=build_slateq_losses,

    # define how to act
    action_sampler_fn=action_sampler_fn,

    # post processing batch sampled data
    postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
)
