'''
Article : https://arxiv.org/abs/1704.02971
Source : https://github.com/KurochkinAlexey/DA-RNN
'''
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F

import os
from torch.utils.data import IterableDataset
from zipfile import ZipFile
import pickle as pkl

from torch.utils.tensorboard import SummaryWriter

def direction_accuracy(pred, true):
    pred_adj = pred.reshape((-1))-1
    true_adj = true.reshape((-1))-1

    pred_adj = torch.where(pred_adj<0, -1, 1)
    true_adj = torch.where(true_adj<0, -1, 1)

    count = 0
    for i in range(pred_adj.shape[0]):
        if torch.equal(pred_adj[i], true_adj[i]):
            count+=1
    return count/pred_adj.shape[0]

def mape(pred, true):
    err = torch.abs(torch.add(true, -pred)/true).view(-1).mean() * 100
    return err.item()

class EarlyStopping():
    """
    source : https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

config = {
    "n_features" : 1, #Number of exogenous features
    "lstm_units_encoder" : 64, #Number of units of lstm encoder
    "lstm_units_decoder" : 64, #Number of units of lstm decoder
    "lag" : 1, #Timesteps of features
    "path" : "Weights/path"
}

class InputAttentionEncoder(nn.Module):
    def __init__(self, config):
        """
        :param: N: int
            number of time series
        :param: M:
            number of LSTM units
        :param: T:
            number of timesteps
        :param: stateful:
            decides whether to initialize cell state of new time window with values of the last cell state
            of previous time window or to initialize it with zeros
        """
        super(self.__class__, self).__init__()
        self.N = config["n_features"]
        self.M = config["lstm_units_encoder"]
        self.T = config["lag"]
        
        self.encoder_lstm = nn.LSTMCell(input_size=self.N, hidden_size=self.M)
        
        #equation 8 matrices
        
        self.W_e = nn.Linear(2*self.M, self.T)
        self.U_e = nn.Linear(self.T, self.T, bias=False)
        self.v_e = nn.Linear(self.T, 1, bias=False)
    
    def forward(self, inputs):
        encoded_inputs = torch.zeros((inputs.size(0), self.T, self.M)).cuda()
        
        #initiale hidden states
        h_tm1 = torch.zeros((inputs.size(0), self.M)).cuda()
        s_tm1 = torch.zeros((inputs.size(0), self.M)).cuda()
        
        for t in range(self.T):
            #concatenate hidden states
            h_c_concat = torch.cat((h_tm1, s_tm1), dim=1)
            
            #attention weights for each k in N (equation 8)
            x = self.W_e(h_c_concat).unsqueeze_(1).repeat(1, self.N, 1)
            y = self.U_e(inputs.permute(0, 2, 1))
            z = torch.tanh(x + y)
            e_k_t = torch.squeeze(self.v_e(z))
        
            #normalize attention weights (equation 9)
            alpha_k_t = F.softmax(e_k_t, dim=1)
            
            #weight inputs (equation 10)
            weighted_inputs = alpha_k_t * inputs[:, t, :] 
    
            #calculate next hidden states (equation 11)
            h_tm1, s_tm1 = self.encoder_lstm(weighted_inputs, (h_tm1, s_tm1))
            
            encoded_inputs[:, t, :] = h_tm1
        return encoded_inputs

class TemporalAttentionDecoder(nn.Module):
    def __init__(self, config):
        """
        :param: M: int
            number of encoder LSTM units
        :param: P:
            number of deocder LSTM units
        :param: T:
            number of timesteps
        :param: stateful:
            decides whether to initialize cell state of new time window with values of the last cell state
            of previous time window or to initialize it with zeros
        """
        super(self.__class__, self).__init__()
        self.M = config["lstm_units_encoder"]
        self.P = config["lstm_units_decoder"]
        self.T = config["lag"]
        
        self.decoder_lstm = nn.LSTMCell(input_size=1, hidden_size=self.P)
        
        #equation 12 matrices
        self.W_d = nn.Linear(2*self.P, self.M)
        self.U_d = nn.Linear(self.M, self.M, bias=False)
        self.v_d = nn.Linear(self.M, 1, bias = False)
        
        #equation 15 matrix
        self.w_tilda = nn.Linear(self.M + 1, 1)
        
        #equation 22 matrices
        self.W_y = nn.Linear(self.P + self.M, self.P)
        self.v_y = nn.Linear(self.P, 1)
        
    def forward(self, encoded_inputs, y):
        
        #initializing hidden states
        d_tm1 = torch.zeros((encoded_inputs.size(0), self.P)).cuda()
        s_prime_tm1 = torch.zeros((encoded_inputs.size(0), self.P)).cuda()
        for t in range(self.T):
            #concatenate hidden states
            d_s_prime_concat = torch.cat((d_tm1, s_prime_tm1), dim=1)
            #print(d_s_prime_concat)
            #temporal attention weights (equation 12)
            x1 = self.W_d(d_s_prime_concat).unsqueeze_(1).repeat(1, encoded_inputs.shape[1], 1)
            y1 = self.U_d(encoded_inputs)
            z1 = torch.tanh(x1 + y1)
            l_i_t = self.v_d(z1)
            
            #normalized attention weights (equation 13)
            beta_i_t = F.softmax(l_i_t, dim=1)
            
            #create context vector (equation_14)
            c_t = torch.sum(beta_i_t * encoded_inputs, dim=1)
            
            #concatenate c_t and y_t
            y_c_concat = torch.cat((c_t, y[:, t, :]), dim=1)
            #create y_tilda
            y_tilda_t = self.w_tilda(y_c_concat)
            
            #calculate next hidden states (equation 16)
            d_tm1, s_prime_tm1 = self.decoder_lstm(y_tilda_t, (d_tm1, s_prime_tm1))
        
        #concatenate context vector at step T and hidden state at step T
        d_c_concat = torch.cat((d_tm1, c_t), dim=1)

        #calculate output
        y_Tp1 = self.v_y(self.W_y(d_c_concat))
        return y_Tp1

class DualAttentionRNN(nn.Module):
    #-------------------__INIT__--------------------
    def __init__(self, config):
        super(self.__class__, self).__init__()

        self.N = config["n_features"]
        self.M = config["lstm_units_encoder"]
        self.P = config["lstm_units_decoder"]
        self.T = config["lag"]
        self.path = config["path"]

        self.encoder = InputAttentionEncoder(config).cuda()
        self.decoder = TemporalAttentionDecoder(config).cuda()

    #-------------------FORWARD--------------------
    def forward(self, X_history, y_history):
        out = self.decoder(self.encoder(X_history), y_history)
        return out
    
    #--------------------SAVE_MODEL--------------------
    def save_model(self):
        torch.save(self.state_dict(), self.path)

    #--------------------SAVE_MODEL--------------------
    def save_model_finetuning(self, path):
        torch.save(self.state_dict(), path)

    #--------------------LOAD_MODEL--------------------
    def load_model(self):
        self.load_state_dict(torch.load(self.path))
    
    #--------------------LOAD_MODEL--------------------
    def load_model_finetuning(self, path):
        self.load_state_dict(torch.load(path))

    #--------------------COMPUTE_VALIDATION--------------------
    def compute_validation(self, val_data, loss_fn, metric, device, target):
        '''
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform evaluation
        target (int) :
            Target serie to forecast
        '''
        self.eval()
        with torch.no_grad():
            batch = 0
            val_err = 0
            val_l = 0
            val_acc = 0
            val_mape = 0
            Y_true = []
            Y_pred = []
            for X_batch in val_data:
                batch+=1
                b, l, f = X_batch.shape

                c = [0,1,2,4]
                exo_batch = X_batch[:,0:l-1, c].float().to(device)
                endo_batch = X_batch[:, 0:l-1, target].reshape(b, l-1, 1).float().to(device)
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)

                outputs = self(exo_batch, endo_batch)
                outputs = outputs.to(device)

                loss = loss_fn(outputs, Y_batch)

                val_l+=loss.item()

                val_err+= metric(outputs, Y_batch)

                val_acc+= direction_accuracy(outputs, Y_batch)

                val_mape += mape(outputs, Y_batch)

            val_loss = round(val_l/batch, 8)

            val_metric = round(val_err/batch, 8)

            val_accuracy = round(val_acc/batch, 8)
            
            val_MAPE = round(val_mape/batch, 8)

            return val_loss, val_metric, val_accuracy, val_MAPE
    
    #--------------------PREDICT--------------------
    def predict_index(self, data_days, metric, target, device, save_path=''):
        '''
        data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        device (torch.device) :
            Device where to perform evaluation
        '''
        with torch.no_grad():
            if os.path.exists(save_path)==False:
                os.mkdir(save_path)
            self.to(device)
            self.eval()
            self.load_model()
            err=0
            Mape=0
            step = 0
            for d in data_days:
                data = torch.utils.data.DataLoader(d, batch_size=40)
                Y_pred = np.empty((0,), dtype=np.float32)
                Y_true = np.empty((0,), dtype=np.float32)
                Y_names = []
                for X_batch, X_names in data:
                    b, l, f = X_batch.shape
                    c = [0,1,2,4]
                    exo_batch = X_batch[:,0:l-1, c].float().to(device)
                    endo_batch = X_batch[:, 0:l-1, target].reshape(b, l-1, 1).float().to(device)
                    Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)
                    try:
                        outputs = self(exo_batch, endo_batch)
                    except:
                        continue
                    outputs = outputs.to(device)

                    err+= metric(outputs, Y_batch)

                    Mape += mape(outputs, Y_batch)

                    Y_true = np.append(Y_true, Y_batch.cpu().numpy().reshape((b,)))
                    Y_pred = np.append(Y_pred, outputs.cpu().numpy().reshape((b,)))
                    Y_names = Y_names + list(X_names)

                    step+=1
                table = pd.DataFrame({"Name" : Y_names, "True" : Y_true, "Pred" : Y_pred})
                table.to_csv(f'{save_path}/{data.dataset.day}.csv', index=False)
            print("RMSE : ", round(err/step, 8))
            print('MAPE : ', round(Mape/step, 8))

    #-------------------TRAINING_PHASE--------------------
    def training_phase(self, train_data, epochs, optimizer, loss_fn, metric, device, target, batch_start=1, batch_end=40000, val_data=None, val_step=35866, start_val_count=1, log_dir=None, load=False):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            DataLoader containing batches of train inputs and outputs.
        epochs (int) :
            Number of epochs of training.
        optimizer (torch.optim) :
            Optimizer i.e. algorithm used for computing gradient.
        loss_fn (torch.nn.Loss) :
            Loss function the model need to optimized.
        metric (function) :
            A function taking (pred, true) in argument and return a scalar which gauge the training phase.
        device (torch.device) :
            The device where the training need to be performed.
        '''
        self.train()
        self.to(device)
        writer = SummaryWriter(log_dir)
        step = batch_start
        val_counter = start_val_count
        if load == True:
            self.load_model()
        for i in range(epochs):
            batch = 0
            err = 0
            total_loss = 0

            for X_batch in train_data:
                batch+=1
                if batch<batch_start:
                    continue
                if batch>batch_end:
                    self.save_model()
                    break

                b, l, f = X_batch.shape

                c = [0,1,2,4]
                exo_batch = X_batch[:,0:l-1, c].float().to(device)
                endo_batch = X_batch[:, 0:l-1, target].reshape(b, l-1, 1).float().to(device)
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)

                optimizer.zero_grad()
                outputs = self(exo_batch, endo_batch)
                outputs = outputs.to(device)

                loss = loss_fn(outputs, Y_batch)
                total_loss+=loss.item()

                err += metric(outputs, Y_batch)

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/BatchEnd - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch

                writer.add_scalar("Accuracy/Train/BatchEnd - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction

                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs, Y_batch), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch
                
                step+=1

                if batch%val_step == 0:

                    val_loss, val_metric, val_accuracy, val_mape = self.compute_validation(val_data, loss_fn, metric, device, target)
                    self.train()

                    writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

                    writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_metric, val_counter) #Val Metric at end of epoch

                    writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_accuracy, val_counter) #Val Accuracy at end of epoch

                    writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_mape, val_counter)

                    val_counter+=1
            self.save_model()
        writer.close()
    
    def finetuning(self, train_data, epochs, optimizer, loss_fn, metric, device, target, patience, min_delta, val_data=None, log_dir=None, load_path="", save_path=""):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            Iterable training dataset
        epochs (int) <0 :
            Number of training epochs
        optimizer (torch.optim) :
            Algorithm for perfom optimization
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform training
        target (int) :
            Target serie to forecast
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        '''
        self.to(device)
        self.train()
        writer = SummaryWriter(log_dir)
        step = 1
        val_counter = 0
        self.load_model_finetuning(load_path)
        early_stop = EarlyStopping(patience, min_delta)

        val_loss, val_metric, val_accuracy, val_MAPE = self.compute_validation(val_data, loss_fn, metric, device, target)
        self.train()

        early_stop(val_loss)

        if early_stop.counter==0:
            self.save_model_finetuning(save_path)

        writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

        writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric

        writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy

        writer.add_scalar("MAPE/Val/ValStep", val_MAPE, val_counter)
        
        val_counter+=1

        for i in range(epochs):
            batch = 0
            err = 0
            total_loss = 0

            for X_batch in train_data:
                batch+=1

                b, l, f = X_batch.shape

                c = [0,1,2,4]
                exo_batch = X_batch[:,0:l-1, c].float().to(device)
                endo_batch = X_batch[:, 0:l-1, target].reshape(b, l-1, 1).float().to(device)
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)

                optimizer.zero_grad()
                outputs = self(exo_batch, endo_batch)
                outputs = outputs.to(device)

                loss = loss_fn(outputs, Y_batch)
                total_loss+=loss.item()

                err += metric(outputs, Y_batch)

                loss.backward()
                optimizer.step()


                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/BatchEnd - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch

                writer.add_scalar("Accuracy/Train/BatchEnd - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction

                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs, Y_batch), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch

                step+=1

            val_loss, val_metric, val_accuracy, val_MAPE = self.compute_validation(val_data, loss_fn, metric, device, target)
            self.train()

            writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

            writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric

            writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy

            writer.add_scalar("MAPE/Val/ValStep", val_MAPE, val_counter)

            val_counter+=1

            early_stop(val_loss)

            if early_stop.counter==0:
                self.save_model_finetuning(save_path)

            if early_stop.early_stop==True:
                break

        writer.close()

class TimeSeriesIterableDataset(IterableDataset):
    #-----------------------------------__init__
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDataset).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_zip = os.listdir(self.dir)

    def __iter__(self):
        for zips in self.list_of_zip:
            with ZipFile(self.dir+'/'+zips) as z:
                filelist = z.namelist()

            if self.phase=="train":
                filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
            
            if self.phase=="val":
                filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
            
            if self.phase=="test":
                filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

            for filename in filelist:
                with ZipFile(self.dir+'/'+zips, 'r') as z: #Open zipfile
                    with z.open(filename) as file: #Open pickle file
                        data = pkl.load(file)
                        file.close()
                for i in range(len(data)):
                    d = data[i].replace([np.nan, np.inf], 1).values
                    if d[:,3].max()>2 or d[:,3].min()<0.5:
                        continue
                    yield d

class TimeSeriesIterableDatasetFinetuning(IterableDataset):
    #--------------------__INIT__--------------------
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDatasetFinetuning).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_file = os.listdir(self.dir)
        
    #--------------------__ITER__--------------------
    def __iter__(self):
        filelist = self.list_of_file

        if self.phase=="train":
            filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
        
        if self.phase=="val":
            filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
        
        if self.phase=="test":
            filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

        for filename in filelist:
            with open(self.dir+'/'+filename, 'rb') as file: #Open pickle file
                data = pkl.load(file)
                file.close()
            for i in range(len(data)):
                d = data[i].replace([np.nan, np.inf], 1).values
                if d[:,3].max()>10 or d[:,3].min()<0.1:
                    continue
                yield d
