import torch
import torch.nn as nn
import torch_geometric
import torch_scatter
from gym import spaces
from models.tsp_reasoner import TSPReasoner
from models.algorithm_processor import LitAlgorithmProcessor
from models.gnns import LitProcessorSet
from sb3_contrib.ppo_mask.policies import MaskableActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import MultiInputActorCriticPolicy


class AlgorithmicFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self,
                 spec,
                 data,
                 latent_features: int,
                 node_features: int,
                 edge_features: int,
                 output_features: int,
                 bias: bool = True,
                 processors=['MPNN'],
                 load_processor: str = None,
                 freeze_processor: bool = False,
                 ):
        super().__init__(observation_space=None, features_dim=node_features)

        processor = LitProcessorSet(2*latent_features,
                                    latent_features,
                                    latent_features,
                                    bias=bias,
                                    processors=processors,
                                    use_gate=True)
        if load_processor is not None:
            print("Processor loaded.")
            algo_processor = LitAlgorithmProcessor.load_from_checkpoint(load_processor)
            if freeze_processor:
                print("Processor frozen.")
                algo_processor.freeze()

            processor.processors.append(algo_processor.processor_set.processors[0])

        self.net = TSPReasoner(
            spec=spec,
            data=data,
            latent_features=latent_features,
            node_features=node_features,
            edge_features=edge_features,
            output_features=output_features,
            algo_processor=processor,
            timeit=False,
            xavier_on_scalars=False
        )

        self.net.dataset_spec = spec
        self.encode_visited_nodes = nn.Linear(1, latent_features)
        self.latent_features = latent_features
        self.net.epoch = 0
        self.max_iter = 16

    def reset(self, obs):
        self.net.epoch = 0
        self.net.set_initial_states(obs)
        self.net.prepare_initial_masks(obs)
        self.net.zero_steps()

    def encode(self, obs):
        node_fts, edge_fts = self.net.encode_inputs(obs)
        node_fts += self.encode_visited_nodes(obs.visited_nodes.unsqueeze(1))
        return node_fts, edge_fts

    def forward(self, obs, last_latent):
        self.net.last_latent = last_latent
        node_fts, edge_fts = self.encode(obs)
        current_latents, _, _ = self.net(obs, node_fts, edge_fts)
        out = torch.cat((node_fts, self.net.last_latent, current_latents), dim=1)
        self.net.last_latent = current_latents

        return out, edge_fts, current_latents


class ACNet(nn.Module):
    def __init__(self, latent_dim_pi, latent_dim_vf):
        super().__init__()
        self.latent_dim_pi = 1
        self.latent_dim_vf = 1

    def forward(self, features):
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features):
        return features

    def forward_critic(self, features):
        return features


class PolicyNet(MaskableActorCriticPolicy):
    name = 'NAR_Policy'

    def __init__(
            self,
            observation_space: spaces.Space,
            action_space: spaces.Space,
            lr_schedule,
            spec,
            data,
            net_arch=None,
            activation_fn=nn.Tanh,
            latent_features: int = 128,
            node_features: int = 1,
            edge_features: int = 1,
            output_features: int = 1,
            processors = ['MPNN'],
            load_processor: str = None,
            freeze_processor: bool = False,
            *args,
            **kwargs,
    ):
        """
        :param observation_space: Observation space of the agent
        :param action_space: Action space of the agent
        :param lr_schedule: Learning rate schedule
        :param psn: Physical Service Network
        :param servers_map_idx_id: Mapping between servers' indexes and their IDs
        :param net_arch: architecture of the policy and value networks after the feature extractor
        :param activation_fn: Activation function
        :param gcn_layers_dims: Dimensions of the GCN layers
        :param nspr_out_features: Number of output features of the NSPR state
        :param use_heuristic: Whether to use the heuristic or not
        :param heu_kwargs: Keyword arguments for the heuristic
        """

        # assert len(net_arch) == 1 and isinstance(net_arch[0], dict), \
        #     "This policy allows net_arch to be a list with only one dict"

        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )

        self.features_extractor = AlgorithmicFeaturesExtractor(
            spec=spec,
            data=data,
            processors=processors,
            latent_features=latent_features,
            node_features=node_features,
            edge_features=edge_features,
            output_features=output_features,
            load_processor=load_processor,
            freeze_processor=freeze_processor
        )

        self.value_head = nn.Sequential(
            nn.Linear(latent_features*3, 1),
            nn.ReLU6()
        )
        self.policy_head = nn.ModuleList([nn.Linear(latent_features*3, latent_features),
                                          nn.Linear(latent_features*3, latent_features),
                                          nn.Linear(latent_features, latent_features),
                                          nn.Linear(latent_features, 1)])

        delattr(self, "value_net")
        delattr(self, "action_net")

        self.optimizer = self.optimizer_class(self.parameters(), **self.optimizer_kwargs)
        # TODO: check what this step actually does
        # Disable orthogonal initialization
        # self.ortho_init = False

    def _build_mlp_extractor(self):
        self.mlp_extractor = ACNet(1, 1)

    def extract_features(self, obs, last_latent):
        """
        Preprocess the observation if needed and extract features.
        :param obs: Observation
        :return: the output of the feature extractor(s)
        """
        feats, edge_feats, current_latents = self.features_extractor(obs, last_latent)
        return feats, edge_feats, current_latents

    def _decode_logits(self, pi_feats, edge_feats, src, dst):
        fr = self.policy_head[0](pi_feats[src])
        to = self.policy_head[1](pi_feats[dst])
        edge = self.policy_head[2](edge_feats)
        pointer_logits = self.policy_head[3](fr.max(to+edge)).squeeze(-1)
        return pointer_logits

    def _decode_pi(self, pi_feats, edge_feats, src, dst, num_nodes, current_node):
        pointer_logits = self._decode_logits(pi_feats, edge_feats, src, dst)
        log_action_probs = torch_scatter.scatter_log_softmax(pointer_logits, src, dim_size=num_nodes).reshape(num_nodes, num_nodes)
        return log_action_probs[current_node], pointer_logits.reshape(num_nodes, num_nodes)[current_node]

    def _decode_pi_batch(self, batch_size, pi_feats, edge_feats, src, dst, num_nodes, current_node):
        pointer_logits = self._decode_logits(pi_feats, edge_feats, src, dst)
        log_action_probs = torch_scatter.scatter_log_softmax(pointer_logits, src, dim_size=num_nodes).reshape(batch_size, num_nodes//batch_size, num_nodes//batch_size)
        indices = current_node.view(-1, 1, 1).expand(batch_size, 1, log_action_probs.shape[2])
        return torch.gather(log_action_probs, 1, indices), torch.gather(pointer_logits.reshape(batch_size, num_nodes//batch_size, num_nodes//batch_size), 1, indices)

    def forward(self, obs, last_latent, action_masks, deterministic=False):

        current_node = obs['current_node']
        obs = self.preprocess_obs(obs)
        src, dst = obs.edge_index[0], obs.edge_index[1]
        feats, edge_feats, current_latents = self.extract_features(obs, last_latent)

        # Evaluate the values for the given observations
        values = -self.value_head(feats)[current_node]

        log_action_probs, action_logits = self._decode_pi(feats,
                                                      edge_feats,
                                                      src,
                                                      dst,
                                                      obs.num_nodes,
                                                      current_node)

        action_masks = (1-torch.FloatTensor(action_masks)).to(log_action_probs)
        action_masks[action_masks == 0] = -float('inf')
        action_masks[action_masks == 1] = 0.
        if deterministic:
            action = (log_action_probs + action_masks).argmax(-1)
        else:
            m = torch.distributions.categorical.Categorical(logits=(log_action_probs + action_masks))
            action = m.sample()
        log_prob = log_action_probs[:, action]

        return action, values, log_prob, current_latents

    def preprocess_obs(self, obs):
        from torch_geometric.data import Batch, Data
        return Batch.from_data_list([
            Data(**{
                key: obs[key].squeeze().long()
                if key in ['edge_index'] else obs[key].squeeze()
                for key in obs
            })
        ])

    def preprocess_batch(self, obs):
        from torch_geometric.data import Batch, Data

        d = {
            key: obs[key].unbind()
            for key in obs
        }

        d = [Data(**dict(zip(d, i))) for i in zip(*d.values())]
        d = Batch.from_data_list(d)
        d.edge_index = d.edge_index.long()
        return d

    def evaluate_actions(self, obs, actions, last_latent, action_masks):
        obs = self.preprocess_batch(obs)
        current_node = obs.current_node.long()
        src, dst = obs.edge_index[0], obs.edge_index[1]
        feats, edge_feats, current_latents = self.extract_features(obs, last_latent.float())
        _, action_logits = self._decode_pi_batch(obs.num_graphs,
                                                 feats,
                                                 edge_feats,
                                                 src,
                                                 dst,
                                                 obs.num_nodes,
                                                 current_node)
        distribution = self.action_dist.proba_distribution(action_logits=action_logits)
        log_prob = distribution.log_prob(actions)

        values = -torch.gather(self.value_head(feats).reshape(obs.num_graphs, obs.num_nodes // obs.num_graphs), 1, current_node.unsqueeze(1)).squeeze()
        return values, log_prob, distribution.entropy()

    def predict_values(self, obs, latents):
        """
        Get the estimated values according to the current policy given the observations.
        :param obs: Observation
        :return: the estimated values.
        """
        obs = self.preprocess_obs(obs)
        feats, _, _ = self.extract_features(obs, latents)
        current_node = obs['current_node']
        return -self.value_head(feats)[current_node]

    @torch.no_grad()
    def predict(self, obs, latents, action_masks=None, deterministic=False):
        self.eval()
        actions, _, _, latents = self.forward(obs, latents, action_masks, deterministic)
        return actions.cpu(), latents
