from .TransformerFunctions import PositionalEncoder
from ..embeddings.EmbeddingFunctions import IdentityEmbedding
from ..embeddings.EmbeddingWrapperFile import EmbeddingWrapper
from ...util.ProgressBar import PBar
from ...data.DataLoading import MyData
from ..interface_base.ModelFittingFunctions import ModelFitter
from ..loss.LossFunctions import LogCoshLoss
from ...data.PreProcessingFunctions import make_input_roll


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


import numpy as np
import pandas as pd



class BaseEncoderTransformerModel(nn.Module):
    '''
    An encoder focused base Transformer model.
    
    The basis for this model can be found at: 
    https://github.com/pytorch/examples/blob/master/word_language_model/model.py
    
    '''

    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, output_dim, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu'):
        
        super(BaseEncoderTransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.device = device
        self.out_mask = None
        self.embedding = EmbeddingWrapper(embedding_class, embedding_args)
        self.pos_encoder = PositionalEncoder(sequence_length, embedding_dim, device = self.device).to(self.device)
        encoder_layers = TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.sequence_length = sequence_length
        
        self.decoder = nn.Linear(sequence_length*embedding_dim, output_dim)
        self.init_weights()

        return

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, X, has_mask=True):
        

        
        out = self.embedding(X)
        out = self.pos_encoder(out)
        out = out.transpose(0,1)
        
        if has_mask:
            device = out.device
            if self.out_mask is None or self.out_mask.size(0) != len(out):
                mask = self._generate_square_subsequent_mask(len(out)).to(device)
                self.out_mask = mask
        else:
            self.out_mask = None
        

        out = self.transformer_encoder(out.float(), self.out_mask)
        out = torch.cat([x for x in out], dim = 1)
        
        out = self.decoder(out)
        
        return out




class BaseClassificationEncoderTransformerModel(BaseEncoderTransformerModel):
    '''
    An encoder focused Transformer model that outputs probabilities for a multi-class
    classification task. In contrast to SimpleEncoderModelTimeSum, this class uses a 
    masking matrix to force attention to work in the backwards direction only.
    
    The basis for this model can be found at: 
    https://github.com/pytorch/examples/blob/master/word_language_model/model.py
    
    Input should be of the shape: (N, S, E), where N is the number of samples, 
    S is the sequence length and E is the number of features (or embedding size).
    '''

    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, output_dim, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu'):
        
        super(BaseClassificationEncoderTransformerModel, self).__init__(sequence_length, embedding_dim, 
                                                             nhead, dim_feedforward, 
                                                             nlayers, output_dim, 
                                                             dropout, 
                                                             embedding_class, embedding_args, 
                                                             device)

        
        return
    
    def forward(self, X, has_mask=True):
        
        out = super(BaseClassificationEncoderTransformerModel, self).forward(X, has_mask=True)
        
        return F.softmax(out, dim=-1)


class PytorchTransformerWrapper(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, 
                                     num_decoder_layers, dim_feedforward, dropout, 
                                     activation, device = 'cpu'):
        super(PytorchTransformerWrapper, self).__init__()
        
        self.device = device
        
        self.pos_encoder = PositionalEncoder(1000, d_model, device = self.device).to(self.device)
        
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, 
                                     num_decoder_layers, dim_feedforward, dropout, 
                                     activation)
        return

    def forward(self, src, tgt, src_mask=None, 
                tgt_mask=None, memory_mask=None, 
                src_key_padding_mask=None, 
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        
        src = self.pos_encoder(src)
        src = src.transpose(0,1)
        src = src.float()
        
        tgt = tgt.transpose(0,1)
        
        out = self.transformer(src, tgt, src_mask, 
                tgt_mask, memory_mask, 
                src_key_padding_mask, 
                tgt_key_padding_mask, memory_key_padding_mask)
        
        out = out.transpose(0,1)
        
        return out
    
    
    def encode(self, src, src_mask = None, src_key_padding_mask = None):
        
        src = self.pos_encoder(src)
        src = src.transpose(0,1)
        src = src.float()
        
        out = self.transformer.encoder(src, src_mask, src_key_padding_mask)
        
        out = out.transpose(0,1)
        
        return out
    
    
    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None,
                  tgt_key_padding_mask=None,
                  memory_key_padding_mask=None):
        
        tgt = tgt.transpose(0,1)
        memory = memory.transpose(0,1)
        
        out = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                                          tgt_key_padding_mask=tgt_key_padding_mask,
                                          memory_key_padding_mask=memory_key_padding_mask)
        out = out.transpose(0,1)
        
        return out
    
    def generate_square_mask(self, size):
        '''
        This generates a square mask with shape (size,size). It has 0s on the diagonal
        and lower triangular, and -inf on the upper triangular. A 2x2 example:
        [[0, -inf]
         [0,    0]]
        
        Arguments
        ---------
            size: integer
                Size of the square matrix
        
        Returns
        ---------
            out: tensor
                Mask tensor.
        
        
        '''
        
        out = self.transformer.generate_square_subsequent_mask(size)
        
        return out
        




class BaseRegressionEncoderTransformerModel(BaseEncoderTransformerModel):
    '''
    An encoder focused Transformer model that outputs probabilities for a multi-class
    classification task. In contrast to SimpleEncoderModelTimeSum, this class uses a 
    masking matrix to force attention to work in the backwards direction only.
    
    The basis for this model can be found at: 
    https://github.com/pytorch/examples/blob/master/word_language_model/model.py
    
    Input should be of the shape: (N, S, E), where N is the number of samples, 
    S is the sequence length and E is the number of features (or embedding size).
    '''

    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, output_dim, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu'):
        
        super(BaseRegressionEncoderTransformerModel, self).__init__(sequence_length, embedding_dim, 
                                                             nhead, dim_feedforward, 
                                                             nlayers, output_dim, 
                                                             dropout, 
                                                             embedding_class, embedding_args, 
                                                             device)
        return
    
    def forward(self, X, has_mask=True):
        
        out = super(BaseRegressionEncoderTransformerModel, self).forward(X)
        
        return out



class ClassificationEncoderTransformerModel(ModelFitter):
    '''
    This class is a classification transformer enconder model interface. Please see available
    functions. 
    
    '''
    

    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, output_dim, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu', 
                 learning_rate = 0.01, epochs = 150, batch_size = 20, verbose = True):
        
        '''
        Arguments
        ---------
            sequence_length: integer
                This is the length of the input sequence.

            embedding_dim: integer
                This is the number of features or size of the embedding of 
                each data point in the sequence.

            nhead: integer
                This is the number of attention heads that will be used in the model.

            dim_feedforward: integer
                This is the dimension of the feedforward layer that is an argument in 
                torch.nn.TransformerEncoderLayer.

            nlayers: integer
                This is the number of sub encoder layers that is an argument in
                torch.nn.TransformerEncoder.

            dropout: float 0-1
                This is the dropout probability that is an argument in
                torch.nn.TransformerEncoder.

            embedding_class: torch.nn.Module subclass
                This class is used to perform an embedding on the data before it
                is passed into the transformer. This is separate from the positional
                encoder. By default this class is simply an identity function.

            embedding_args: dict
                This is a dictionary containing arguments for the embedding_class.

            device: string
                This is the name of the device for the pytorch model to be run on. By
                default this is 'cpu', however if you have CUDA set up, you can pass 'cuda'
                as the argument.

            learning_rate: float
                This is the learning rate for the training of the model. By default this
                is set to 0.01.

            epochs: integer
                This is the number of epochs the model will be trained using. By default this
                is set to 150.

            batch_size: integer
                This is the batch_size that will be used in the training. By default this is
                set to 20.

            verbose: bool
                This dictates whether the training information will be printed.

        '''
        
        self.model = BaseClassificationEncoderTransformerModel(sequence_length, embedding_dim, 
                                                             nhead, dim_feedforward, 
                                                             nlayers, output_dim, 
                                                             dropout, 
                                                             embedding_class, embedding_args, 
                                                             device)
        self.model = self.model.to(device)
        
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.batch_size = batch_size
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
        
        self.verbose = verbose
        self.device = device
        
        super(ClassificationEncoderTransformerModel, self).__init__(self.model, data_prepare_class = MyData)
        
        return 
    
    
    
    def fit(self, X, Y, X_val = None, Y_val = None, from_numpy = True):
        '''
        Fit transformer model.
        
        Arguments
        ---------
            X: numpy array
                This is an array containing the input. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size).
            
            Y: numpy array
                These are the targets. This must be of shape (N,1)
                
            X_val: numpy array
                This is the validation set that the model will be evaluated on as the model 
                is trained.
            
            Y_val: numpy array
                This is the validation targets set that the model will be evaluated on as the model 
                is trained.
        
        '''
        
        if from_numpy:
            X = torch.from_numpy(X).float()
            Y = torch.from_numpy(Y).long()
        
        if (not X_val is None) & (not Y_val is None): testing_too = True
        else: testing_too = False
        
        if testing_too and from_numpy:
            X_val = torch.from_numpy(X_val).float()
            Y_val = torch.from_numpy(Y_val).long()
        
        
        
        
        self.model = super(ClassificationEncoderTransformerModel, self).fit(X, Y, n_epochs = self.epochs, 
                                                                     criterion = self.criterion, 
                                                                     optimizer = self.optimizer, 
                                                                     data_params = {'batch_size': self.batch_size}, 
                                                                     verbose = self.verbose, 
                                                                     X_val = X_val, Y_val = Y_val)
        
        return self
    
    def decision_function(self, X, output_numpy = True, from_numpy = True):
        '''
        Passes X through the model to calculate the outputs.
        
        Arguments
        ---------
            X: numpy array
                This is the input that you want to pass through the model.
        
        Returns
        ---------
            out: numpy array:
                This is the output of the model.
        
        '''
        
        if from_numpy: X = torch.from_numpy(X)
        out = super(ClassificationEncoderTransformerModel, self).predict(X)
        if output_numpy: out = out.cpu().detach().numpy()

        return out
    
    def predict(self, X, output_numpy = True, from_numpy = True):
        '''
        Passes X through the model to calculate the outputs. This function also applies the
        argmax function to calculate the max value.
        
        Arguments
        ---------
            X: numpy array
                This is the input that you want to pass through the model.
        
        Returns
        ---------
            out: numpy array:
                This is the output of the model.
        
        '''
        
        if output_numpy:
            out = np.argmax(self.decision_function(X, output_numpy = output_numpy, from_numpy = from_numpy), axis = 1)
        else:
            out = torch.max(self.decision_function(X, output_numpy = output_numpy, from_numpy = from_numpy), 1, keepdim = True)

        return out
    
    

class RegressionEncoderTransformerModel(ModelFitter):
    '''
    This class is a regression transformer enconder model interface. Please see available
    functions. 
    
    '''
    

    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, output_dim, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu', 
                 learning_rate = 0.01, epochs = 150, batch_size = 20, verbose = True):
        
        '''
        Arguments
        ---------
            sequence_length: integer
                This is the length of the input sequence.

            embedding_dim: integer
                This is the number of features or size of the embedding of 
                each data point in the sequence.

            nhead: integer
                This is the number of attention heads that will be used in the model.

            dim_feedforward: integer
                This is the dimension of the feedforward layer that is an argument in 
                torch.nn.TransformerEncoderLayer.

            nlayers: integer
                This is the number of sub encoder layers that is an argument in
                torch.nn.TransformerEncoder.

            dropout: float 0-1
                This is the dropout probability that is an argument in
                torch.nn.TransformerEncoder.

            embedding_class: torch.nn.Module subclass
                This class is used to perform an embedding on the data before it
                is passed into the transformer. This is separate from the positional
                encoder. By default this class is simply an identity function.

            embedding_args: dict
                This is a dictionary containing arguments for the embedding_class.

            device: string
                This is the name of the device for the pytorch model to be run on. By
                default this is 'cpu', however if you have CUDA set up, you can pass 'cuda'
                as the argument.

            learning_rate: float
                This is the learning rate for the training of the model. By default this
                is set to 0.01.

            epochs: integer
                This is the number of epochs the model will be trained using. By default this
                is set to 150.

            batch_size: integer
                This is the batch_size that will be used in the training. By default this is
                set to 20.

            verbose: bool
                This dictates whether the training information will be printed.

        '''
        
        self.model = BaseRegressionEncoderTransformerModel(sequence_length, embedding_dim, 
                                                             nhead, dim_feedforward, 
                                                             nlayers, output_dim, 
                                                             dropout, 
                                                             embedding_class, embedding_args, 
                                                             device)
        self.model = self.model.to(device)
        
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.batch_size = batch_size
        
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.learning_rate)
        
        self.verbose = verbose
        self.device = device
        
        super(RegressionEncoderTransformerModel, self).__init__(self.model, data_prepare_class = MyData)
        
        return 
    
    
    
    def fit(self, X, Y, X_val = None, Y_val = None, from_numpy = True):
        '''
        Fit transformer model.
        
        Arguments
        ---------
            X: numpy array
                This is an array containing the input. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size).
            
            Y: numpy array
                These are the targets. This must be of shape (N, output_dim)
                
            X_val: numpy array
                This is the validation set that the model will be evaluated on as the model 
                is trained.
            
            Y_val: numpy array
                This is the validation targets set that the model will be evaluated on as the model 
                is trained.
        
        '''
        
        if from_numpy:
            X = torch.from_numpy(X).float()
            Y = torch.from_numpy(Y).float()
        
        if (not X_val is None) & (not Y_val is None): testing_too = True
        else: testing_too = False
        
        if testing_too and from_numpy:
            X_val = torch.from_numpy(X_val).float()
            Y_val = torch.from_numpy(Y_val).float()
        
        
        
        
        self.model = super(RegressionEncoderTransformerModel, self).fit(X, Y, n_epochs = self.epochs, 
                                                                     criterion = self.criterion, 
                                                                     optimizer = self.optimizer, 
                                                                     data_params = {'batch_size': self.batch_size, 
                                                                                    'shuffle': False}, 
                                                                     verbose = self.verbose, 
                                                                     X_val = X_val, Y_val = Y_val)
        
        return self
    
    def decision_function(self, X, output_numpy = True, from_numpy = True):
        '''
        Passes X through the model to calculate the outputs.
        
        Arguments
        ---------
            X: numpy array
                This is the input that you want to pass through the model.
        
        Returns
        ---------
            out: numpy array:
                This is the output of the model.
        
        '''
        
        if from_numpy: X = torch.from_numpy(X)
        out = super(RegressionEncoderTransformerModel, self).predict(X)
        if output_numpy: out = out.cpu().detach().numpy()

        return out
    


    def predict(self, X, output_numpy = True, from_numpy = True):
        '''
        Passes X through the model to calculate the outputs.
        
        Arguments
        ---------
            X: numpy array
                This is the input that you want to pass through the model.
        
        Returns
        ---------
            out: numpy array:
                This is the output of the model.
        
        '''
        
        out = self.decision_function(X, output_numpy = output_numpy, from_numpy = from_numpy)
        
        return out


class TransformerModel(ModelFitter):
    '''
    Simpler interface for the nn.Transformer model with ability to 
    fit and predict.
    
    '''
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 
                 activation='relu',
                 layer_norm_eps=1e-05, 
                 device='cpu',
                 embedding_class = IdentityEmbedding, embedding_args = {},
                 learning_rate = 0.01, epochs = 150, batch_size = 20, verbose = True):
        
        
        self.model = PytorchTransformerWrapper(d_model, nhead, num_encoder_layers, 
                                                 num_decoder_layers, dim_feedforward, dropout, 
                                                 activation, device)
        
        self.model = self.model.to(device)
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = device
        
        
        self.train_loss = {}
        self.test_loss = {}
        self.predicted = {}
        self.n_trains = -1
        self.data_prepare_class = MyData
        
        
        
        self.optimizer = torch.optim.Adam(list(self.model.parameters()), lr = self.learning_rate)
        self.criterion = nn.MSELoss()
        
        self.verbose = verbose
        
        super(TransformerModel, self).__init__(self.model, data_prepare_class = MyData)
        
        return
    
    
    def fit(self, X, Y = None, X_val = None, Y_val = None, sequence_length = 'auto',
           from_numpy = True):
        '''
        Fit transformer model.
        
        Arguments
        ---------
            X: numpy array
                This is an array containing the input. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size). Please ensure that this array
                is in order of time.
            
            Y: numpy array
                The targets. This should be of shape (N, T, E) where N is the number 
                of samples, S is the target sequence length and E is the
                number of features (or embedding size). Please ensure that this array
                is in order of time.
                
            X_val: numpy array
                This is the validation set that the model will be evaluated on as the model 
                is trained.
                
            Y_val: None
                Ignored
            
            sequence_length: string or integer
                'auto' allows the model to take the sequence length from X.shape[1].
                If an integer is supplied, it will use this value instead (which will 
                most likely cause errors!).
            
            from_numpy: bool
                If this argument is False, you may pass a tensor in place of the numpy
                arrays in arguments X and Y. This tensor must be the same shape as the
                numpy array would have been.
            
        
        '''
        
        
        
        if sequence_length == 'auto': sequence_length = X.shape[1]
        self.sequence_length = sequence_length
        
        # inserting start and end tokens to X and Y
        if from_numpy:
            X = np.insert(X, 0, 1, axis = 1)
            X = np.insert(X, X.shape[1], 0, axis = 1)
            Y = np.insert(Y, 0, 1, axis = 1)
            Y = np.insert(Y, Y.shape[1], 0, axis = 1)
            X = torch.from_numpy(X).float()
            Y = torch.from_numpy(Y).float()
        else:
            X = torch.cat([torch.ones((X.size(0), 1, X.size(2))), X], axis = 1)
            X = torch.cat([X, torch.zeros((X.size(0), 1, X.size(2)))], axis = 1)
            Y = torch.cat([torch.ones((Y.size(0), 1, Y.size(2))), Y], axis = 1)
            Y = torch.cat([Y, torch.zeros((Y.size(0), 1, Y.size(2)))], axis = 1)
        
        if (not X_val is None) & (not Y_val is None): testing_too = True
        else: testing_too = False
        

        
        if testing_too:
            Y_val = X_val[sequence_length:,:,:]
            X_val = X_val[:-sequence_length]

            # inserting start and end tokens to X and Y
            if from_numpy:
                X_val = np.insert(X_val, 0, 1, axis = 1)
                X_val = np.insert(X_val, X.shape[1], 0, axis = 1)
                Y_val = np.insert(Y_val, 0, 1, axis = 1)
                Y_val = np.insert(Y_val, Y.shape[1], 0, axis = 1)
                X_val = torch.from_numpy(X_val).float()
                Y_val = torch.from_numpy(Y_val).float()
            else:
                X_val = torch.cat([torch.ones((X_val.size(0), 1, X_val.size(2))), X_val], axis = 1)
                X_val = torch.cat([X_val, torch.zeros((X_val.size(0), 1, X_val.size(2)))], axis = 1)
                Y_val = torch.cat([torch.ones((Y_val.size(0), 1, Y_val.size(2))), Y_val], axis = 1)
                Y_val = torch.cat([Y_val, torch.zeros((Y_val.size(0), 1, Y_val.size(2)))], axis = 1)
        
        self.n_trains += 1
        
        
        
        if (not X_val is None) & (not Y_val is None): testing_too = True
        else: testing_too = False

        training_set = self.prepare_data(X, Y)
        training_generator = torch.utils.data.DataLoader(training_set, batch_size = self.batch_size, shuffle = False)
        
        if testing_too:
            testing_set = self.prepare_data(X_val, Y_val)
            testing_generator = torch.utils.data.DataLoader(testing_set, batch_size = self.batch_size, shuffle = False)

        train_loss_temp = []
        test_loss_temp = []
        
        epoch_bar = PBar(show_length = 20, n_iterations = self.epochs)
        print_threshold = 0

        self.model.train()

        for epoch in range(self.epochs):
            training_loss = 0
            test_loss = 0

            for nd, data in enumerate(training_generator):
                
                
                # putting data into the model here
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                labels_input = labels[:,:-1,:]
                
                #generating a mask
                labels_mask = self.model.generate_square_mask(labels_input.size(1)).to(self.device)
                labels_outputs = self.model(inputs, labels_input, tgt_mask = labels_mask)
                labels_true = labels[:,1:,:]
                
                # calculating the loss and stepping here
                loss = self.criterion(labels_outputs, labels_true)
                loss.backward()
                self.optimizer.step()
                training_loss += loss.item() * inputs.size(0)

            epoch_loss = training_loss/training_set.__len__()
            train_loss_temp.append(epoch_loss)
            
            # validation loss
            self.model.eval()
            if testing_too:
                with torch.no_grad():
                        for data in testing_generator:
                            
                            # putting data into the model here
                            inputs, labels = data
                            inputs, labels = inputs.to(self.device), labels.to(self.device)
                            self.optimizer.zero_grad()
                            labels_input = labels[:,:-1,:]
                            
                            #generating a mask
                            labels_mask = self.model.generate_square_mask(labels_input.size(1)).to(self.device)
                            labels_outputs = self.model(inputs, labels_input, tgt_mask = labels_mask)
                            labels_true = labels[:,1:,:]
                            
                            # calculating the loss and stepping here
                            loss = self.criterion(labels_outputs, labels_true)
                            test_loss += loss.item() * inputs.size(0)
                epoch_test_loss = test_loss/testing_set.__len__()
                test_loss_temp.append(epoch_test_loss)
            self.model.train()

            
            epoch_bar.update(1)
            bar = epoch_bar.give()
            if (epoch+1)/self.epochs >= print_threshold:
                printing_statement = 'Epochs: {}. epoch {} done. Loss per train sample: {:.2f}.'.format(bar,epoch+1,epoch_loss)
                if testing_too: 
                    printing_statement += ' Loss per test sample: {:.2f}.'.format(epoch_test_loss)
                if self.verbose: 
                    print(printing_statement)
                print_threshold += 0.2
        self.train_loss[self.n_trains] = np.asarray(train_loss_temp)
        self.test_loss[self.n_trains] = np.asarray(test_loss_temp)

        return self

    
    def decision_function(self, X, output_numpy = True, from_numpy = True, predict_sequence_length = 'same'):
        '''
        Passes X through the model to calculate the outputs.
        
        Arguments
        ---------
            X: numpy array or tensor
                This is the input that you want to pass through the model. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size). If passing a tensor, please use from_numpy = False.
            
            output_numpy: bool
                This dictates whether the model will output a numpy array or a tensor.
            
            from_numpy: bool
                This tells the function whether the inputted X is a numpy array or a tensor. It will 
                then do conversions or not based on this.
            
            predict_sequence_length: integer or string
                This is the sequence length for the prediction. 'same' means that the sequence
                length will be the same as the input sequence length. An integer argument 
                means the output sequence will be of this length.
            
        Returns
        ---------
            out: numpy array or tensor
                This is the output of the model. It will be a tensor or an array depending on the 
                value of the argument output_numpy.
        
        '''
        if self.n_trains == -1:
            raise TypeError('Please fit the model first by calling .fit(X).')
        
        
        if predict_sequence_length == 'same':
            predict_sequence_length = self.sequence_length
        
        # adding start and end values
        if from_numpy: 
            
            X = np.insert(X, 0, 1, axis = 1)
            X = np.insert(X, X.shape[1], 0, axis = 1)
            X = torch.from_numpy(X)
            
        else:
            X = torch.cat([torch.ones((X.size(0), 1, X.size(2))), X], axis = 1)
            X = torch.cat([X, torch.zeros((X.size(0), 1, X.size(2)))], axis = 1)
        
        
        out = torch.zeros(1, predict_sequence_length, X.size(2)).to('cpu')
        
        tgt_input = torch.ones(X.size(0),1,X.size(2))
        
        predicting_set = self.prepare_data(X, tgt_input)
        predicting_generator = torch.utils.data.DataLoader(predicting_set, batch_size = self.batch_size)
        
        with torch.no_grad():
            self.model.eval()
            for nd, data in enumerate(predicting_generator):
                inputs, tgt_input = data
                inputs, tgt_input = inputs.to(self.device), tgt_input.to(self.device)
                memory_input = self.model.encode(inputs)
                for i in range(predict_sequence_length):
                    tgt_mask = self.model.generate_square_mask(tgt_input.size(1)).to(self.device)
                    tgt_output = self.model.decode(tgt_input, memory_input, tgt_mask = tgt_mask)
                    tgt_input = torch.cat([tgt_input, tgt_output[:,-1,:].unsqueeze(1)], axis = 1)
                out = torch.cat([out,tgt_input[:,1:,:].to('cpu')], axis = 0)
            self.model.train()
            
            out = out[1:]
            
            if output_numpy: out = out.detach().numpy()

        return out
    
    def predict(self, X, output_numpy = True, from_numpy = True):
        '''
        Passes X through the model to calculate the outputs.
        
        Arguments
        ---------
            X: numpy array
                This is the input that you want to pass through the model.
        
        Returns
        ---------
            out: numpy array:
                This is the output of the model.
        
        '''
        
        out = self.decision_function(X, output_numpy = output_numpy, from_numpy = from_numpy)
        
        return out


class ODRegressionEncoderTransformerModel(RegressionEncoderTransformerModel):
    '''
    This class is an outlier detection model based on the encoding part of a transformer. 
    This model works by calculating the mean loss of predictions and then using a 
    normal distribution to measure the probability of deviation from this result 
    for each time point.

    You will be able to access the following attributes after the model has been fit and
    the decision_function() has been run:

        .mean_loss: 
            This is the mean loss across the features.
        
        .std_loss: 
            This is the std of the loss across the features.
        
        .predictions: 
            These are the predictions made by the model on the self-supervised task.
        
        .loss: 
            This is the loss between the predictions and the true values on the self-supervised task.
        
        .likelihood_raw: 
            This is the outlier scores for each feature in each data point. This might be 
            useful when testing which feature makes a particular point an outlier.
        
        .scores:
            These are the outlier scores for the data points.


    '''
    def __init__(self, sequence_length, embedding_dim, nhead, dim_feedforward, 
                 nlayers, dropout=0.5, 
                 embedding_class = IdentityEmbedding, embedding_args = {}, device = 'cpu', 
                 learning_rate = 0.01, epochs = 150, batch_size = 20, verbose = True):
        '''
        Arguments
        ---------
            sequence_length: integer
                This is the length of the input sequence.

            embedding_dim: integer
                This is the number of features or size of the embedding of 
                each data point in the sequence.

            nhead: integer
                This is the number of attention heads that will be used in the model.

            dim_feedforward: integer
                This is the dimension of the feedforward layer that is an argument in 
                torch.nn.TransformerEncoderLayer.

            nlayers: integer
                This is the number of sub encoder layers that is an argument in
                torch.nn.TransformerEncoder.

            dropout: float 0-1
                This is the dropout probability that is an argument in
                torch.nn.TransformerEncoder.

            embedding_class: torch.nn.Module subclass
                This class is used to perform an embedding on the data before it
                is passed into the transformer. This is separate from the positional
                encoder. By default this class is simply an identity function.

            embedding_args: dict
                This is a dictionary containing arguments for the embedding_class.

            device: string
                This is the name of the device for the pytorch model to be run on. By
                default this is 'cpu', however if you have CUDA set up, you can pass 'cuda'
                as the argument.

            learning_rate: float
                This is the learning rate for the training of the model. By default this
                is set to 0.01.

            epochs: integer
                This is the number of epochs the model will be trained using. By default this
                is set to 150.

            batch_size: integer
                This is the batch_size that will be used in the training. By default this is
                set to 20.

            verbose: bool
                This dictates whether the training information will be printed.
        '''
        
        super(ODRegressionEncoderTransformerModel, self).__init__(sequence_length, embedding_dim, 
                                                             nhead, dim_feedforward, 
                                                             nlayers, embedding_dim, 
                                                             dropout, 
                                                             embedding_class, embedding_args, 
                                                             device, learning_rate, epochs, 
                                                           batch_size, verbose)

        self.mean_loss = None
        self.std_loss = None
        self.predictions = None
        self.loss = None
        self.likelihood_raw = None
        self.scores = None
        self.sequence_length = sequence_length

        return
    
    
    def fit(self, X, Y= None):
        '''
        Arguments
        ---------
            X: numpy array
                This is the data set of sequences that the outlier detection
                model will fit on. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size). 
            
            Y: None
                This argument is ignored since the model is unsupervised.
            
        '''
        
        X_buf = np.insert(X, 0, np.zeros((self.sequence_length-1, X.shape[1])), axis = 0)
        X_seq = make_input_roll(X_buf, self.sequence_length)


        Y_f = X_seq[1:, -1, :]
        X_f = X_seq[:-1, :, :]
        
        X_f = torch.from_numpy(X_f).float()
        Y_f = torch.from_numpy(Y_f).float()
        
        super(ODRegressionEncoderTransformerModel, self).fit(X_f, Y_f, from_numpy = False)
        
        predictions = super(ODRegressionEncoderTransformerModel, self).decision_function(X_f, 
                                                                                  output_numpy = False, 
                                                                                  from_numpy = False)
        Y_f = Y_f.to('cpu')
        criterion = nn.MSELoss(reduction = 'none')
        loss = criterion(Y_f, predictions)
        self.mean_loss = torch.mean(loss, axis = 0).detach().numpy()
        self.std_loss = torch.std(loss, dim = 0).detach().numpy()
        
        return self
    
    
    
    def decision_function(self, X):
        '''
        Arguments
        ---------
            X: numpy array
                This is the data set of sequences to calculate outlier scores on, given the 
                training results.
        
        Returns
        ---------
            out: numpy array
                This is a numpy array of shape (N,) which contains the outlier scores.
        
        '''
        
        from scipy.stats import norm
        
        X_buf = np.insert(X, 0, np.zeros((self.sequence_length-1, X.shape[1])), axis = 0)
        X_seq = make_input_roll(X_buf, self.sequence_length)


        Y_f = X_seq[1:, -1, :]
        X_f = X_seq[:-1, :, :]
        
        X_f = torch.from_numpy(X_f).float()
        Y_f = torch.from_numpy(Y_f).float()
        
        predictions = super(ODRegressionEncoderTransformerModel, self).decision_function(X_f, 
                                                                                  output_numpy = False, 
                                                                                  from_numpy = False)
        Y_f = Y_f.to('cpu')
        criterion = nn.MSELoss(reduction = 'none')
        loss = criterion(Y_f, predictions)
        loss = loss.detach().numpy()
        self.loss = loss
        standard_loss = (loss - self.mean_loss)/self.std_loss
        likelihood = norm.cdf(standard_loss)
        
        self.likelihood_raw = likelihood

        out = np.mean(likelihood, axis = 1)
        
        self.scores = out
        self.predictions = predictions.cpu().detach().numpy()
        
        return out
    
    def predict(self, X, top_k = 10):
        '''
        This function uses the model to calculate the largest top_k outliers.
        
        Arguments
        ---------
            X: numpy array
                This is the array you wish to find the outliers for.
            
            top_k: integer
                This is the number of outliers you want returned.
        '''
        
        scores = self.decision_function(X)
        
        i_outliers = np.argpartition(scores, -top_k)[-top_k:]
        i_outliers = i_outliers[np.argsort(scores[i_outliers])[::-1]]
        
        outliers = np.zeros(X.shape[0])
        outliers[i_outliers] = 1
        
        self.outliers = outliers
        
        return outliers

    def predict_proba(self,X):
        '''
        This function uses the model to calculate the outlier scores.
        
        Arguments
        ---------
            X: numpy array
                This is the array you wish to find the outliers for.
        '''

        scores = self.decision_function(X)

        
        out = np.zeros(X.shape[0])
        out[:scores.shape[0]] = scores
        
        self.out = out
        
        return out


  
class ODTransformerModel(TransformerModel):
    '''
    This model is an outlier detection model based on the full transformer model. This is
    built on the pytorch nn.Transformer module.
    
    '''
    
    def __init__(self, sequence_length, predict_sequence_length, d_model, 
                 nhead, num_encoder_layers=2, 
                 num_decoder_layers=2, dim_feedforward=30, dropout=0.1, 
                 activation='relu',
                 layer_norm_eps=1e-05, 
                 device='cpu',
                 embedding_class = IdentityEmbedding, embedding_args = {},
                 learning_rate = 0.01, epochs = 150, batch_size = 20, verbose = True):
        
        
        super(ODTransformerModel, self).__init__(d_model, nhead, num_encoder_layers, 
                 num_decoder_layers, dim_feedforward, dropout, 
                 activation,
                 layer_norm_eps, 
                 device,
                 embedding_class, embedding_args,
                 learning_rate, epochs, batch_size, verbose)
        
        self.sequence_length = sequence_length
        self.mean_loss = None
        self.std_loss = None
        self.predictions = None
        self.loss = None
        self.likelihood_raw = None
        self.scores = None
        self.predict_sequence_length = predict_sequence_length
        self.device = device

        return
        
    def fit(self, X, Y = None):
        '''
        Arguments
        ---------
            X: numpy array
                This is the data set of sequences that the outlier detection
                model will fit on. It must be of shape (N, S, E),
                where N is the number of samples, S is the sequence length and E is the
                number of features (or embedding size). 
            
            Y: None
                This argument is ignored since the model is unsupervised.
            
        '''
        
        
        X_buf = np.insert(X, 0, np.zeros((self.sequence_length-1, X.shape[1])), axis = 0)
        X_seq = make_input_roll(X_buf, self.sequence_length)


        Y_f = X_seq[self.predict_sequence_length:, -self.predict_sequence_length:, :]
        X_f = X_seq[:-self.predict_sequence_length, :, :]
        
        Y_f = torch.from_numpy(Y_f).float()
        X_f = torch.from_numpy(X_f).float()
        
        super(ODTransformerModel, self).fit(X_f,Y_f, from_numpy = False)
        
        predictions = super(ODTransformerModel, self).decision_function(X_f, 
                                                                        from_numpy = False, 
                                                                        output_numpy = False,
                                                                        predict_sequence_length = self.predict_sequence_length)
        
        Y_f = Y_f.to('cpu')
        criterion = nn.MSELoss(reduction = 'none')
        loss = criterion(Y_f, predictions)
        self.mean_loss = torch.mean(torch.mean(loss, axis = 1), axis = 0).cpu().detach().numpy()
        self.std_loss = torch.std(torch.mean(loss, axis = 1), axis = 0).cpu().detach().numpy()
        
        return self
    
    def decision_function(self, X):
        '''
        Arguments
        ---------
            X: numpy array
                This is the data set of sequences to calculate outlier scores on, given the 
                training results.
        
        Returns
        ---------
            out: numpy array
                This is a numpy array of shape (N,) which contains the outlier scores.
        
        '''
        
        from scipy.stats import norm
        
        X_buf = np.insert(X, 0, np.zeros((self.sequence_length-1, X.shape[1])), axis = 0)
        X_seq = make_input_roll(X_buf, self.sequence_length)


        Y_f = X_seq[self.predict_sequence_length:, -self.predict_sequence_length:, :]
        X_f = X_seq[:-self.predict_sequence_length, :, :]
        
        Y_f = torch.from_numpy(Y_f).float()
        X_f = torch.from_numpy(X_f).float()
        
        predictions = super(ODTransformerModel, self).decision_function(X_f, 
                                                                          output_numpy = False, 
                                                                          from_numpy = False, 
                                                                          predict_sequence_length = self.predict_sequence_length)
        
        
        criterion = nn.MSELoss(reduction = 'none')
        loss = criterion(Y_f, predictions)
        loss = loss.cpu().detach().numpy()
        self.loss = loss
        loss = np.mean(loss, axis = 1)
        standard_loss = (loss - self.mean_loss)/self.std_loss
        likelihood = norm.cdf(standard_loss)
        
        self.likelihood_raw = likelihood

        out = np.mean(likelihood, axis = 1)
        
        self.scores = out
        self.predictions = predictions.cpu().detach().numpy()
        
        return out
    
    
    
    def predict(self, X, top_k = 10):
        '''
        This function uses the model to calculate the largest top_k outliers.
        
        Arguments
        ---------
            X: numpy array
                This is the array you wish to find the outliers for.
            
            top_k: integer
                This is the number of outliers you want returned.
        '''
        
        scores = self.decision_function(X)
        i_outliers = np.argpartition(scores, -top_k)[-top_k:]
        i_outliers = i_outliers[np.argsort(scores[i_outliers])[::-1]]
        
        outliers = np.zeros(X.shape[0])
        outliers[i_outliers] = 1
        
        self.outliers = outliers
        
        return outliers

    def predict_proba(self,X):
        '''
        This function uses the model to calculate the outlier scores.
        
        Arguments
        ---------
            X: numpy array
                This is the array you wish to find the outliers for.
        '''

        scores = self.decision_function(X)

        
        out = np.zeros(X.shape[0])
        out[:scores.shape[0]] = scores
        
        self.out = out
        
        return out