from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import math
import time
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.policy.sample_batch import SampleBatch
from gym.spaces import Discrete, MultiDiscrete
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, normc_initializer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
from typing import Sequence
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.view_requirement import ViewRequirement
from .iterative_normalization_original import IterNorm

# from torch.nn.utils.parametrizations import spectral_norm
from torch.nn.utils import spectral_norm
from gym.spaces import Tuple

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
import tree

from models.model_utils import (
    Conv2dSame,
    Flatten,
    ResidualBlock,
    make_n_orderd_dict,
)


class PPOAttention(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **customized_model_kwargs,
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.num_agents = customized_model_kwargs["num_agents"]
        self.num_allies = self.num_agents - 1
        self.num_opp_agents = customized_model_kwargs["num_opp_agents"]
        self.input_size = customized_model_kwargs["input_size"]
        # self.h_dim = customized_model_kwargs["hidden_dim"]
        self.embed_dim = customized_model_kwargs["embed_dim"]
        self.n_heads = customized_model_kwargs["n_heads"]
        self.n_policy_layers = customized_model_kwargs["policy_layers"]
        self.is_guard = customized_model_kwargs["is_guard"]
        self.conceptdim = customized_model_kwargs["conceptdim"]
        self.bottleneck = customized_model_kwargs["bottleneck"]
        self.include_classifer = customized_model_kwargs["include_concepts"]

        self.nonlin = nn.ReLU

        self.K = 1  # message passing rounds

        self.c_dim = self.embed_dim

        self.selfencoder = nn.Sequential(
            nn.Linear(self.input_size, self.embed_dim), self.nonlin(inplace=True)
        )

        self.oppEncoder = nn.Sequential(
            nn.Linear(self.input_size, self.embed_dim), self.nonlin(inplace=True)
        )  ## encode the states of opponents

        self.allyencoder = nn.Sequential(
            nn.Linear(self.input_size, self.embed_dim), self.nonlin(inplace=True)
        )  ## encode the states of allies

        self.conceptEncoder = nn.Sequential(
            nn.Linear(self.input_size * 10, self.embed_dim),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, self.embed_dim),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, self.conceptdim),
        )  ## concept encoding
        # +self.conceptdim
        self.oppIntermediaryEncoder = nn.Sequential(
            nn.Linear(2 * self.embed_dim + 2 * self.conceptdim, self.embed_dim),
            self.nonlin(inplace=True),
        )

        # +self.conceptdim
        self.oppAttention = nn.MultiheadAttention(
            self.embed_dim + self.conceptdim,
            self.n_heads,
            kdim=self.embed_dim,
            vdim=self.embed_dim,
            batch_first=True,
        )

        self.allyAttention = nn.MultiheadAttention(
            self.embed_dim, self.n_heads, batch_first=True
        )

        self.allyIntermediaryEncoder = nn.Sequential(
            nn.Linear(2 * self.embed_dim, self.embed_dim), self.nonlin(inplace=True)
        )

        self.value_head = nn.Sequential(
            nn.Linear(self.c_dim, self.embed_dim),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, 1),
        )
        # if self.include_classifer:
        #    self.concept_encoder = nn.Sequential(nn.Linear(self.c_dim, self.conceptdim),
        #                                self.nonlin(inplace = True))
        #    self.botteneck_encoder = nn.Sequential(nn.Linear(self.c_dim, self.bottleneck),
        #                                self.nonlin(inplace = True))
        #    self.decoder = nn.Sequential(nn.Linear(self.conceptdim + self.bottleneck, self.c_dim),
        #                                self.nonlin(inplace = True))
        #    self._concept_head = nn.Sequential(nn.Linear(self.conceptdim, self.conceptdim),
        #                                self.nonlin(inplace = True))

        if self.n_policy_layers == 1:
            self.policy_head = nn.Sequential(
                nn.Linear(self.c_dim, num_outputs), self.nonlin(inplace=True)
            )
        elif self.n_policy_layers == 2:
            self.policy_head = nn.Sequential(
                nn.Linear(self.c_dim, self.embed_dim),
                self.nonlin(inplace=True),
                nn.Linear(self.embed_dim, num_outputs),
            )

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):

        agent_obs = torch.split(input_dict["obs"], 1, dim=-2)
        xself = agent_obs[0]
        if self.is_guard:
            xally = torch.cat(agent_obs[1:5], dim=-2)
            xopp = torch.cat(agent_obs[5:], dim=-2)
        else:
            xopp = torch.cat(agent_obs[1:6], dim=-2)
            xally = torch.cat(agent_obs[6:], dim=-2)

        all_agents = torch.cat([xself, xally, xopp], dim=-2)
        all_agents = all_agents.reshape(all_agents.shape[0], 1, -1)

        xselfenc = self.selfencoder(xself)
        xoppenc = self.oppEncoder(xopp)

        self._concept_features = self.conceptEncoder(all_agents)

        # xselfenc,
        xselfpf = torch.cat((xselfenc, self._concept_features), dim=-1)
        # h = xselfenc

        q1, _ = self.oppAttention(xselfpf, xoppenc, xoppenc)
        h = torch.cat((xselfpf, q1), dim=-1)
        h = self.oppIntermediaryEncoder(h)

        xallyenc = self.allyencoder(xally)

        for k in range(self.K):
            q2, _ = self.allyAttention(h, xallyenc, xallyenc)
            h = self.allyIntermediaryEncoder(torch.cat((h, q2), dim=-1))

        self._features = h

        self._features = self._features.squeeze(-2)

        logits = self.policy_head(self._features)

        return logits, state

    @override(TorchModelV2)
    def value_function(self) -> TensorType:
        assert self._features is not None, "must call forward() first"

        value_out = self.value_head(self._features).squeeze(-1)

        return value_out

    def concept_function(self) -> TensorType:
        assert self._concept_features is not None, "must call forward() first"

        # return self._concept_head(self._concept_features).squeeze(-2)
        return self._concept_features.squeeze(-2)


class FeedForwardPPO(TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **customized_model_kwargs,
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.num_agents = customized_model_kwargs["num_agents"]
        self.num_allies = self.num_agents - 1
        self.num_opp_agents = customized_model_kwargs["num_opp_agents"]
        self.input_size = customized_model_kwargs["input_size"]
        # self.h_dim = customized_model_kwargs["hidden_dim"]
        self.embed_dim = customized_model_kwargs["embed_dim"]
        self.n_heads = customized_model_kwargs["n_heads"]
        self.n_policy_layers = customized_model_kwargs["policy_layers"]
        self.is_guard = customized_model_kwargs["is_guard"]
        self.conceptdim = customized_model_kwargs["conceptdim"]
        self.bottleneck = customized_model_kwargs["bottleneck"]
        self.include_classifer = customized_model_kwargs["include_concepts"]

        self.nonlin = nn.ReLU

        self.K = 1  # message passing rounds

        self.c_dim = self.embed_dim
        self.conceptEncoder = nn.Sequential(
            nn.Linear(
                self.input_size * (self.num_agents + self.num_opp_agents),
                self.embed_dim,
            ),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, self.embed_dim),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, self.conceptdim),
        )  ## concept encoding

        self.value_head = nn.Sequential(
            nn.Linear(self.conceptdim, self.embed_dim),
            self.nonlin(inplace=True),
            nn.Linear(self.embed_dim, 1),
        )

        # if self.include_classifer:
        #    self.concept_encoder = nn.Sequential(nn.Linear(self.c_dim, self.conceptdim),
        #                                self.nonlin(inplace = True))
        #    self.botteneck_encoder = nn.Sequential(nn.Linear(self.c_dim, self.bottleneck),
        #                                self.nonlin(inplace = True))
        #    self.decoder = nn.Sequential(nn.Linear(self.conceptdim + self.bottleneck, self.c_dim),
        #                                self.nonlin(inplace = True))
        #    self._concept_head = nn.Sequential(nn.Linear(self.conceptdim, self.conceptdim),
        #                                self.nonlin(inplace = True))

        if self.n_policy_layers == 1:
            self.policy_head = nn.Sequential(
                nn.Linear(self.conceptdim, num_outputs), self.nonlin(inplace=True)
            )
        elif self.n_policy_layers == 2:
            self.policy_head = nn.Sequential(
                nn.Linear(self.conceptdim, self.embed_dim),
                self.nonlin(inplace=True),
                nn.Linear(self.embed_dim, num_outputs),
            )

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):

        # agent_obs = torch.split(input_dict["obs"], 1, dim=-2)
        # xself = agent_obs[0]
        # if self.is_guard:
        #     xally = torch.cat(agent_obs[1:5], dim=-2)
        #     xopp = torch.cat(agent_obs[5:], dim=-2)
        # else:
        #     xopp = torch.cat(agent_obs[1:6], dim=-2)
        #     xally = torch.cat(agent_obs[6:], dim=-2)

        # all_agents = torch.cat([xself, xally, xopp], dim=-2)
        all_agents = input_dict["obs"]
        all_agents = all_agents.reshape(all_agents.shape[0], -1)

        self._concept_features = self.conceptEncoder(all_agents)

        self._features = self._concept_features

        self._features = self._features.squeeze(-2)

        logits = self.policy_head(self._features)

        return logits, state

    @override(TorchModelV2)
    def value_function(self) -> TensorType:
        assert self._features is not None, "must call forward() first"

        value_out = self.value_head(self._features).squeeze(-1)

        return value_out

    def concept_function(self) -> TensorType:
        assert self._concept_features is not None, "must call forward() first"

        # return self._concept_head(self._concept_features).squeeze(-2)
        return self._concept_features.squeeze(-2)


class RNNModel(TorchRNN, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **customized_model_kwargs,
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        if len(customized_model_kwargs.keys()) == 0:
            customized_model_kwargs = model_config["custom_model_config"]

        self.config = customized_model_kwargs

        self.num_agents = customized_model_kwargs["num_agents"]
        self.num_allies = self.num_agents - 1
        self.num_opp_agents = customized_model_kwargs["num_opp_agents"]
        self.input_size = customized_model_kwargs["input_size"]
        # self.h_dim = customized_model_kwargs["hidden_dim"]
        self.embed_dim = customized_model_kwargs["embed_dim"]
        self.n_heads = customized_model_kwargs["n_heads"]
        self.n_policy_layers = customized_model_kwargs["policy_layers"]
        self.is_guard = customized_model_kwargs["is_guard"]
        self.conceptdim = customized_model_kwargs["conceptdim"]
        self.bottleneck = customized_model_kwargs["bottleneck"]
        self.include_classifer = customized_model_kwargs["include_concepts"]
        self.combined_dim = self.conceptdim + self.bottleneck
        self.name = name
        self.model_config = model_config

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = self.embed_dim
        self.lstm_state_size = self.embed_dim
        self.nonlin = nn.ReLU

        # Build the Module from fc + LSTM + 2xfc (action + value outs).
        # self.fc1 = nn.Linear(self.obs_size, self.fc_size)
        self.fc1 = nn.Sequential(
            nn.Linear(
                self.obs_size,
                self.fc_size,
            ),
            self.nonlin(inplace=True),
            nn.Linear(self.fc_size, self.fc_size),
        )
        self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=True)
        self.fc2 = nn.Linear(self.lstm_state_size, self.combined_dim)

        self.action_branch = nn.Linear(self.combined_dim, num_outputs)
        self.value_branch = nn.Linear(self.combined_dim, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None

        self.activation_length = []
        self.activation_type = []
        self.concept_name = []
        self.activation_num = 0

        if self.config["include_concepts"]:
            concept_configs = self.config["concept_configs"]
            total_length = 0
            for config in concept_configs:
                cur_start = config["start idx"]
                cur_end = config["end idx"]
                length = cur_end - cur_start
                self.activation_length.append(length)
                self.activation_type.append(config["type"])
                self.concept_name.append(config["name"])
                total_length += length
                self.activation_num += 1
            self.total_length = total_length
            self.activation_length.append(self.combined_dim - total_length)
            self.activation_type.append("None")
            self.activation_num += 1

    @override(ModelV2)
    def get_initial_state(self):
        # TODO: (sven): Get rid of `get_initial_state` once Trajectory
        #  View API is supported across all of RLlib.
        # Place hidden states on same device as model.
        h = [
            self.fc1[2].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1[2].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._concept_features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._concept_features), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.
        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).
        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        x = nn.functional.relu(self.fc1(inputs))
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
        )
        self._features = self._features

        cf = nn.functional.relu(self.fc2(self._features))
        cf_out = []

        if self.config["include_concepts"]:
            cf_list = torch.split(cf, self.activation_length, dim=-1)
            for i in range(self.activation_num):
                activation_type = self.activation_type[i]
                if activation_type == "classification":
                    cf_out.append(torch.nn.Softmax(dim=-1)(cf_list[i]))
                else:
                    cf_out.append(cf_list[i])
            self._concept_features = torch.cat(cf_out, dim=-1)
        else:
            self._concept_features = cf

        action_out = self.action_branch(self._concept_features)

        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    def concept_function(self) -> TensorType:
        assert self._concept_features is not None, "must call forward() first"

        # return self._concept_head(self._concept_features).squeeze(-2)
        # return self._concept_features.squeeze(-2)
        return torch.reshape(self._concept_features, [-1, self.combined_dim])


class RNNRewModel(TorchRNN, TorchModelV2, nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **customized_model_kwargs,
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        if len(customized_model_kwargs.keys()) == 0:
            customized_model_kwargs = model_config["custom_model_config"]

        self.config = customized_model_kwargs

        self.num_agents = customized_model_kwargs["num_agents"]
        self.num_allies = self.num_agents - 1
        self.num_opp_agents = customized_model_kwargs["num_opp_agents"]
        self.input_size = customized_model_kwargs["input_size"]
        # self.h_dim = customized_model_kwargs["hidden_dim"]
        self.embed_dim = customized_model_kwargs["embed_dim"]
        self.n_heads = customized_model_kwargs["n_heads"]
        self.n_fc1_layers = customized_model_kwargs["fc1_layers"]
        self.n_fc2_layers = customized_model_kwargs["fc2_layers"]
        self.n_fc3_layers = customized_model_kwargs["fc3_layers"]
        self.n_policy_layers = customized_model_kwargs["policy_layers"]
        self.is_guard = customized_model_kwargs["is_guard"]
        self.conceptdim = customized_model_kwargs["conceptdim"]
        self.bottleneck = customized_model_kwargs["bottleneck"]
        self.include_classifer = customized_model_kwargs["include_concepts"]
        self.include_whitening = customized_model_kwargs["include_whitening"]
        self.affine_whitening = customized_model_kwargs["affine_whitening"]
        self.T_whitening = customized_model_kwargs["T_whitening"]
        self.use_reward = customized_model_kwargs["use_reward"]
        self.combined_dim = self.conceptdim + self.bottleneck
        self.name = name
        self.model_config = model_config

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = self.embed_dim
        self.lstm_state_size = self.embed_dim
        self.nonlin = nn.ReLU
        self.action_space_struct = get_base_struct_from_space(self.action_space)
        self.action_dim = 0
        if self.include_whitening:
            self.ItN = IterNorm(
                self.combined_dim,
                num_channels=self.combined_dim,
                dim=2,
                T=self.T_whitening,
                momentum=1,
                affine=self.affine_whitening,
            )

        for space in tree.flatten(self.action_space_struct):
            if isinstance(space, Discrete):
                self.action_dim += space.n
            elif isinstance(space, MultiDiscrete):
                self.action_dim += np.sum(space.nvec)
            elif space.shape is not None:
                self.action_dim += int(np.product(space.shape))
            else:
                self.action_dim += int(len(space))

        # Construct an arbitrary number of linear layers with relu after each
        # fc1_list = [
        #     ("fc0", nn.Linear(self.obs_size, self.fc_size,)),
        #     ("relu0", self.nonlin(inplace=True)),
        # ]
        # for i in range(1, self.n_fc1_layers):
        #     fc1_list.append((f"fc{i}", nn.Linear(self.fc_size, self.fc_size)))
        #     fc1_list.append((f"relu{i}", self.nonlin(inplace=True)))
        # fc1_dict = OrderedDict(fc1_list)

        fc1_dict = make_n_orderd_dict(
            n=self.n_fc1_layers,
            input_size=self.obs_size,
            hidden_size=self.fc_size,
            output_size=self.fc_size,
        )
        self.fc1 = nn.Sequential(fc1_dict)

        if self.use_reward:
            total = self.fc_size + self.action_dim + 1
        else:
            total = self.fc_size + self.action_dim

        self.lstm = nn.LSTM(total, self.lstm_state_size, batch_first=True)

        # self.fc2 = nn.Linear(self.lstm_state_size, self.combined_dim)

        fc2_dict = make_n_orderd_dict(
            n=self.n_fc2_layers,
            input_size=self.lstm_state_size,
            hidden_size=self.fc_size,
            output_size=self.combined_dim,
        )
        self.fc2 = nn.Sequential(fc2_dict)

        fc3_dict = make_n_orderd_dict(
            n=self.n_fc3_layers,
            input_size=self.combined_dim,
            hidden_size=self.fc_size,
            output_size=self.fc_size,  # temp should be self.fc_size
        )
        self.fc3 = nn.Sequential(fc3_dict)

        self.action_branch = nn.Linear(
            self.fc_size, num_outputs
        )  # temp should be self.fc_size
        self.value_branch = nn.Linear(self.fc_size, 1)  # temp should be self.fc_size
        # Holds the current "base" output (before logits layer).
        self._features = None

        self.activation_length = []
        self.activation_type = []
        self.concept_name = []
        self.activation_num = 0
        self.concepts_to_update = []

        if self.config["include_concepts"]:
            if self.is_guard:
                concept_configs = self.config["concept_configs"]["guard"]
            else:
                concept_configs = self.config["concept_configs"]["attacker"]
            total_length = 0
            for config in concept_configs.configs:
                cur_start = config.start_idx
                cur_end = config.end_idx
                length = cur_end - cur_start
                self.activation_length.append(length)
                self.activation_type.append(config.concept_type)
                self.concept_name.append(config.name)
                total_length += length
                self.activation_num += 1
            self.total_length = total_length
            self.activation_length.append(self.combined_dim - total_length)
            self.activation_type.append("None")
            self.concept_name.append("Residual")
            self.activation_num += 1

        self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
            SampleBatch.ACTIONS, space=self.action_space, shift=-1
        )
        if self.use_reward:
            self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
                SampleBatch.REWARDS, shift=-1
            )

    @override(ModelV2)
    def get_initial_state(self):
        # TODO: (sven): Get rid of `get_initial_state` once Trajectory
        #  View API is supported across all of RLlib.
        # Place hidden states on same device as model.
        h = [
            self.fc1[-1].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1[-1].weight.new(1, self.lstm_state_size).zero_().squeeze(0),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self.fc3_out is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self.fc3_out), [-1])

    @override(TorchRNN)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        """Adds time dimension to batch before sending inputs to forward_rnn().
        You should implement forward_rnn() in your subclass."""
        # print(f"{input_dict=}")
        flat_inputs = input_dict["obs_flat"].float()
        # intervene = input_dict["intervene"]

        if (
            "concept_infos" in input_dict
            and "do_update" in input_dict
            and input_dict["do_update"] is not None
            and input_dict["do_update"]
        ):
            concept_update = input_dict["concept_infos"]
        else:
            concept_update = None

        concepts_to_update = self.concepts_to_update

        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()

        max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get("_time_major", False)

        obs_inputs = add_time_dimension(
            flat_inputs,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=self.time_major,
        )

        # Previous actions
        prev_a = input_dict[SampleBatch.PREV_ACTIONS]
        prev_a_r = []
        if isinstance(self.action_space, (Discrete, MultiDiscrete)):
            prev_a = one_hot(prev_a.float(), self.action_space)
        else:
            prev_a = prev_a.float()
        prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
        if self.use_reward:
            prev_a_r.append(
                torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1])
            )
        action_plus_reward = torch.cat(prev_a_r, dim=1)

        action_reward_inputs = add_time_dimension(
            action_plus_reward,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=self.time_major,
        )

        output, new_state = self.forward_rnn(
            obs_inputs,
            action_reward_inputs,
            state,
            seq_lens,
            concept_update=concept_update,
            concepts_to_update=concepts_to_update,
        )
        output = torch.reshape(output, [-1, self.num_outputs])
        return output, new_state

    @override(TorchRNN)
    def forward_rnn(
        self,
        obs_inputs,
        action_reward_inputs,
        state,
        seq_lens,
        concept_update=None,
        concepts_to_update=None,
    ):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.
        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).
        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        B_shape = obs_inputs.shape[0]
        T_shape = obs_inputs.shape[1]
        x = self.fc1(obs_inputs)
        x = torch.cat([x, action_reward_inputs], dim=2)

        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
        )
        self._features = self._features
        temp = self.fc2(self._features)

        if self.include_whitening:
            if self.training:
                self.ItN.train()
            else:
                self.ItN.eval()
            temp = torch.reshape(temp, [B_shape * T_shape, self.conceptdim])
            self.cf = torch.reshape(self.ItN(temp), [B_shape, T_shape, self.conceptdim])
        else:
            self.cf = temp

        cf = self.cf
        cf_out = []

        if self.config["include_concepts"]:
            if self.activation_length[-1] > 0:
                cf_list = torch.split(cf, self.activation_length, dim=-1)
            else:
                cf_list = torch.split(cf, self.activation_length[:-1], dim=-1)
            # and len(self.activation_length[-1]) > 0
            if concept_update is not None  and len(self.activation_length) > 1:
                ci_list = torch.split(
                    concept_update, self.activation_length[:-1], dim=-1
                )
            else:
                ci_list = [concept_update]
            for i in range(self.activation_num):
                if self.activation_length[i] == 0:
                    continue
                activation_type = self.activation_type[i]
                if activation_type == "classification":
                    cf_out.append(torch.nn.Softmax(dim=-1)(cf_list[i]))
                else:
                    cf_out.append(cf_list[i])
                # perform concept update for the current
                if (
                    self.concept_name[i] in concepts_to_update
                    and concept_update is not None
                ):
                    cf_out[-1] = ci_list[i]
                    print(f"{self.concept_name[i]} updated")
            self._concept_features = torch.cat(cf_out, dim=-1)
        else:
            self._concept_features = cf

        self.fc3_out = self.fc3(self._concept_features)

        action_out = self.action_branch(self.fc3_out)

        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    def concept_function(self) -> TensorType:
        assert self.cf is not None, "must call forward() first"

        # return self._concept_head(self._concept_features).squeeze(-2)
        # return self._concept_features.squeeze(-2)
        return torch.reshape(self.cf, [-1, self.combined_dim])

    def return_concept(self):
        assert self._concept_features is not None, "must call forward() first"

        # return self._concept_head(self._concept_features).squeeze(-2)
        # return self._concept_features.squeeze(-2)
        return torch.reshape(self._concept_features, [-1, self.combined_dim])
