"""
Modified networks with pytorch

To use during training
Each network has a unique output layer for each train task
to linearize the model, all other layers are shared among tasks
"""
import torch
from torch import nn as nn
from torch.nn import functional as F

from rlkit.policies.base import Policy
from rlkit.torch import pytorch_util as ptu
from rlkit.torch.core import PyTorchModule
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
from rlkit.torch.modules import LayerNorm


def identity(x):
    return x


class ModifiedMlp(PyTorchModule):
    def __init__(
            self,
            hidden_sizes,
            output_size,
            input_size,
            tasks_num,
            last_layer_dim,
            init_w=3e-3,
            hidden_activation=F.relu,
            output_activation=identity,
            hidden_init=ptu.fanin_init,
            b_init_value=0.1,
            layer_norm=False,
            layer_norm_kwargs=None,
            
    ):
        self.save_init_params(locals())
        super().__init__()

        if layer_norm_kwargs is None:
            layer_norm_kwargs = dict()

        self.tasks_num = tasks_num
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        self.layer_norm = layer_norm
        self.fcs = []
        self.layer_norms = []
        in_size = input_size
        self.last_layers = []

        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            in_size = next_size
            hidden_init(fc.weight)
            fc.bias.data.fill_(b_init_value)
            self.__setattr__("fc{}".format(i), fc)
            self.fcs.append(fc)

            if self.layer_norm:
                ln = LayerNorm(next_size)
                self.__setattr__("layer_norm{}".format(i), ln)
                self.layer_norms.append(ln)

        self.last_fc = nn.Linear(in_size, last_layer_dim)
        self.last_fc.weight.data.uniform_(-init_w, init_w)
        self.last_fc.bias.data.uniform_(-init_w, init_w)
        
        self.temp = nn.Linear(output_size, 1)
        self.temp.weight.data.uniform_(-init_w, init_w)
        self.temp.bias.data.uniform_(-init_w, init_w)
        
        for i in range(self.tasks_num):
            lc = nn.Linear(last_layer_dim, output_size)
            lc.weight.data.uniform_(-init_w, init_w)
            lc.bias.data.uniform_(-init_w, init_w)
            self.__setattr__("lc{}".format(i), lc)
            self.last_layers.append(lc)

    def forward(self, input, index, return_preactivations=False):
        
        outputs_task = []
        h = input
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            if self.layer_norm and i < len(self.fcs) - 1:
                h = self.layer_norms[i](h)
            h = self.hidden_activation(h)
        preactivation = self.last_fc(h)
        output = self.hidden_activation(preactivation)
        output = self.last_layers[index](output)

        if return_preactivations:
            return output, preactivation
        else:
            return output


class ModifiedFlattenMlp(ModifiedMlp):
    """
    if there are multiple inputs, concatenate along dim 1
    """

    def forward(self, index, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=1)
        return super().forward(flat_inputs, index, **kwargs)





