import chainer
from chainer import Link, Chain, ChainList
import chainer.functions as F
from chainer.link_hooks import SpectralNormalization
import chainer.links as L
from source import yaml_utils as yu
from source import yaml_utils2 as yu2


class MLP(Chain):

    def __init__(self, input_dim=None, hidden_dim=None, out_dim=None, **flags):
        super(MLP, self).__init__()
        
        self.n_layers = flags['n_layers']
        self.in_dims = [input_dim]  + (self.n_layers - 1) * [hidden_dim]
        self.out_dims = (self.n_layers - 1) * [hidden_dim] + [out_dim]
        self.activation = yu.make_function(F, flags['activation'])
        self.bn_flag = flags['bn_flag']
        self.use_gamma = flags['bn_use_gamma']
        self.use_beta = flags['bn_use_beta']
        self.ln_flag = flags['ln_flag']
        self.sn_flag = flags['sn_flag']
        self.sn_factor = flags['sn_factor']
        self.nobias = flags['nobias']
        with self.init_scope():
            w = yu2.make_instance(chainer.initializers, flags['initializer'])
            
            layer_list = []
            for i in range(self.n_layers-1):
                layer_list.append(L.Linear(self.in_dims[i],
                                           self.out_dims[i],
                                           initialW=w, nobias=self.nobias))
                if self.bn_flag:
                    layer_list.append(L.BatchNormalization(
                        self.out_dims[i], use_gamma=self.use_gamma,
                        use_beta=self.use_beta))
                if self.ln_flag:
                    layer_list.append(L.LayerNormalization(
                        self.out_dims[i]))
            self.layers = chainer.ChainList(*layer_list)
            
            self.last_layer = L.Linear(self.in_dims[self.n_layers-1],
                                       self.out_dims[self.n_layers-1],
                                       initialW=w, nobias=True)
            
            if self.sn_flag:
                for k in range(self.n_layers-1):
                    self.layers[k].add_hook(SpectralNormalization(
                        factor=self.sn_factor, n_power_iteration=1))
                
                self.last_layer.add_hook(SpectralNormalization(
                    factor=self.sn_factor, n_power_iteration=1))
                
                    
    def __call__(self, x):
        h = x
        if self.bn_flag or self.ln_flag:
            for k in range(self.n_layers - 1):
                h = self.activation(self.layers[2*k+1](self.layers[2*k](h)))
        else:
            for k in range(self.n_layers - 1):
                h = self.activation(self.layers[k](h))
        return self.last_layer(h)
