"""
This script provides high-level definitions of our models.

These models are based on those defined in rebase_rnn but they include a linear output layer.
"""

import torch
from torch.nn import Module, Embedding, Linear, Dropout, Sequential
from src.rnn import RNN, GRU, LSTM, LINEAR, REGRESSION
import numpy as np

# convert a string kwarg into a model constructor
__RNNS__ = {"RNN_TANH": RNN, "GRU": GRU, "LSTM": LSTM, "LINEAR": LINEAR, "REGRESSION": REGRESSION}

# different types of datasets
_datasets_wt_embedding = ['ptb','penn_treebank']
_datasets_wt_onehot = ['copy_memory','cm', 'temporal_ordering', 'to', 'rcm','random_copy_memory']
_datasets_direct = ['pmnist', 'permuted_mnist','not','nottingham','pcifar','permuted_cifar']


# determine the proper type of model given the dataset and construct it
def get_model(dataset, **kwargs):

    output_only_last = dataset in ['pmnist', 'permuted_mnist', 'temporal_ordering','to','pcifar','permuted_cifar']
    esn = kwargs.pop("esn")
    _datasets_direct.extend(_datasets_wt_onehot)

    if dataset in _datasets_direct:
        model = DirectModel(output_only_last,**kwargs)
    else:
        raise ValueError
    if esn:
        for param in model.rnn.parameters():
            param.requires_grad = False

    return model

def get_h0(model):
    num_directions = 2 if model.rnn.bidirectional else 1
    model.rnn.mt19937.state = model.rnn.hh_seed
    random = np.random.RandomState(model.rnn.mt19937)
    rand = random.rand(model.rnn.num_layers * num_directions,
                       1, model.rnn.hidden_size)
    return rand

# constructs the desired model from rebase_rnn and adds a linear output layer
class DirectModel(Module):

    def __init__(self, output_only_last, *,
                 alpha:float=1.+1e-3,
                 architecture:str='RNN_TANH',
                 batch_first=False,
                 dropout:float=0,
                 gain:float=1.,
                 # gradient_clipping:float=0.,
                 hidden_size:int=128,
                 initializer:str='orthogonal',
                 input_size:int=52,
                 max_angle:float=2.,
                 min_angle:float=0.,
                 num_layers:int=1,
                 output_size:int=10,
                 ):

        super(DirectModel, self).__init__()

        self.output_only_last = output_only_last

        rnn = __RNNS__[architecture]
        self.rnn = rnn(input_size=input_size,
                       hidden_size=hidden_size,
                       num_layers=num_layers,
                       init=initializer,
                       alpha=alpha,
                       min_angle=min_angle,
                       max_angle=max_angle,
                       gain=gain,
                       dropout=dropout,
                       batch_first=batch_first)

        if dropout == 0:
            self.classifier = Linear(hidden_size, output_size)
        else:
            self.classifier = Sequential(Dropout(p=dropout),Linear(hidden_size, output_size))


    def forward(self, x, hidden_init=None, return_hidden=False, return_for_inspection=False):

        output, hidden = self.rnn(x, hx=hidden_init)
        if self.rnn.batch_first and self.output_only_last:
            yhat = self.classifier(output[:, -1, ...])
        elif self.rnn.batch_first == False and self.output_only_last:
            yhat = self.classifier(output[-1, ...])
        elif self.rnn.batch_first == True:
            yhat = self.classifier(output)
            yhat = yhat.permute([0,2,1])
        else:
            yhat = self.classifier(output)
            yhat = yhat.permute([1,2,0])
        if return_hidden:
            return yhat, hidden
        if return_for_inspection:
            return yhat, output

        return yhat
