import copy
from typing import Type, Dict, Any, Optional, Tuple

import gym
import numpy as np
import torch
from torch import nn, Tensor, optim

from networks.gnn import GNN
from stable_baselines3.common.policies import BaseModel, BasePolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class TransitionModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            ignore_action,
            copy_action,
            use_interactions,
            edge_actions
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.ignore_action = ignore_action
        self.copy_action = copy_action
        self.use_interactions = use_interactions
        self.edge_actions = edge_actions
        self.network = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                           action_dim=self.action_space.n, num_objects=self.num_objects,
                           ignore_action=self.ignore_action, copy_action=self.copy_action,
                           use_interactions=self.use_interactions, edge_actions=self.edge_actions)

    def forward(self, embedding, action):
        return self.network(embedding, action)[0] + embedding

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            ignore_action=self.ignore_action,
            copy_action=self.copy_action,
            use_interactions=self.use_interactions,
            edge_actions=self.edge_actions
        )


class RewardModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            ignore_action,
            copy_action,
            use_interactions,
            edge_actions
    ):
        super().__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.ignore_action = ignore_action
        self.copy_action = copy_action
        self.use_interactions = use_interactions
        self.edge_actions = edge_actions
        self.gnn = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                       action_dim=self.action_space.n, num_objects=self.num_objects,
                       ignore_action=False, copy_action=True, act_fn='relu', layer_norm=True, num_layers=3,
                       use_interactions=True, edge_actions=True)
        self.mlp = nn.Linear(self.embedding_dim, 1)

    def forward(self, embedding, action):
        return self.mlp(self.gnn(embedding, action)[0].mean(dim=1)).squeeze(dim=1)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            ignore_action=self.ignore_action,
            copy_action=self.copy_action,
            use_interactions=self.use_interactions,
            edge_actions=self.edge_actions
        )


class ValueModelGNN(BaseModel):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            embedding_dim,
            hidden_dim,
            num_objects,
            use_interactions,
    ):
        super().__init__(
            observation_space,
            None,
            features_extractor=None,
            normalize_images=False,
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_objects = num_objects
        self.use_interactions = use_interactions
        self.network = GNN(input_dim=self.embedding_dim, hidden_dim=self.hidden_dim,
                           action_dim=0, num_objects=self.num_objects,
                           ignore_action=True, copy_action=False,
                           use_interactions=self.use_interactions, edge_actions=False)
        self.mlp = nn.Linear(self.embedding_dim, 1)

    def forward(self, embedding):
        gnn_output = self.network(embedding, None)[0].sum(dim=1).squeeze(dim=1)
        return self.mlp(gnn_output).squeeze(dim=1)

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()

    def extract_features(self, obs: Tensor) -> Tensor:
        raise NotImplementedError()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        return dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            normalize_images=self.normalize_images,
            embedding_dim=self.embedding_dim,
            hidden_dim=self.hidden_dim,
            num_objects=self.num_objects,
            use_interactions=self.use_interactions,
        )


class QNetwork(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            transition_model,
            reward_model,
            value_model,
            gamma,
    ):
        super(QNetwork, self).__init__(
            observation_space,
            action_space,
            features_extractor=None,
            normalize_images=False,
        )
        self.transition_model = transition_model
        self.reward_model = reward_model
        self.value_model = value_model
        self.gamma = gamma
        self.num_actions = self.action_space.n
        self.actions = torch.arange(0, self.num_actions).to('cuda')

    def forward(self, embedding: Tensor) -> Tensor:
        batch_size, num_slots, embedding_dim = embedding.size()
        embedding_expanded = embedding.unsqueeze(1).expand(-1, self.action_space.n, -1, -1).reshape(
            batch_size * self.num_actions, num_slots, embedding_dim)
        action_expanded = self.actions.unsqueeze(0).expand(batch_size, self.num_actions).flatten()
        rewards = self.reward_model(embedding_expanded, action_expanded).reshape(batch_size, self.num_actions)
        next_embeddings = self.transition_model(embedding_expanded, action_expanded)
        values = self.value_model(next_embeddings).reshape(batch_size, self.num_actions)

        return rewards + self.gamma * values

    def _predict(self, embedding: Tensor, deterministic: bool = True) -> Tensor:
        q_values = self(embedding)
        # Greedy action
        action = q_values.argmax(dim=1).reshape(-1)
        return action

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                gamma=self.gamma,
            )
        )
        return data


class OOQNPolicy(BasePolicy):
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            lr: float,
            features_extractor,
            q_network,
            optimizer_class: Type[optim.Optimizer] = optim.Adam,
            optimizer_kwargs: Optional[Dict[str, Any]] = None,
            use_wm_optimizer=False,
    ):
        super(OOQNPolicy, self).__init__(
            observation_space,
            action_space,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
        )

        self.features_extractor = features_extractor
        self.q_net = q_network
        self.lr = lr
        self.use_wm_optimizer = use_wm_optimizer
        self.q_net_target = copy.deepcopy(self.q_net)
        self.q_net_target.load_state_dict(self.q_net.state_dict())
        self.q_net_target.set_training_mode(False)

        if self.use_wm_optimizer:
            wm_parameters = [parameter for name, parameter in self.q_net.named_parameters() if
                             name.startswith('transition') or name.startswith('reward')]
            value_parameters = [parameter for name, parameter in self.q_net.named_parameters() if
                                name.startswith('value')]
            self.q_net.optimizer_wm = self.optimizer_class(wm_parameters, lr=self.lr, **self.optimizer_kwargs)
            self.q_net.optimizer_value = self.optimizer_class(value_parameters, lr=self.lr, **self.optimizer_kwargs)
        else:
            self.q_net.optimizer = self.optimizer_class(self.q_net.parameters(), lr=self.lr, **self.optimizer_kwargs)

    def forward(self, obs: Tensor, deterministic: bool = True) -> Tensor:
        return self._predict(obs, deterministic=deterministic)

    def _predict(self, embedding: Tensor, deterministic: bool = True) -> Tensor:
        return self.q_net._predict(embedding, deterministic=deterministic)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                lr=self.lr,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                use_wm_optimizer=self.use_wm_optimizer
            )
        )
        return data

    def extract_features(self, obs: Tensor, prev_slots=None) -> Tensor:
        return self.features_extractor(obs.to(self.device) / 255., prev_slots=prev_slots)

    def predict(
            self,
            observation: np.ndarray,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        self.set_training_mode(False)
        with torch.no_grad():
            embedding = self.extract_features(obs=torch.as_tensor(observation, dtype=torch.float32))
            action = self._predict(embedding, deterministic=deterministic)

        action = action.cpu().numpy()
        return action, state

    def set_training_mode(self, mode: bool) -> None:
        self.q_net.set_training_mode(mode)
        self.training = mode

    def _update_features_extractor(
            self,
            net_kwargs: Dict[str, Any],
            features_extractor: Optional[BaseFeaturesExtractor] = None,
    ) -> Dict[str, Any]:
        raise NotImplementedError()

    def make_features_extractor(self) -> BaseFeaturesExtractor:
        raise NotImplementedError()
