from vprnn.layers import VanillaCell
from keras.layers import InputLayer, Dense, RNN, TimeDistributed, Bidirectional, LSTM
from keras.models import Sequential
from .layers import FastVanillaCell, FastSimpleCell

from keras.utils.generic_utils import CustomObjectScope


class SGORNNModel(Sequential):
    def __init__(self, *args,
                 layers=1, dim=128,
                 rots=7, activation='relu',
                 output_dim=1,
                 input_dim=2,
                 output_activation='linear',
                 clip_scalar=True,
                 return_sequences=False,
                 bidirectional=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.add(InputLayer((None, input_dim)))
        custom_objs = {'FastVanillaCell': FastVanillaCell}
        for _ in range(layers - 1):
            if bidirectional:
                with CustomObjectScope(custom_objs):
                    self.add(Bidirectional(RNN(FastVanillaCell(dim,
                                                               clip_scalar=clip_scalar,
                                                               n_rotations=rots,
                                                               activation=activation),
                                               return_sequences=True)))
            else:
                self.add(RNN(FastVanillaCell(dim,
                                             clip_scalar=clip_scalar,
                                             n_rotations=rots,
                                             activation=activation),
                             return_sequences=True))
        if bidirectional:
            with CustomObjectScope(custom_objs):
                self.add(Bidirectional(RNN(FastVanillaCell(dim,
                                                           clip_scalar=clip_scalar,
                                                           n_rotations=rots,
                                                           activation=activation),
                                           return_sequences=return_sequences)))
        else:
            self.add(RNN(FastVanillaCell(dim,
                                         clip_scalar=clip_scalar,
                                         n_rotations=rots,
                                         activation=activation),
                         return_sequences=return_sequences))
        if return_sequences:
            self.add(TimeDistributed(Dense(output_dim, activation=output_activation)))
        else:
            self.add(Dense(output_dim, activation=output_activation))


class LSTMModel(Sequential):
    def __init__(self, *args,
                 layers=1, dim=128,
                 rots=7, activation='tanh',
                 output_dim=1,
                 input_dim=2,
                 output_activation='linear',
                 clip_scalar=True,
                 return_sequences=False,
                 bidirectional=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.add(InputLayer((None, input_dim)))
        for _ in range(layers - 1):
            if bidirectional:
                self.add(Bidirectional(LSTM(dim, return_sequences=True)))
            else:
                self.add(LSTM(dim, return_sequences=True))
        if bidirectional:
            self.add(Bidirectional(LSTM(dim, return_sequences=return_sequences)))
        else:
            self.add(LSTM(dim, return_sequences=return_sequences))
        if return_sequences:
            self.add(TimeDistributed(Dense(output_dim, activation=output_activation)))
        else:
            self.add(Dense(output_dim, activation=output_activation))


class FastRNNModel(Sequential):
    def __init__(self, *args,
                 layers=1, dim=128,
                 activation='relu',
                 output_dim=1,
                 input_dim=2,
                 output_activation='linear',
                 rots=None,
                 clip_scalar=True,
                 return_sequences=False,
                 bidirectional=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.add(InputLayer((None, input_dim)))
        custom_objs = {'FastSimpleCell': FastSimpleCell}
        for _ in range(layers - 1):
            if bidirectional:
                with CustomObjectScope(custom_objs):
                    self.add(Bidirectional(RNN(FastSimpleCell(dim,
                                                              activation=activation),
                                               return_sequences=True)))
            else:
                self.add(RNN(FastSimpleCell(dim,
                                            activation=activation),
                             return_sequences=True))
        if bidirectional:
            with CustomObjectScope(custom_objs):
                self.add(Bidirectional(RNN(FastSimpleCell(dim,
                                                          activation=activation),
                                           return_sequences=return_sequences)))
        else:
            self.add(RNN(FastSimpleCell(dim,
                                        activation=activation),
                         return_sequences=return_sequences))
        if return_sequences:
            self.add(TimeDistributed(Dense(output_dim, activation=output_activation)))
        else:
            self.add(Dense(output_dim, activation=output_activation))


class VPRNNModel(Sequential):
    def __init__(self, *args,
                 layers=1, dim=128,
                 rots=7, activation='relu',
                 output_dim=1,
                 input_dim=2,
                 clip_scalar=True,
                 output_activation='linear',
                 return_sequences=False,
                 bidirectional=False,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.add(InputLayer((None, input_dim)))
        custom_objs = {'VanillaCell': VanillaCell}
        for _ in range(layers - 1):
            if bidirectional:
                with CustomObjectScope(custom_objs):
                    self.add(Bidirectional(RNN(VanillaCell(dim,
                                                           n_rotations=rots,
                                                           activation=activation),
                                               return_sequences=True)))
            else:
                self.add(RNN(VanillaCell(dim,
                                         n_rotations=rots,
                                         activation=activation),
                             return_sequences=True))
        if bidirectional:
            with CustomObjectScope(custom_objs):
                self.add(Bidirectional(RNN(VanillaCell(dim,
                                                       n_rotations=rots,
                                                       activation=activation),
                                           return_sequences=return_sequences)))
        else:
            self.add(RNN(VanillaCell(dim,
                                     n_rotations=rots,
                                     activation=activation),
                         return_sequences=return_sequences))
        if return_sequences:
            self.add(TimeDistributed(Dense(output_dim, activation=output_activation)))
        else:
            self.add(Dense(output_dim, activation=output_activation))
