"""
General networks for pytorch.

Algorithm-specific networks should go else-where.
"""
import torch
from torch import nn as nn
from torch.nn import functional as F

from rlkit.policies.base import Policy
from rlkit.torch import pytorch_util as ptu
from rlkit.torch.core import PyTorchModule
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
from rlkit.torch.modules import LayerNorm


def identity(x):
    return x


class Mlp(PyTorchModule):
    def __init__(
            self,
            hidden_sizes,
            output_size,
            input_size,
            init_w=3e-3,
            hidden_activation=F.relu,
            output_activation=identity,
            hidden_init=ptu.fanin_init,
            b_init_value=0.1,
            layer_norm=False,
            layer_norm_kwargs=None,
    ):
        self.save_init_params(locals())
        super().__init__()

        if layer_norm_kwargs is None:
            layer_norm_kwargs = dict()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        self.layer_norm = layer_norm
        self.fcs = []
        self.layer_norms = []
        in_size = input_size

        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            in_size = next_size
            hidden_init(fc.weight)
            fc.bias.data.fill_(b_init_value)
            self.__setattr__("fc{}".format(i), fc)
            self.fcs.append(fc)

            if self.layer_norm:
                ln = LayerNorm(next_size)
                self.__setattr__("layer_norm{}".format(i), ln)
                self.layer_norms.append(ln)

        self.last_fc = nn.Linear(in_size, output_size)
        self.last_fc.weight.data.uniform_(-init_w, init_w)
        self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, input, return_preactivations=False):
        h = input
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            if self.layer_norm and i < len(self.fcs) - 1:
                h = self.layer_norms[i](h)
            h = self.hidden_activation(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)
        if return_preactivations:
            return output, preactivation
        else:
            return output

OUT_DIM = {2: 39, 4: 35, 6: 31}
OUT_DIM_64 = {2: 29, 4: 25, 6: 21}
OUT_DIM_108 = {4: 47}

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)

class CnnEncoder(PyTorchModule):
    def __init__(self, obs_shape, num_layers=2, num_filters=32,
            hidden_dim=1024, output_logits=False,):
        self.save_init_params(locals())
        super().__init__()

        self.obs_shape = obs_shape
        self.num_layers = num_layers
        self.output_logits = output_logits

        self.convs = nn.ModuleList(
            [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
        )
        for i in range(num_layers - 1):
            self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))

        self.apply(weight_init)

    def forward(self, obs):
        if obs.max() > 1.:
            obs = obs / 255.
        conv = torch.relu(self.convs[0](obs))

        for i in range(1, self.num_layers):
            conv = torch.relu(self.convs[i](conv))

        """
        h = conv.view(conv.size(0), -1)
        h = self.ln(self.fc(h))
        # maybe don't have layer norm here, also maybe don't have tanh
        return h
        """
        return conv.view(conv.size(0), -1)

class CnnContextEncoder(PyTorchModule):
    def __init__(self, obs_shape, action_dim, embedding_dim, num_layers=2, num_filters=32,
            hidden_dim=1024, latent_dim=5, output_logits=False,):
        self.save_init_params(locals())
        super().__init__()

        self.obs_shape = obs_shape
        self.num_layers = num_layers
        self.output_logits = output_logits
        # self.image_embedding_encoder = image_embedding_encoder

        if obs_shape[-1] == 108:
            out_dim = OUT_DIM_108[num_layers]
        elif obs_shape[-1] == 64:
            out_dim = OUT_DIM_64[num_layers]
        else:
            out_dim = OUT_DIM[num_layers]

        self.fc_layers = nn.Sequential(
            nn.Linear(num_filters * out_dim * out_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
        )

        self.trunk = nn.Sequential(
            # context is obs + action + reward
            nn.Linear(embedding_dim + action_dim + 1, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)
        )

        self.apply(weight_init)

    def get_image_embedding(self, obs):
        # h = self.image_embedding_encoder(obs)
        # h = self.fc_layers(h)
        h = self.fc_layers(obs)
        return h

    def forward(self, obs, action, reward):
        # h = self.image_embedding_encoder(obs)
        # h = self.fc_layers(h)
        h = self.fc_layers(obs)

        out = torch.cat([h, action, reward], dim=1)

        out = self.trunk(h)
        return out

    def context_forward(self, context):
        out = self.trunk(context)
        return out

    def reset(self, num_tasks=1):
        pass

class CnnPolicyNetwork(PyTorchModule):
    def __init__(self, obs_shape, action_shape, latent_dim, feature_dim, num_layers=2, num_filters=32,
            hidden_dim=1024, output_logits=False,):
        self.save_init_params(locals())
        super().__init__()

        # self.image_embedding_encoder = image_embedding_encoder

        if obs_shape[-1] == 108:
            out_dim = OUT_DIM_108[num_layers]
        elif obs_shape[-1] == 64:
            out_dim = OUT_DIM_64[num_layers]
        else:
            out_dim = OUT_DIM[num_layers]

        self.fc_layers = nn.Sequential(
            nn.Linear(num_filters * out_dim * out_dim, feature_dim),
            nn.LayerNorm(feature_dim),
        )

        self.trunk = nn.Sequential(
            nn.Linear(feature_dim + latent_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2 * action_shape)
        )
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.latent_dim = latent_dim
        self.feature_dim = feature_dim
        self.output_logits = output_logits

        self.apply(weight_init)


    def forward(self, obs, z):
        # h = self.image_embedding_encoder(obs)
        # h = self.fc_layers(h)
        h = self.fc_layers(obs)

        out = h
        out = torch.cat([out, z], dim=1)
        out = self.trunk(out)
        # add action stuff here or in CnnPolicy
        return out

class CnnQFunction(PyTorchModule):
    def __init__(self, obs_shape, action_shape, latent_dim, feature_dim, num_layers=2,
            num_filters=32, hidden_dim=1024, output_logits=False):
        self.save_init_params(locals())
        super().__init__()

        # self.image_embedding_encoder = image_embedding_encoder

        if obs_shape[-1] == 108:
            out_dim = OUT_DIM_108[num_layers]
        elif obs_shape[-1] == 64:
            out_dim = OUT_DIM_64[num_layers]
        else:
            out_dim = OUT_DIM[num_layers]

        self.fc_layers = nn.Sequential(
            nn.Linear(num_filters * out_dim * out_dim, feature_dim),
            nn.LayerNorm(feature_dim),
        )

        self.trunk = nn.Sequential(
            nn.Linear(feature_dim + action_shape + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.apply(weight_init)

    def forward(self, obs, action, z, detach=False):
        # encoded = self.image_embedding_encoder(obs)
        # encoded = self.fc_layers(encoded)
        h = self.fc_layers(obs)
        obs_action_z = torch.cat([h, action, z], dim=1)
        out = self.trunk(obs_action_z)
        if detach:
            out = out.detach()
        return out

class CnnVf(PyTorchModule):
    def __init__(self, obs_shape, latent_dim, feature_dim, num_layers=2, num_filters=32,
            hidden_dim=1024, output_logits=False):
        self.save_init_params(locals())
        super().__init__()

        # self.image_embedding_encoder = image_embedding_encoder

        if obs_shape[-1] == 108:
            out_dim = OUT_DIM_108[num_layers]
        elif obs_shape[-1] == 64:
            out_dim = OUT_DIM_64[num_layers]
        else:
            out_dim = OUT_DIM[num_layers]

        self.fc_layers = nn.Sequential(
            nn.Linear(num_filters * out_dim * out_dim, feature_dim),
            nn.LayerNorm(feature_dim),
        )

        self.trunk = nn.Sequential(
            nn.Linear(feature_dim + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.apply(weight_init)

    def forward(self, obs, z):
        # encoded = self.encoder(obs)
        # encoded = self.fc_layers(encoded)
        h = self.fc_layers(obs)
        obs_z = torch.cat([h, z], dim=1)
        out = self.trunk(obs_z)
        return out

class FlattenMlp(Mlp):
    """
    if there are multiple inputs, concatenate along dim 1
    """

    def forward(self, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=1)
        return super().forward(flat_inputs, **kwargs)


class MlpPolicy(Mlp, Policy):
    """
    A simpler interface for creating policies.
    """

    def __init__(
            self,
            *args,
            obs_normalizer: TorchFixedNormalizer = None,
            **kwargs
    ):
        self.save_init_params(locals())
        super().__init__(*args, **kwargs)
        self.obs_normalizer = obs_normalizer

    def forward(self, obs, **kwargs):
        if self.obs_normalizer:
            obs = self.obs_normalizer.normalize(obs)
        return super().forward(obs, **kwargs)

    def get_action(self, obs_np):
        actions = self.get_actions(obs_np[None])
        return actions[0, :], {}

    def get_actions(self, obs):
        return self.eval_np(obs)


class TanhMlpPolicy(MlpPolicy):
    """
    A helper class since most policies have a tanh output activation.
    """
    def __init__(self, *args, **kwargs):
        self.save_init_params(locals())
        super().__init__(*args, output_activation=torch.tanh, **kwargs)


class MlpEncoder(FlattenMlp):
    '''
    encode context via MLP
    '''

    def reset(self, num_tasks=1):
        pass


class RecurrentEncoder(FlattenMlp):
    '''
    encode context via recurrent network
    '''

    def __init__(self,
                 *args,
                 **kwargs
    ):
        self.save_init_params(locals())
        super().__init__(*args, **kwargs)
        self.hidden_dim = self.hidden_sizes[-1]
        self.register_buffer('hidden', torch.zeros(1, 1, self.hidden_dim))

        # input should be (task, seq, feat) and hidden should be (task, 1, feat)

        self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim, num_layers=1, batch_first=True)

    def forward(self, in_, return_preactivations=False):
        # expects inputs of dimension (task, seq, feat)
        task, seq, feat = in_.size()
        out = in_.view(task * seq, feat)

        # embed with MLP
        for i, fc in enumerate(self.fcs):
            out = fc(out)
            out = self.hidden_activation(out)

        out = out.view(task, seq, -1)
        out, (hn, cn) = self.lstm(out, (self.hidden, torch.zeros(self.hidden.size()).to(ptu.device)))
        self.hidden = hn
        # take the last hidden state to predict z
        out = out[:, -1, :]

        # output layer
        preactivation = self.last_fc(out)
        output = self.output_activation(preactivation)
        if return_preactivations:
            return output, preactivation
        else:
            return output

    def reset(self, num_tasks=1):
        self.hidden = self.hidden.new_full((1, num_tasks, self.hidden_dim), 0)
