
import numpy as np
import random
import os, errno
import sys
from tqdm import trange

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F


class lstm_encoder(nn.Module):
    ''' Encodes time-series sequence '''

    def __init__(self, input_size, hidden_size, num_layers = 1):
        
        '''
        : param input_size:     the number of features in the input X
        : param hidden_size:    the number of features in the hidden state h
        : param num_layers:     number of recurrent layers (i.e., 2 means there are
        :                       2 stacked LSTMs)
        '''
        
        super(lstm_encoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # define LSTM layer
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers, batch_first=True)

    def forward(self, x_input):
        
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence;
        :                              hidden gives the hidden state and cell state for the last
        :                              element in the sequence 
        '''
        
        lstm_out, self.hidden = self.lstm(x_input.view(x_input.shape[0], x_input.shape[1], self.input_size))
        
        return lstm_out, self.hidden     
    
    def init_hidden(self, batch_size):
        
        '''
        initialize hidden state
        : param batch_size:    x_input.shape[1]
        : return:              zeroed hidden state and cell state 
        '''
        
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))


class lstm_decoder(nn.Module):
    ''' Decodes hidden state output by encoder '''
    
    def __init__(self, input_size, hidden_size, num_layers = 1):

        '''
        : param input_size:     the number of features in the input X
        : param hidden_size:    the number of features in the hidden state h
        : param num_layers:     number of recurrent layers (i.e., 2 means there are
        :                       2 stacked LSTMs)
        '''
        
        super(lstm_decoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)           

    def forward(self, x_input, encoder_hidden_states):
        
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence;
        :                                   hidden gives the hidden state and cell state for the last
        :                                   element in the sequence 
        '''
        lstm_out, self.hidden = self.lstm(x_input.unsqueeze(1), encoder_hidden_states)
        output = self.linear(lstm_out.squeeze(1))     
        
        return output, self.hidden

class lstm_seq2seq(nn.Module):
    ''' train LSTM encoder-decoder and make predictions '''
    
    def __init__(self, input_size=2, output_size=2, embedding_size=24, num_layers=1, target_len = 24):

        '''
        : param input_size:     the number of expected features in the input X
        : param hidden_size:    the number of features in the hidden state h
        '''

        super(lstm_seq2seq, self).__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = embedding_size
        self.target_len = target_len
        self.encoder = lstm_encoder(input_size = input_size, hidden_size = embedding_size, num_layers = num_layers)
        self.decoder = lstm_decoder(input_size = output_size, hidden_size = embedding_size, num_layers = num_layers)
        self.X = None
        self.Y = None
        self.loss_fn = None
        self.loss = None




    def train_model(self, input_tensor, target_tensor, n_epochs=100, batch_size=150, training_prediction = 'recursive', teacher_forcing_ratio = 0.5, learning_rate = 0.01, dynamic_tf = False):
        
        '''
        train lstm encoder-decoder
        
        : param input_tensor:              input data with shape (seq_len, # in batch, number features); PyTorch tensor    
        : param target_tensor:             target data with shape (seq_len, # in batch, number features); PyTorch tensor
        : param n_epochs:                  number of epochs 
        : param target_len:                number of values to predict 
        : param batch_size:                number of samples per gradient update
        : param training_prediction:       type of prediction to make during training ('recursive', 'teacher_forcing', or
        :                                  'mixed_teacher_forcing'); default is 'recursive'
        : param teacher_forcing_ratio:     float [0, 1) indicating how much teacher forcing to use when
        :                                  training_prediction = 'teacher_forcing.' For each batch in training, we generate a random
        :                                  number. If the random number is less than teacher_forcing_ratio, we use teacher forcing.
        :                                  Otherwise, we predict recursively. If teacher_forcing_ratio = 1, we train only using
        :                                  teacher forcing.
        : param learning_rate:             float >= 0; learning rate
        : param dynamic_tf:                use dynamic teacher forcing (True/False); dynamic teacher forcing
        :                                  reduces the amount of teacher forcing for each epoch
        : return losses:                   array of loss function for each epoch
        '''
        
        # initialize array of losses 
        losses = np.full(n_epochs, np.nan)

        self.X = input_tensor
        self.Y = target_tensor

        optimizer = optim.Adam(self.parameters(), lr = learning_rate)
        criterion = nn.MSELoss()
        self.loss_fn = criterion

        # calculate number of batch iterations
        n_batches = int(input_tensor.shape[0] / batch_size)


        with trange(n_epochs) as tr:
            for it in tr:
                
                batch_loss = 0.
                batch_loss_tf = 0.
                batch_loss_no_tf = 0.
                num_tf = 0
                num_no_tf = 0

                for b in range(n_batches):
                    # select data 
                    input_batch = input_tensor[b * batch_size: (b+1) * batch_size, :,  :]
                    #print('input_batch', input_batch.shape)
                    target_batch = target_tensor[b * batch_size: (b+1) * batch_size, :,  :]

                    # outputs tensor
                    #outputs = torch.zeros(batch_size, self.target_len, self.output_size)

                    # zero the gradient
                    optimizer.zero_grad()
                    
                    outputs = self.predict(input_batch)

                    # compute the loss 
                    loss = criterion(outputs, target_batch)
                    batch_loss += loss.item()
                    
                    # backpropagation
                    self.loss = loss
                    loss.backward()
                    optimizer.step()

                # loss for epoch 
                batch_loss /= n_batches 
                losses[it] = batch_loss

                # dynamic teacher forcing
                if dynamic_tf and teacher_forcing_ratio > 0:
                    teacher_forcing_ratio = teacher_forcing_ratio - 0.02 

                # progress bar 
                tr.set_postfix(loss="{0:.3f}".format(batch_loss))
                    
        return losses

    def predict(self, input_tensor):
        
        '''
        : param input_tensor:      input data (seq_len, input_size); PyTorch tensor 
        : param target_len:        number of target values to predict 
        : return np_outputs:       tensor containing predicted values; prediction done recursively 
        '''

        # encode input_tensor
        #input_tensor = input_tensor.unsqueeze(1)     # add in batch size of 1

        # initialize hidden state
        encoder_hidden = self.encoder.init_hidden(input_tensor.shape[0])

        encoder_output, encoder_hidden = self.encoder(input_tensor)

        # initialize tensor for predictions
        outputs = []

        # decode input_tensor
        decoder_input = input_tensor[:, -1, :self.output_size]
        decoder_hidden = encoder_hidden
        
        for t in range(self.target_len):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs.append(decoder_output.unsqueeze(1))
            decoder_input = decoder_output
            
        
        return torch.cat(outputs,dim=1)


# encoder outputs
'''
encoder_output, encoder_hidden = self.encoder(input_batch)

# decoder with teacher forcing
decoder_input = input_batch[:, -1, :self.output_size]   # shape: (batch_size, input_size)
decoder_hidden = encoder_hidden



if training_prediction == 'recursive':
    # predict recursively
    for t in range(self.target_len): 

        decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
        outputs[:,t,:] = decoder_output
        decoder_input = decoder_output

if training_prediction == 'teacher_forcing':
    # use teacher forcing
    if random.random() < teacher_forcing_ratio:
        for t in range(self.target_len): 
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs[:, t] = decoder_output
            decoder_input = target_batch[t, :, :]

    # predict recursively 
    else:
        for t in range(self.target_len): 
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs[:, t] = decoder_output
            decoder_input = decoder_output

if training_prediction == 'mixed_teacher_forcing':
    # predict using mixed teacher forcing
    for t in range(self.target_len):
        decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
        outputs[:, t] = decoder_output
        
        # predict with teacher forcing
        if random.random() < teacher_forcing_ratio:
            decoder_input = target_batch[:, t, :]
        
        # predict recursively 
        else:
            decoder_input = decoder_output
'''
