'''
Transformer
source : https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
github : https://github.com/nklingen/Transformer-Time-Series-Forecasting/blob/main/model.py

Positional Encoding
source : https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
github : https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
'''

import os
from zipfile import ZipFile
import pickle as pkl
import numpy as np

import torch
from torch import nn
from torch.utils.data import IterableDataset
from torch.utils.tensorboard import SummaryWriter


#from ..metrics import direction_accuracy

def direction_accuracy(pred, true):
    pred_adj = pred.reshape((-1))-1
    true_adj = true.reshape((-1))-1

    pred_adj = torch.where(pred_adj<0, -1, 1)
    true_adj = torch.where(true_adj<0, -1, 1)

    count = 0
    for i in range(pred_adj.shape[0]):
        if torch.equal(pred_adj[i], true_adj[i]):
            count+=1
    return count/pred_adj.shape[0]

def mape(pred, true):
    err = torch.abs(torch.add(true, -pred)/true).view(-1).mean() * 100
    return err.item()

def get_emb(sin_inp):
    '''
    Gets a base embedding for one dimension with sin and cos intertwined
    '''
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)

#----------------------------------------CONFIG----------------------------------------
config = {
    "n_features" : 5,
    "lag" : 53,
    "latent_dim" : 300,
    "num_layers" : 3,
    "n_head" : 12,
    "dropout" : 0,
    "path" : "Weights/path"
}
'''
n_features (int):
    number of series in input
latent_dim (int) :
    dimension of embedding for observations
num_layers (int) :
    number of transformers layers
n_head (int) :
    number of attention heads in each layer
dropout (float [0,1]) :
    dropout rate
path (directory):
    file to save/load model
'''

#----------------------------------------POSITIONAL_ENCODING1D---------------------------------------
class PositionalEncoding1D(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, channels):
        '''
        (batch, x, channels)
        channels (int) :
            The last dimension of the tensor you want to apply pos emb to.
        '''
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.cached_penc = None

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc

#----------------------------------------POSITIONAL_ENCODING_PERMUTE1D---------------------------------------
class PositionalEncodingPermute1D(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, channels):
        '''
        (batch, channels, x)
        channels (int) :
            The last dimension of the tensor you want to apply pos emb to.
        '''
        super(PositionalEncodingPermute1D, self).__init__()
        self.penc = PositionalEncoding1D(channels)

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        tensor = tensor.permute(0, 2, 1)
        enc = self.penc(tensor)
        return enc.permute(0, 2, 1)

    @property
    def org_channels(self):
        return self.penc.org_channels
    
#----------------------------------------SUMMER---------------------------------------
class Summer(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, penc):
        '''
        penc (PositionalEncoding1D) :
            The type of positional encoding to run the summer on.
        '''
        super(Summer, self).__init__()
        self.penc = penc

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        penc = self.penc(tensor)
        assert (
            tensor.size() == penc.size()
        ), "The original tensor size {} and the positional encoding tensor size {} must match!".format(
            tensor.size(), penc.size()
        )
        return tensor + penc

#----------------------------------------AR_TRANSFORMER----------------------------------------
class AR_Transformer(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, config):
        super(AR_Transformer, self).__init__()

        self.n_features = config["n_features"]
        self.lag = config["lag"]
        self.latent_dim = config["latent_dim"]
        self.num_layers = config["num_layers"]
        self.n_head = config["n_head"]
        self.dropout = config["dropout"]
        self.path = config["path"]

        self.val_y_true = []
        self.val_y_pred = []

        self.test_y_true = []
        self.test_y_pred = []

        self.embedding = nn.Linear(self.n_features, self.latent_dim)
        self.positional_encoder = Summer(PositionalEncodingPermute1D(self.lag))
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.n_head, dropout=self.dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=self.num_layers)        
        self.decoder = nn.Linear(self.latent_dim, 1)
        
        self.relu = nn.ReLU()
        
        self.init_weights()

    #--------------------INIT_WEIGHTS--------------------
    def init_weights(self):
        initrange = 0.1    
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    #--------------------_GENERATE_SQUARE_SUBSEQUENT_MASK--------------------
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    #--------------------FORWARD--------------------
    def forward(self, input, device):
        b, l, f = input.shape
        mask = self._generate_square_subsequent_mask(len(input)).to(device)

        x = torch.reshape(input, (b*l, f))
        x = self.embedding(x)
        x = torch.reshape(x, (b, l, self.latent_dim))
        x = self.positional_encoder(x)

        output = self.transformer_encoder(x,mask)
        output = self.decoder(output)

        return output

    #--------------------SAVE_MODEL--------------------
    def save_model(self):
        torch.save(self.state_dict(), self.path)

    #--------------------LOAD_MODEL--------------------
    def load_model(self):
        self.load_state_dict(torch.load(self.path))

    #--------------------COMPUTE_VALIDATION--------------------
    def compute_validation(self, val_data, loss_fn, metric, device, target):
        '''
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform evaluation
        target (int) :
            Target serie to forecast
        '''
        with torch.no_grad():
            batch = 0
            val_err = 0
            val_end_err = 0
            val_l = 0
            val_acc = 0
            val_end_acc = 0
            val_mape = 0
            val_end_mape = 0 
            Y_true = []
            Y_pred = []
            for X_batch in val_data:
                batch+=1
                b, l, f = X_batch.shape
                Y_batch = torch.roll(X_batch, -1, 1)
                X_batch , Y_batch = X_batch[:, 0:l-1, 0:5].float().to(device), Y_batch[:, 0:l-1, target].reshape((b,l-1,1)).float().to(device)

                outputs = torch.ones((b,l-1,1))
                outputs = outputs.to(device)
                loss = loss_fn(outputs, Y_batch)

                val_l+=loss.item()

                val_err+= metric(outputs, Y_batch)
                val_end_err+=metric(outputs[:,-1,:],Y_batch[:,-1,:])

                val_acc+= direction_accuracy(outputs, Y_batch)
                val_end_acc += direction_accuracy(outputs[:,-1,:],Y_batch[:,-1,:])

                val_mape += mape(outputs, Y_batch)
                val_end_mape+=mape(outputs[:,-1,:],Y_batch[:,-1,:])

                Y_true.append(Y_batch.detach())
                Y_pred.append(outputs.detach())

            val_loss = round(val_l/batch, 8)

            val_metric = round(val_err/batch, 8)
            val_end_metric = round(val_end_err/batch, 8)

            val_accuracy = round(val_acc/batch, 8)
            val_end_accuracy = round(val_end_acc/batch, 8)
            
            val_MAPE = round(val_mape/batch, 8)
            val_end_MAPE = round(val_end_mape/batch, 8)

            return val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_MAPE, val_end_MAPE, Y_true, Y_pred
        
    #--------------------EVALUATE--------------------
    def evaluate(self, test_data, loss_fn, metric, device, target):
        '''
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform evaluation
        target (int) :
            Target serie to forecast
        '''
        self.eval()
        #To build

    #--------------------PREDICT--------------------
    def predict(self, data, device):
        '''
        data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        device (torch.device) :
            Device where to perform evaluation
        '''
        with torch.no_grad():
            self.to(device)
            self.eval()
            Y_pred = []

            for X_batch in data:
                b, l, f = X_batch.shape
                X_batch = X_batch[:, 0:l-1, 0:5].float().to(device)

                outputs = self(X_batch, device)
                outputs = outputs.to(device)

                Y_pred.append(outputs.detach())

            return Y_pred
    #--------------------TRAINING_PHASE--------------------
    def training_phase(self, train_data, epochs, optimizer, loss_fn, metric, device, target, batch_start=1, batch_end=40000, val_data=None, val_step=40000, start_val_count=1, log_dir=None, load=False):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            Iterable training dataset
        epochs (int) <0 :
            Number of training epochs
        optimizer (torch.optim) :
            Algorithm for perfom optimization
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform training
        target (int) :
            Target serie to forecast
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        '''
        self.to(device)
        self.train()
        writer = SummaryWriter(log_dir)
        step = batch_start
        val_counter = start_val_count
        if load == True:
            self.load_model()
        for i in range(epochs):
            self.train()
            batch = 0
            err = 0
            total_loss = 0
            self.val_y_true = []
            self.val_y_pred = []
            for X_batch in train_data:
                batch+=1
                if batch<batch_start:
                    continue
                if batch>batch_end:
                    self.save_model()
                    break
                b, l, f = X_batch.shape
                Y_batch = torch.roll(X_batch, -1, 1)
                X_batch , Y_batch = X_batch[:, 0:l-1, 0:5].float().to(device), Y_batch[:, 0:l-1, target].reshape((b,l-1,1)).float().to(device)
                
                optimizer.zero_grad()

                outputs = self(X_batch, device)
                outputs = outputs.to(device)
                loss = loss_fn(outputs, Y_batch)

                total_loss+=loss.item()
                
                err+= metric(outputs, Y_batch)

                end_err = metric(outputs[:,-1,:],Y_batch[:,-1,:])
                end_acc = direction_accuracy(outputs[:,-1,:],Y_batch[:,-1,:])

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/Batch - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/Batch - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(end_err,8), step) #Metric on last value

                writer.add_scalar("Accuracy/Train/Batch - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction
                writer.add_scalar("Accuracy/Train/BatchEnd - step", round(end_acc, 8), step) #Accuracy on last value

                writer.add_scalar("MAPE/Train/Batch - step", mape(outputs, Y_batch), step)
                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs[:,-1,:],Y_batch[:,-1,:]), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch
                
                step+=1
                
                if batch%val_step == 0:

                    val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_mape, val_end_mape, Y_true, Y_pred = self.compute_validation(val_data, loss_fn, metric, device, target)
                    self.train()
                    self.val_y_true = Y_true
                    self.val_y_pred = Y_pred

                    writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

                    writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric
                    writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_end_metric, val_counter) #Val Metric at end of epoch

                    writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy
                    writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_end_accuracy, val_counter) #Val Accuracy at end of epoch

                    writer.add_scalar("MAPE/Val/ValStep", val_mape, val_counter)
                    writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_end_mape, val_counter)

                    val_counter+=1
            self.save_model()
        writer.close()

#----------------------------------------DATASET----------------------------------------
class TimeSeriesIterableDataset(IterableDataset):
    #--------------------__INIT__--------------------
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDataset).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_zip = os.listdir(self.dir)
        
    #--------------------__ITER__--------------------
    def __iter__(self):
        for zips in self.list_of_zip:
            with ZipFile(self.dir+'/'+zips) as z:
                filelist = z.namelist()

            if self.phase=="train":
                filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
            
            if self.phase=="val":
                filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
            
            if self.phase=="test":
                filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

            for filename in filelist:
                with ZipFile(self.dir+'/'+zips, 'r') as z: #Open zipfile
                    with z.open(filename) as file: #Open pickle file
                        data = pkl.load(file)
                        file.close()
                    for i in range(len(data)):
                        d = data[i].replace([np.nan, np.inf], 1).values
                        if d[:,3].max()>2 or d[:,3].min()<0.5:
                            continue
                        yield d