import copy
from typing import Type, Dict, Any, Optional, Tuple

import gym
import numpy as np
import torch
from gym import spaces
from torch import nn, Tensor, no_grad, optim

from networks.gnn import GNN
from networks.utils import make_node_mlp_layers
from rl import utils
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, CategoricalDistribution
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, create_mlp, NatureCNN

# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20


class TransitionModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            ignore_action,
            copy_action,
            use_interactions,
            edge_actions
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.ignore_action = ignore_action
        self.copy_action = copy_action
        self.use_interactions = use_interactions
        self.edge_actions = edge_actions
        action_dim = self.action_space.n if isinstance(self.action_space, spaces.Discrete) else get_action_dim(
            self.action_space)
        self.network = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                           action_dim=action_dim, num_objects=self.num_objects,
                           ignore_action=self.ignore_action, copy_action=self.copy_action,
                           use_interactions=self.use_interactions, edge_actions=self.edge_actions)

    def forward(self, embedding, action):
        return self.network(embedding, action)[0] + embedding

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            ignore_action=self.ignore_action,
            copy_action=self.copy_action,
            use_interactions=self.use_interactions,
            edge_actions=self.edge_actions
        )


class TransitionModelMLP(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_layers=3,
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        action_dim = self.action_space.n if isinstance(self.action_space, spaces.Discrete) else get_action_dim(
            self.action_space)
        self.network = nn.Sequential(*make_node_mlp_layers(
            self.num_layers,
            self.embedding_dim + action_dim,
            self.hidden_dim,
            output_dim=self.embedding_dim,
            act_fn='relu',
            layer_norm=True
        ))

    def forward(self, embedding, action):
        x = torch.cat([embedding, action], dim=1)
        return self.network(x)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
        )


class RewardModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            ignore_action,
            copy_action,
            use_interactions,
            edge_actions
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.ignore_action = ignore_action
        self.copy_action = copy_action
        self.use_interactions = use_interactions
        self.edge_actions = edge_actions
        action_dim = self.action_space.n if isinstance(self.action_space, spaces.Discrete) else get_action_dim(
            self.action_space)
        self.gnn = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                       action_dim=action_dim, num_objects=self.num_objects,
                       ignore_action=False, copy_action=True, act_fn='relu', layer_norm=True, num_layers=3,
                       use_interactions=True, edge_actions=True)
        self.mlp = nn.Linear(self.embedding_dim, 1)

    def forward(self, embedding, action):
        return self.mlp(self.gnn(embedding, action)[0].mean(dim=1)).squeeze(dim=1)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            ignore_action=self.ignore_action,
            copy_action=self.copy_action,
            use_interactions=self.use_interactions,
            edge_actions=self.edge_actions
        )


class RewardModelMLP(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_layers=3,
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        action_dim = self.action_space.n if isinstance(self.action_space, spaces.Discrete) else get_action_dim(
            self.action_space)
        self.network = nn.Sequential(*make_node_mlp_layers(
            self.num_layers,
            self.embedding_dim + action_dim,
            self.hidden_dim,
            output_dim=1,
            act_fn='relu',
            layer_norm=True
        ))

    def forward(self, embedding, action):
        x = torch.cat([embedding, action], dim=1)
        return self.network(x)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
        )


class ValueModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            use_interactions,
    ):
        super().__init__(
            observation_space,
            None,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.use_interactions = use_interactions
        self.network = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                           action_dim=0, num_objects=self.num_objects,
                           ignore_action=True, copy_action=False,
                           use_interactions=self.use_interactions, edge_actions=False)
        self.mlp = nn.Linear(self.embedding_dim, 1)

    def forward(self, embedding):
        gnn_output = self.network(embedding, None)[0].sum(dim=1).squeeze(dim=1)
        return self.mlp(gnn_output).squeeze(dim=1)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            use_interactions=self.use_interactions,
        )


class ValueModelMLP(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_layers=3,
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.network = nn.Sequential(*make_node_mlp_layers(
            self.num_layers,
            self.embedding_dim,
            self.hidden_dim,
            output_dim=1,
            act_fn='relu',
            layer_norm=True
        ))

    def forward(self, embedding):
        x = embedding
        return self.network(x)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
        )


class ActorGNN(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            use_interactions,
            temperature: float = None,
    ):
        if isinstance(action_space, gym.spaces.Box):
            squash_output = True
        elif isinstance(action_space, gym.spaces.Discrete):
            squash_output = False
        else:
            raise ValueError(f'Unexpected action space type: {type(action_space)}')

        super(ActorGNN, self).__init__(
            observation_space,
            action_space,
            features_extractor=False,
            normalize_images=False,
            squash_output=squash_output,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.use_interactions = use_interactions
        self.temperature = temperature

        self.latent_pi = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                             action_dim=0, num_objects=self.num_objects,
                             ignore_action=True, copy_action=False,
                             use_interactions=self.use_interactions, edge_actions=False, )
        last_layer_dim = self.embedding_dim
        if squash_output:
            action_dim = get_action_dim(self.action_space)
            self.action_dist = SquashedDiagGaussianDistribution(action_dim)
            self.log_std = nn.Linear(last_layer_dim, action_dim)
        else:
            assert self.temperature is not None, f'Temperature for Gumbel-Softmax is not provided.'
            action_dim = self.action_space.n
            self.action_dist = utils.CategoricalDistribution(action_dim, self.temperature)
            self.log_std = nn.Identity()
        self.mu = nn.Linear(last_layer_dim, action_dim)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                squash_output=self.squash_output,
                embedding_dim=self.embedding_dim,
                num_objects=self.num_objects,
                hidden_dim=self.hidden_dim,
                use_interactions=self.use_interactions,
            )
        )
        return data

    def get_std(self) -> Tensor:
        raise NotImplementedError()

    def reset_noise(self, batch_size: int = 1) -> None:
        raise NotImplementedError()

    def get_action_dist_params(self, embedding) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]:
        latent_pi = self.latent_pi(embedding, None)[0].squeeze(-1).mean(-2)
        mean_actions = self.mu(latent_pi)

        # Unstructured exploration (Original implementation)
        log_std = self.log_std(latent_pi)
        # Original Implementation to cap the standard deviation
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        return mean_actions, log_std, {}

    def proba_distribution(self, embedding):
        mean_actions, log_std, kwargs = self.get_action_dist_params(embedding)
        return self.action_dist.proba_distribution(mean_actions, log_std)

    def forward(self, embedding, deterministic: bool = False) -> Tensor:
        mean_actions, log_std, kwargs = self.get_action_dist_params(embedding)
        # Note: the action is squashed
        return self.action_dist.actions_from_params(mean_actions=mean_actions, log_std=log_std,
                                                    deterministic=deterministic, **kwargs)

    def action_log_prob(self, embedding) -> Tuple[Tensor, Tensor]:
        mean_actions, log_std, kwargs = self.get_action_dist_params(embedding)
        # return action and associated log prob
        return self.action_dist.log_prob_from_params(mean_actions=mean_actions, log_std=log_std, **kwargs)

    def _predict(self, embedding, deterministic: bool = False) -> Tensor:
        return self(embedding, deterministic)

    def predict(
            self,
            embedding,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        self.set_training_mode(False)

        with no_grad():
            actions = self._predict(embedding, deterministic=deterministic)
        # Convert to numpy
        actions = actions.cpu().numpy()

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        return actions, state


class DiscreteActor(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim=None,
            net_arch=(),
            features_extractor_class=None,
            features_extractor_kwargs=None,
    ):
        assert bool(embedding_dim) != bool(features_extractor_class), \
            (f'Use embedding_dim or features_extractor_class, but not both: '
             f'embedding_dim={embedding_dim} features_extractor_class={features_extractor_class}')
        super(DiscreteActor, self).__init__(
            observation_space,
            action_space,
            features_extractor_class=features_extractor_class,
            features_extractor_kwargs=features_extractor_kwargs,
            normalize_images=features_extractor_class is not None,
            squash_output=False,
        )

        self.net_arch = net_arch
        if features_extractor_class is not None:
            self.features_extractor = self.make_features_extractor()
            self.embedding_dim = self.features_extractor.features_dim
        else:
            self.embedding_dim = embedding_dim

        action_dim = self.action_space.n
        if len(net_arch) == 0:
            self.latent_pi = nn.Identity()
            last_layer_dim = self.embedding_dim
        else:
            self.latent_pi = nn.Sequential(*create_mlp(self.embedding_dim, -1, net_arch))
            last_layer_dim = self.net_arch[-1]

        self.action_dist = CategoricalDistribution(action_dim)
        self.action_net = self.action_dist.proba_distribution_net(last_layer_dim)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                squash_output=self.squash_output,
            )
        )
        return data

    def get_action_dist_params(self, observation) -> Tensor:
        if self.features_extractor is not None:
            embedding = self.extract_features(observation)
        else:
            embedding = observation

        latent_pi = self.latent_pi(embedding)
        mean_actions = self.action_net(latent_pi)
        return mean_actions

    def forward(self, observation, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
        mean_actions = self.get_action_dist_params(observation)
        action, log_prob_action = self.action_dist.log_prob_from_params(mean_actions)
        return log_prob_action, self.action_dist.distribution.probs, self.action_dist.distribution.logits

    def _predict(self, observation, deterministic: bool = False) -> Tensor:
        mean_actions = self.get_action_dist_params(observation)
        distribution = self.action_dist.proba_distribution(action_logits=mean_actions)
        return distribution.get_actions(deterministic=deterministic)

    def predict(
            self,
            observation,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        self.set_training_mode(False)

        with no_grad():
            actions = self._predict(observation, deterministic=deterministic)
        # Convert to numpy
        actions = actions.cpu().numpy()

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        return actions, state


class DiscreteActorGNN(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            use_interactions,
    ):
        super(DiscreteActorGNN, self).__init__(
            observation_space,
            action_space,
            features_extractor=False,
            normalize_images=False,
            squash_output=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.use_interactions = use_interactions

        action_dim = self.action_space.n
        self.latent_pi = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                             action_dim=0, num_objects=self.num_objects,
                             ignore_action=True, copy_action=False,
                             use_interactions=self.use_interactions, edge_actions=False, )
        last_layer_dim = self.embedding_dim
        self.action_dist = CategoricalDistribution(action_dim)
        self.action_net = self.action_dist.proba_distribution_net(last_layer_dim)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                squash_output=self.squash_output,
                embedding_dim=self.embedding_dim,
                num_objects=self.num_objects,
                hidden_dim=self.hidden_dim,
                use_interactions=self.use_interactions,
            )
        )
        return data

    def get_action_dist_params(self, embedding) -> Tensor:
        latent_pi = self.latent_pi(embedding, None)[0].squeeze(-1).mean(-2)
        mean_actions = self.action_net(latent_pi)
        return mean_actions

    def forward(self, embedding, deterministic: bool = False) -> Tuple[Tensor, Tensor, Tensor]:
        mean_actions = self.get_action_dist_params(embedding)
        action, log_prob_action = self.action_dist.log_prob_from_params(mean_actions)
        return log_prob_action, self.action_dist.distribution.probs, self.action_dist.distribution.logits

    def _predict(self, embedding, deterministic: bool = False) -> Tensor:
        mean_actions = self.get_action_dist_params(embedding)
        distribution = self.action_dist.proba_distribution(action_logits=mean_actions)
        return distribution.get_actions(deterministic=deterministic)

    def predict(
            self,
            embedding,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        self.set_training_mode(False)

        with no_grad():
            actions = self._predict(embedding, deterministic=deterministic)
        # Convert to numpy
        actions = actions.cpu().numpy()

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        return actions, state

    def proba_distribution(self, embedding):
        mean_actions = self.get_action_dist_params(embedding)
        distribution = self.action_dist.proba_distribution(action_logits=mean_actions)
        return distribution


class ContinuousCriticGNN(BaseModel):
    def __init__(
            self,
            q_model,
            n_critics: int = 1,
    ):
        super().__init__(
            None,
            None,
            features_extractor=None,
            normalize_images=False,
        )

        self.n_critics = n_critics
        self.q_models = nn.ModuleList()
        self.q_models.append(q_model)

        def reinit(layer):
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

        for _ in range(self.n_critics - 1):
            q_model_copy = copy.deepcopy(q_model)
            q_model_copy.apply(reinit)
            self.q_models.append(q_model_copy)

    def forward(self, embedding: Tensor, action: Tensor, actor: ActorGNN) -> Tuple[Tensor, ...]:
        return tuple(q_model(embedding, action) for q_model in self.q_models)

    def q1_forward(self, embedding: Tensor, action: Tensor, actor: ActorGNN) -> Tensor:
        return self.q_models[0](embedding, action)


class ContinuousCriticWMGNN(BaseModel):
    def __init__(
            self,
            transition_model,
            reward_model,
            value_model,
            action_space,
            n_critics: int = 1,
            gamma: float = 0.99,
            depth: int = 1,
    ):
        super().__init__(
            None,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )

        self.gamma = gamma
        self.depth = depth
        self.n_critics = n_critics
        self.transition_models = nn.ModuleList()
        self.reward_models = nn.ModuleList()
        self.value_models = nn.ModuleList()

        self.transition_models.append(transition_model)
        self.reward_models.append(reward_model)
        self.value_models.append(value_model)

        def reinit(layer):
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

        for _ in range(self.n_critics - 1):
            for models in (self.transition_models, self.reward_models, self.value_models):
                model_copy = copy.deepcopy(models[0])
                model_copy.apply(reinit)
                models.append(model_copy)

    def forward(self, embedding: Tensor, action: Tensor, actor: ActorGNN) -> Tuple[Tensor, ...]:
        return tuple(self.q1_forward(embedding, action, actor, critic_id) for critic_id in range(self.n_critics))

    def q1_forward(self, embedding: Tensor, action: Tensor, actor: ActorGNN, critic_id: int = 0) -> Tensor:
        transition_model = self.transition_models[critic_id]
        reward_model = self.reward_models[critic_id]
        value_model = self.value_models[critic_id]
        embeddings = [embedding]
        actions = [action]
        for _ in range(self.depth):
            next_embedding = transition_model(embeddings[-1], actions[-1])
            next_action = actor(next_embedding)
            embeddings.append(next_embedding)
            actions.append(next_action)

        q_value = value_model(embeddings[self.depth])
        for i in range(self.depth - 1, -1, -1):
            q_value = reward_model(embeddings[i], actions[i]) + self.gamma * q_value

        return q_value


class DiscreteCritic(BaseModel):
    def __init__(
            self,
            observation_space,
            action_space,
            embedding_dim=None,
            net_arch=(),
            features_extractor_class=None,
            features_extractor_kwargs=None,
            n_critics: int = 1,
    ):
        assert bool(embedding_dim) != bool(features_extractor_class), \
            (f'Use embedding_dim or features_extractor_class, but not both: '
             f'embedding_dim={embedding_dim} features_extractor_class={features_extractor_class}')
        super().__init__(
            observation_space,
            action_space,
            features_extractor_class=features_extractor_class,
            features_extractor_kwargs=features_extractor_kwargs,
            normalize_images=features_extractor_class is not None,
        )

        self.net_arch = net_arch
        if features_extractor_class is not None:
            self.features_extractor = self.make_features_extractor()
            self.embedding_dim = self.features_extractor.features_dim
        else:
            self.embedding_dim = embedding_dim

        self.n_critics = n_critics
        self.num_actions = self.action_space.n
        self.actions = nn.functional.one_hot(torch.arange(0, self.num_actions)).to(torch.float32)
        self.q_networks = nn.ModuleList([
            nn.Sequential(*create_mlp(self.embedding_dim + self.num_actions, 1, net_arch)) for _ in range(self.n_critics)
        ])

    def forward(self, observation: Tensor, actor: DiscreteActor) -> Tuple[Tensor, ...]:
        return tuple(self.q1_forward(observation, actor, critic_id) for critic_id in range(self.n_critics))

    def q1_forward(self, observation: Tensor, actor: DiscreteActor, critic_id: int = 0) -> Tensor:
        if self.features_extractor is not None:
            embedding = self.extract_features(observation)
        else:
            embedding = observation

        self.actions = self.actions.to(embedding.device)
        batch_size, embedding_dim = embedding.size()
        embedding_expanded = embedding.unsqueeze(1).expand(-1, self.num_actions, -1).reshape(
            batch_size * self.num_actions, embedding_dim)
        action_expanded = self.actions.unsqueeze(0).expand(batch_size, -1, -1).reshape(
            batch_size * self.num_actions, self.num_actions)

        q_values = self.q_networks[critic_id](torch.cat([embedding_expanded, action_expanded], dim=1))
        return q_values.reshape(batch_size, self.num_actions)


class DiscreteCriticWMGNN(BaseModel):
    def __init__(
            self,
            transition_model,
            reward_model,
            value_model,
            action_space,
            n_critics: int = 1,
            gamma: float = 0.99,
            depth: int = 1,
    ):
        super().__init__(
            None,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )

        self.gamma = gamma
        self.depth = depth
        self.n_critics = n_critics
        self.transition_models = nn.ModuleList()
        self.reward_models = nn.ModuleList()
        self.value_models = nn.ModuleList()

        self.transition_models.append(transition_model)
        self.reward_models.append(reward_model)
        self.value_models.append(value_model)
        self.num_actions = self.action_space.n
        self.actions = torch.arange(0, self.num_actions)

        def reinit(layer):
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

        for _ in range(self.n_critics - 1):
            for models in (self.transition_models, self.reward_models, self.value_models):
                model_copy = copy.deepcopy(models[0])
                model_copy.apply(reinit)
                models.append(model_copy)

    def forward(self, embedding: Tensor, actor: DiscreteActorGNN, actions: Tensor = None) -> Tuple[Tensor, ...]:
        return tuple(self.q1_forward(embedding, actor, actions, critic_id) for critic_id in range(self.n_critics))

    def q1_forward(self, embedding: Tensor, actor: DiscreteActorGNN, actions: Tensor = None, critic_id: int = 0) -> Tensor:
        transition_model = self.transition_models[critic_id]
        reward_model = self.reward_models[critic_id]
        value_model = self.value_models[critic_id]

        batch_size, num_slots, embedding_dim = embedding.size()

        if actions is None:
            actions = self.actions.to(embedding.device)
            embedding_expanded = embedding.unsqueeze(1).expand(-1, self.num_actions, -1, -1).reshape(
                batch_size * self.num_actions, num_slots, embedding_dim)
            action_expanded = actions.unsqueeze(0).expand(batch_size, self.num_actions).flatten()
            num_actions = self.num_actions
        else:
            assert batch_size == actions.size()[0], f'Embedding size: {embedding.size()}, actions size: {actions.size()}'
            embedding_expanded = embedding
            action_expanded = actions
            num_actions = 1

        rewards = reward_model(embedding_expanded, action_expanded).reshape(batch_size, num_actions)
        next_embeddings = transition_model(embedding_expanded, action_expanded)
        values = value_model(next_embeddings).reshape(batch_size, num_actions)

        return rewards + self.gamma * values


class DiscreteCriticWMMLP(BaseModel):
    def __init__(
            self,
            transition_model,
            reward_model,
            value_model,
            action_space,
            n_critics: int = 1,
            gamma: float = 0.99,
            depth: int = 1,
    ):
        super().__init__(
            None,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )

        self.gamma = gamma
        self.depth = depth
        self.n_critics = n_critics
        self.transition_models = nn.ModuleList()
        self.reward_models = nn.ModuleList()
        self.value_models = nn.ModuleList()

        self.transition_models.append(transition_model)
        self.reward_models.append(reward_model)
        self.value_models.append(value_model)
        self.num_actions = self.action_space.n
        self.actions = nn.functional.one_hot(torch.arange(0, self.num_actions), num_classes=self.num_actions).to(torch.float32)

        def reinit(layer):
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

        for _ in range(self.n_critics - 1):
            for models in (self.transition_models, self.reward_models, self.value_models):
                model_copy = copy.deepcopy(models[0])
                model_copy.apply(reinit)
                models.append(model_copy)

    def forward(self, embedding: Tensor, actor: DiscreteActor) -> Tuple[Tensor, ...]:
        return tuple(self.q1_forward(embedding, actor, critic_id) for critic_id in range(self.n_critics))

    def q1_forward(self, embedding: Tensor, actor: DiscreteActor, critic_id: int = 0) -> Tensor:
        self.actions = self.actions.to(embedding.device)
        transition_model = self.transition_models[critic_id]
        reward_model = self.reward_models[critic_id]
        value_model = self.value_models[critic_id]

        batch_size, embedding_dim = embedding.size()
        embedding_expanded = embedding.unsqueeze(1).expand(-1, self.num_actions, -1).reshape(
            batch_size * self.num_actions, embedding_dim)
        action_expanded = self.actions.unsqueeze(0).expand(batch_size, self.num_actions, -1).reshape(batch_size * self.num_actions, -1)
        rewards = reward_model(embedding_expanded, action_expanded).reshape(batch_size, self.num_actions)
        next_embeddings = transition_model(embedding_expanded, action_expanded)
        values = value_model(next_embeddings).reshape(batch_size, self.num_actions)

        return rewards + self.gamma * values


class DiscreteCriticGNN(BaseModel):
    def __init__(
            self,
            q_model,
            action_space,
            n_critics: int = 1,
    ):
        super().__init__(
            None,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )

        self.n_critics = n_critics
        self.q_models = nn.ModuleList()
        self.q_models.append(q_model)

        self.num_actions = self.action_space.n
        self.actions = torch.arange(0, self.num_actions)

        def reinit(layer):
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

        for _ in range(self.n_critics - 1):
            for models in (self.q_models,):
                model_copy = copy.deepcopy(models[0])
                model_copy.apply(reinit)
                models.append(model_copy)

    def forward(self, embedding: Tensor, actor: DiscreteActorGNN) -> Tuple[Tensor, ...]:
        return tuple(self.q1_forward(embedding, actor, critic_id) for critic_id in range(self.n_critics))

    def q1_forward(self, embedding: Tensor, actor: DiscreteActorGNN, critic_id: int = 0) -> Tensor:
        self.actions = self.actions.to(embedding.device)
        q_model = self.q_models[critic_id]

        batch_size, num_slots, embedding_dim = embedding.size()
        embedding_expanded = embedding.unsqueeze(1).expand(-1, self.num_actions, -1, -1).reshape(
            batch_size * self.num_actions, num_slots, embedding_dim)
        action_expanded = self.actions.unsqueeze(0).expand(batch_size, self.num_actions).flatten()
        q_value = q_model(embedding_expanded, action_expanded).reshape(batch_size, self.num_actions)

        return q_value


class SACCustomPolicy(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            lr: float,
            features_extractor,
            actor,
            critic,
            optimizer_class: Type[optim.Optimizer] = optim.Adam,
            optimizer_kwargs: Optional[Dict[str, Any]] = None,
            move_channels_axis=False,
            use_wm_optimizer=False,
            is_frozen_features_extractor=True,
    ):
        super(SACCustomPolicy, self).__init__(
            observation_space,
            action_space,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=True,
        )

        self.features_extractor = features_extractor
        self.is_frozen_features_extractor = is_frozen_features_extractor
        self.actor = actor
        self.critic = critic
        self.lr = lr
        self.move_channels_axis = move_channels_axis
        self.use_wm_optimizer = use_wm_optimizer

        self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=self.lr, **self.optimizer_kwargs)

        self.critic_target = copy.deepcopy(self.critic)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_target.set_training_mode(False)

        if self.use_wm_optimizer:
            wm_parameters = [parameter for name, parameter in self.critic.named_parameters() if
                             name.startswith('transition') or name.startswith('reward')]
            if not self.is_frozen_features_extractor:
                wm_parameters += list(self.features_extractor.parameters())

            value_parameters = [parameter for name, parameter in self.critic.named_parameters() if
                                name.startswith('value')]
            self.critic.optimizer_wm = self.optimizer_class(wm_parameters, lr=self.lr, **self.optimizer_kwargs)
            self.critic.optimizer_value = self.optimizer_class(value_parameters, lr=self.lr, **self.optimizer_kwargs)
        else:
            self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=self.lr, **self.optimizer_kwargs)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                squash_output=self.squash_output,
            )
        )
        return data

    def reset_noise(self, batch_size: int = 1) -> None:
        raise NotImplementedError()

    def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ActorGNN:
        raise NotImplementedError()

    def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCriticWMGNN:
        raise NotImplementedError()

    def forward(self, embedding, deterministic: bool = False) -> Tensor:
        return self._predict(embedding, deterministic=deterministic)

    def extract_features(self, obs: Tensor, prev_slots=None) -> Tensor:
        if self.move_channels_axis:
            obs = torch.moveaxis(obs, 3, 1)
        return self.features_extractor(obs.to(self.device) / 255., prev_slots=prev_slots)

    def _predict(self, embedding, deterministic: bool = False) -> Tensor:
        return self.actor(embedding, deterministic)

    def set_training_mode(self, mode: bool) -> None:
        """
        Put the policy in either training or evaluation mode.

        This affects certain modules, such as batch normalisation and dropout.

        :param mode: if true, set to training mode, else set to evaluation mode
        """
        self.actor.set_training_mode(mode)
        self.critic.set_training_mode(mode)
        self.training = mode
