# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""Implementation of the soft routing network and MLP described in
"Multi-Task Reinforcement Learning with Soft Modularization"
Link: https://arxiv.org/abs/2003.13661
"""


from typing import List
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

from mtrl.agent.components.base import Component as BaseComponent
from mtrl.agent.components.moe_layer import Linear, OrthogonalLayer1D
from mtrl.agent import utils as agent_utils
from mtrl.agent.ds.mt_obs import MTObs
from mtrl.utils.types import TensorType

class MOEMLP(BaseComponent):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        task_num: int,
        hidden_features: int,
        hidden_layers: int,
        emb_hidden_layers: int,
        num_layers: int,
        num_experts: int,
        module_hidden_features: int,
        gating_hidden_features: int,
        gating_hidden_layers: int,
        cond_obs: bool = True,
        bias: bool = True,
        use_moore: bool = False,
    ):
        """Class to implement the actor/critic in
        'Multi-Task Reinforcement Learning with Soft Modularization' paper.
        It is similar to layers.FeedForward but allows selection of expert
        at each layer.
        """
        super().__init__()
        self.task_num = task_num
        self.in_task_features = task_num
        self.in_obs_features = in_features - task_num
        self.cond_obs = cond_obs
        self.num_layers = num_layers
        self.num_experts = num_experts

        self.obs_net = agent_utils.build_mlp(
            input_dim=self.in_obs_features,
            hidden_dim=hidden_features,
            num_layers=hidden_layers - 1,
            output_dim=hidden_features,
        )
        self.emb_net = agent_utils.build_mlp(
            input_dim=self.in_task_features,
            hidden_dim=hidden_features,
            num_layers=emb_hidden_layers - 1,
            output_dim=hidden_features,
        )
        self.emb_obs_activation = nn.ReLU()

        layers: List[nn.Module] = []
        current_in_features = hidden_features
        for i in range(num_layers):
            linear = Linear(
                num_experts=num_experts,
                in_features=current_in_features,
                out_features=module_hidden_features,
                bias=bias,
            )
            layers.append(linear)
            if i < num_layers - 1:
                layers.append(nn.ReLU())
            # layers.append(nn.Sequential(linear, nn.ReLU()))
            current_in_features = module_hidden_features
        if use_moore:
            layers.append(OrthogonalLayer1D())
        self.layers = nn.ModuleList(layers)
        self.last = nn.Linear(module_hidden_features, out_features, bias=bias)
        
        self.gating_network = agent_utils.build_mlp(
            input_dim=hidden_features,
            hidden_dim=gating_hidden_features,
            num_layers=gating_hidden_layers,
            output_dim=num_experts,
        )
    
    def forward(self, env_obs: TensorType) -> TensorType:
        env_obs, task_obs = torch.split(env_obs, [self.in_obs_features, self.in_task_features], dim=-1)
        obs_inp = self.obs_net(env_obs)
        gating_inp = self.emb_net(task_obs)
        if self.cond_obs:
            assert gating_inp.shape == obs_inp.shape
            gating_inp = gating_inp * obs_inp

        prob = self.gating_network(gating_inp).unsqueeze(1) # [bs, num] -> [bs, 1, num]

        obs_inp = obs_inp.unsqueeze(0).repeat(self.num_experts, 1, 1)

        for layer in self.layers:
            obs_inp = layer(obs_inp)
        
        obs_inp = obs_inp.permute(1, 0, 2) # [num, bs, _] -> [bs, num, _]
        obs_inp = (prob @ obs_inp).squeeze(1) # [bs, 1, _] ->  [bs, _]
        out = self.last(obs_inp)
        return out
