#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple, Optional
import abc
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

from habitat.tasks.nav.nav import (
    IntegratedPointGoalGPSAndCompassSensor,
    PointGoalSensor,
)
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.utils import CategoricalNet
from habitat_baselines.common.running_mean_and_var import RunningMeanAndVar
from habitat_baselines.rl.models.rnn_state_encoder import build_rnn_state_encoder
from habitat_baselines.rl.models.simple_cnn import SimpleCNN


@torch.jit.script
def _process_depth(observations: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if "depth" in observations:
        depth_observations = observations["depth"]

        depth_observations = torch.clamp(depth_observations, 0.0, 10.0)
        depth_observations /= 10.0

        observations["depth"] = depth_observations

    return observations


class SNIBottleneck(nn.Module):
    active: bool
    __constants__ = ["active"]

    def __init__(self, input_size, output_size, active=False):
        super().__init__()
        self.active: bool = active

        if active:
            self.output_size = output_size
            self.bottleneck = nn.Sequential(nn.Linear(input_size, 2 * output_size))
        else:
            self.output_size = input_size
            self.bottleneck = nn.Sequential()

    def forward(
        self, x
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        if not self.active:
            return x, None, None
        else:
            x = self.bottleneck(x)
            mu, sigma = torch.chunk(x, 2, x.dim() - 1)

            if self.training:
                sigma = F.softplus(sigma)
                sample = torch.addcmul(mu, sigma, torch.randn_like(sigma), value=1.0)

                # This is KL with standard normal for only
                # the parts that influence the gradient!
                kl = torch.addcmul(-torch.log(sigma), mu, mu, value=0.5)
                kl = torch.addcmul(kl, sigma, sigma, value=0.5)
            else:
                sample = None
                kl = None

            return mu, sample, kl


class ScriptableAC(nn.Module):
    def __init__(self, net, dim_actions):
        super().__init__()
        self.net = net

        self.action_distribution = CategoricalNet(self.net.output_size, dim_actions)
        self.critic = CriticHead(self.net.output_size)

    def post_net(
        self, features, rnn_hidden_states, deterministic: bool
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:

        logits = self.action_distribution(features)
        value = self.critic(features)["value"]

        dist_result = self.action_distribution.dist.act(
            logits, sample=not deterministic
        )

        return (
            value,
            dist_result,
            rnn_hidden_states,
        )

    @torch.jit.export
    def act(
        self,
        observations: Dict[str, torch.Tensor],
        rnn_hidden_states,
        prev_actions,
        masks,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        features, rnn_hidden_states = self.net(
            observations, rnn_hidden_states, prev_actions, masks
        )

        return self.post_net(features, rnn_hidden_states, deterministic)

    @torch.jit.export
    def act_post_visual(
        self,
        visual_out,
        observations: Dict[str, torch.Tensor],
        rnn_hidden_states,
        prev_actions,
        masks,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        features, rnn_hidden_states = self.net.rnn_forward(
            visual_out,
            observations["pointgoal_with_gps_compass"],
            rnn_hidden_states,
            prev_actions,
            masks,
        )

        return self.post_net(features, rnn_hidden_states, deterministic)

    @torch.jit.export
    def get_value(
        self,
        observations: Dict[str, torch.Tensor],
        rnn_hidden_states,
        prev_actions,
        masks,
    ):
        features, _ = self.net(observations, rnn_hidden_states, prev_actions, masks)
        return self.critic(features)["value"]

    @torch.jit.export
    def evaluate_actions(
        self,
        observations: Dict[str, torch.Tensor],
        rnn_hidden_states,
        prev_actions,
        masks,
        action,
    ) -> Dict[str, torch.Tensor]:
        features, rnn_hidden_states = self.net(
            observations, rnn_hidden_states, prev_actions, masks
        )

        result: Dict[str, torch.Tensor] = {}

        logits = self.action_distribution(features)

        result.update(self.action_distribution.dist.evaluate_actions(logits, action))
        result.update(self.critic(features))

        return result


class Policy(nn.Module):
    def __init__(self, net, observation_space, dim_actions):
        super().__init__()
        self.dim_actions = dim_actions

        self.num_recurrent_layers = net.num_recurrent_layers
        self.is_blind = net.is_blind

        self.ac = ScriptableAC(net, self.dim_actions)
        self.accelerated_net = None
        self.accel_out = None

        if "rgb" in observation_space.spaces:
            self.running_mean_and_var = RunningMeanAndVar(
                observation_space.spaces["rgb"].shape[0]
                + (
                    observation_space.spaces["depth"].shape[0]
                    if "depth" in observation_space.spaces
                    else 0
                ),
                initial_count=1e4,
            )
        else:
            self.running_mean_and_var = None

    def script_net(self):
        self.ac = torch.jit.script(self.ac)

    def init_trt(self):
        raise NotImplementedError

    def update_trt_weights(self):
        raise NotImplementedError

    def trt_enabled(self):
        return self.accelerated_net != None

    def forward(self, *x):
        raise NotImplementedError

    def _preprocess_obs(self, observations):
        dtype = next(self.parameters()).dtype
        observations = {k: v.to(dtype=dtype) for k, v in observations.items()}

        observations = _process_depth(observations)

        if "rgb" in observations:
            rgb = observations["rgb"] / 255.0
            x = [rgb]
            if "depth" in observations:
                x.append(observations["depth"])

            x = self.running_mean_and_var(torch.cat(x, 1))

            observations["rgb"] = x[:, 0:3]
            if "depth" in observations:
                observations["depth"] = x[:, 3:]

        return observations

    def act(
        self, observations, rnn_hidden_states, prev_actions, masks, deterministic=False,
    ):
        observations = self._preprocess_obs(observations)
        return self.ac.act(
            observations, rnn_hidden_states, prev_actions, masks, deterministic
        )

    def act_fast(
        self, observations, rnn_hidden_states, prev_actions, masks, deterministic=False,
    ):
        observations = self._preprocess_obs(observations)
        if self.accelerated_net == None:
            return self.ac.act(
                observations, rnn_hidden_states, prev_actions, masks, deterministic
            )
        else:
            if "rgb" in observations:
                trt_input = observations["rgb"]
            elif "depth" in observations:
                trt_input = observations["depth"]
            else:
                assert False

            self.accelerated_net.infer(
                trt_input.data_ptr(), torch.cuda.current_stream().cuda_stream
            )
            return self.ac.act_post_visual(
                self.accel_out,
                observations,
                rnn_hidden_states,
                prev_actions,
                masks,
                deterministic,
            )

    def get_value(self, observations, rnn_hidden_states, prev_actions, masks):
        observations = self._preprocess_obs(observations)
        return self.ac.get_value(observations, rnn_hidden_states, prev_actions, masks)

    def evaluate_actions(
        self, observations, rnn_hidden_states, prev_actions, masks, action,
    ):
        observations = self._preprocess_obs(observations)
        return self.ac.evaluate_actions(
            observations, rnn_hidden_states, prev_actions, masks, action
        )


class CriticHead(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc = nn.Linear(input_size, 1)

        self.layer_init()

    def layer_init(self):
        for m in self.modules():
            if hasattr(m, "reset_parameters"):
                m.reset_parameters()

            if isinstance(m, nn.Linear):
                m.weight.data *= 0.1 / torch.norm(
                    m.weight.data, p=2, dim=1, keepdim=True
                )
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x) -> Dict[str, torch.Tensor]:
        return {"value": self.fc(x)}


@baseline_registry.register_policy
class SimpleCNNPolicy(Policy):
    goal_sensor_uuid = "pointgoal_with_gps_compass"

    def __init__(
        self, observation_space, action_space, hidden_size=512, *args, **kwargs
    ):
        super().__init__(
            SimpleCNNNet(observation_space=observation_space, hidden_size=hidden_size),
            observation_space,
            action_space.n,
        )


class Net(nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        pass

    @property
    @abc.abstractmethod
    def output_size(self):
        pass

    @property
    @abc.abstractmethod
    def num_recurrent_layers(self):
        pass

    @property
    @abc.abstractmethod
    def is_blind(self):
        pass


class SimpleCNNNet(Net):
    r"""Network which passes the input image through CNN and concatenates
    goal vector with CNN's output and passes that through RNN.
    """

    def __init__(self, observation_space, hidden_size):
        super().__init__()

        if IntegratedPointGoalGPSAndCompassSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            ].shape[0]
        elif PointGoalSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                PointGoalSensor.cls_uuid
            ].shape[0]

        self._hidden_size = hidden_size

        self.visual_encoder = SimpleCNN(observation_space, hidden_size)

        self.state_encoder = RNNStateEncoder(
            (0 if self.is_blind else self._hidden_size) + self._n_input_goal,
            self._hidden_size,
        )

        self.train()

    @property
    def output_size(self):
        return self._hidden_size

    @property
    def is_blind(self):
        return self.visual_encoder.is_blind

    @property
    def num_recurrent_layers(self):
        return self.state_encoder.num_recurrent_layers

    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        if IntegratedPointGoalGPSAndCompassSensor.cls_uuid in observations:
            target_encoding = observations[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            ]

        elif PointGoalSensor.cls_uuid in observations:
            target_encoding = observations[PointGoalSensor.cls_uuid]

        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            x = [perception_embed] + x

        x = torch.cat(x, dim=1)
        x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks)

        return x, rnn_hidden_states
