"""
General networks for pytorch.

Algorithm-specific networks should go else-where.
"""
import torch
from torch import nn as nn
from torch.nn import functional as F

# from rlkit.policies.base import Policy
from torchkit import pytorch_utils as ptu
from torchkit.core import PyTorchModule
from torchkit.modules import LayerNorm
# from rlkit.torchkit.data_management.normalizer import TorchFixedNormalizer


class Mlp(PyTorchModule):
    def __init__(
            self,
            hidden_sizes,
            output_size,
            input_size,
            init_w=3e-3,
            hidden_activation=F.relu,
            output_activation=ptu.identity,
            hidden_init=ptu.fanin_init,
            b_init_value=0.1,
            layer_norm=False,
            layer_norm_kwargs=None,
    ):
        self.save_init_params(locals())
        super().__init__()

        if layer_norm_kwargs is None:
            layer_norm_kwargs = dict()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        self.layer_norm = layer_norm
        self.fcs = []
        self.layer_norms = []
        in_size = input_size

        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            # nn.init.uniform_(fc.weight, -1, 1)
            in_size = next_size
            hidden_init(fc.weight)
            fc.bias.data.fill_(b_init_value)
            self.__setattr__("fc{}".format(i), fc)
            self.fcs.append(fc)

            if self.layer_norm:
                ln = LayerNorm(next_size)
                self.__setattr__("layer_norm{}".format(i), ln)
                self.layer_norms.append(ln)

        self.last_fc = nn.Linear(in_size, output_size)
        self.last_fc.weight.data.uniform_(-init_w, init_w)
        self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, input, return_preactivations=False):
        h = input
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            if self.layer_norm and i < len(self.fcs) - 1:
                h = self.layer_norms[i](h)
            h = self.hidden_activation(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)
        if return_preactivations:
            return output, preactivation
        else:
            return output


class FlattenMlp(Mlp):
    """
    if there are multiple inputs, concatenate along dim 1
    """

    def forward(self, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=1)
        return super().forward(flat_inputs, **kwargs)


# class MlpPolicy(Mlp, Policy):
#     """
#     A simpler interface for creating policies.
#     """
#
#     def __init__(
#             self,
#             *args,
#             obs_normalizer: TorchFixedNormalizer = None,
#             **kwargs
#     ):
#         self.save_init_params(locals())
#         super().__init__(*args, **kwargs)
#         self.obs_normalizer = obs_normalizer
#
#     def forward(self, obs, **kwargs):
#         if self.obs_normalizer:
#             obs = self.obs_normalizer.normalize(obs)
#         return super().forward(obs, **kwargs)
#
#     def get_action(self, obs_np):
#         actions = self.get_actions(obs_np[None])
#         return actions[0, :], {}
#
#     def get_actions(self, obs):
#         return self.eval_np(obs)


# class TanhMlpPolicy(MlpPolicy):
#     """
#     A helper class since most policies have a tanh output activation.
#     """
#     def __init__(self, *args, **kwargs):
#         self.save_init_params(locals())
#         super().__init__(*args, output_activation=torchkit.tanh, **kwargs)


# class MlpEncoder(FlattenMlp):
#     '''
#     encode context via MLP
#     '''
#
#     def reset(self, num_tasks=1):
#         pass


# class RecurrentEncoder(FlattenMlp):
#     '''
#     encode context via recurrent network
#     '''
#
#     def __init__(self,
#                  *args,
#                  **kwargs
#     ):
#         self.save_init_params(locals())
#         super().__init__(*args, **kwargs)
#         self.hidden_dim = self.hidden_sizes[-1]
#         self.register_buffer('hidden', torchkit.zeros(1, 1, self.hidden_dim))
#
#         # input should be (task, seq, feat) and hidden should be (task, 1, feat)
#
#         self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim, num_layers=1, batch_first=True)
#
#     def forward(self, in_, return_preactivations=False, return_sequence=False):
#         '''
#
#         :param in_:
#         :param return_preactivations:
#         :param return_sequence: returns outputs at all times
#         :return:
#         '''
#         # expects inputs of dimension (task, seq, feat)
#         task, seq, feat = in_.size()
#         out = in_.view(task * seq, feat)
#
#         # embed with MLP
#         for i, fc in enumerate(self.fcs):
#             out = fc(out)
#             out = self.hidden_activation(out)
#
#         out = out.view(task, seq, -1)
#         out, (hn, cn) = self.lstm(out, (self.hidden, torchkit.zeros(self.hidden.size()).to(ptu.device)))
#         self.hidden = hn
#         if return_sequence:
#             out = out.contiguous().view(task * seq, -1)
#         else:
#             # take the last hidden state to predict z
#             out = out[:, -1, :]
#
#         # output layer
#         preactivation = self.last_fc(out)
#         output = self.output_activation(preactivation)
#         if return_sequence:
#             preactivation = preactivation.view(task, seq, -1)
#             output = output.view(task, seq, -1)
#
#         if return_preactivations:
#             return output, preactivation
#         else:
#             return output
#
#     def reset(self, num_tasks=1):
#         self.hidden = self.hidden.new_full((1, num_tasks, self.hidden_dim), 0)
