'''
Transformer
source : https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
github : https://github.com/nklingen/Transformer-Time-Series-Forecasting/blob/main/model.py

Positional Encoding
source : https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
github : https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
'''

import os
from zipfile import ZipFile
import pickle as pkl
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import IterableDataset
from torch.utils.tensorboard import SummaryWriter


#from ..metrics import direction_accuracy

def direction_accuracy(pred, true):
    pred_adj = pred.reshape((-1))-1
    true_adj = true.reshape((-1))-1

    pred_adj = torch.where(pred_adj<-0.0001, -1, pred_adj)
    true_adj = torch.where(true_adj<-0.0001, -1, true_adj)

    pred_adj = torch.where(pred_adj>0.0001, 1, pred_adj)
    true_adj = torch.where(true_adj>0.0001, 1, true_adj)

    pred_adj = torch.where(torch.abs(pred_adj)!=1, 0, pred_adj)
    true_adj = torch.where(torch.abs(true_adj)!=1, 0, true_adj)

    count = 0
    for i in range(pred_adj.shape[0]):
        if torch.equal(pred_adj[i], true_adj[i]):
            count+=1

    return count/pred_adj.shape[0]

def mape(pred, true):
    err = torch.abs(torch.add(true, -pred)/true).view(-1).mean() * 100
    return err.item()

def get_emb(sin_inp):
    '''
    Gets a base embedding for one dimension with sin and cos intertwined
    '''
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)

#----------------------------------------EARLY_STOPPING----------------------------------------

class EarlyStopping():
    """
    source : https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

#----------------------------------------CONFIG----------------------------------------
config = {
    "n_features" : 5,
    "lag" : 53,
    "latent_dim" : 300,
    "num_layers" : 3,
    "n_head" : 12,
    "dropout" : 0,
    "path" : "Weights/path"
}
'''
n_features (int):
    number of series in input
latent_dim (int) :
    dimension of embedding for observations
num_layers (int) :
    number of transformers layers
n_head (int) :
    number of attention heads in each layer
dropout (float [0,1]) :
    dropout rate
path (directory):
    file to save/load model
'''

#----------------------------------------POSITIONAL_ENCODING1D---------------------------------------
class PositionalEncoding1D(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, channels):
        '''
        (batch, x, channels)
        channels (int) :
            The last dimension of the tensor you want to apply pos emb to.
        '''
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.cached_penc = None

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc

#----------------------------------------POSITIONAL_ENCODING_PERMUTE1D---------------------------------------
class PositionalEncodingPermute1D(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, channels):
        '''
        (batch, channels, x)
        channels (int) :
            The last dimension of the tensor you want to apply pos emb to.
        '''
        super(PositionalEncodingPermute1D, self).__init__()
        self.penc = PositionalEncoding1D(channels)

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        tensor = tensor.permute(0, 2, 1)
        enc = self.penc(tensor)
        return enc.permute(0, 2, 1)

    @property
    def org_channels(self):
        return self.penc.org_channels
    
#----------------------------------------SUMMER---------------------------------------
class Summer(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, penc):
        '''
        penc (PositionalEncoding1D) :
            The type of positional encoding to run the summer on.
        '''
        super(Summer, self).__init__()
        self.penc = penc

    #--------------------FORWARD--------------------
    def forward(self, tensor):
        penc = self.penc(tensor)
        assert (
            tensor.size() == penc.size()
        ), "The original tensor size {} and the positional encoding tensor size {} must match!".format(
            tensor.size(), penc.size()
        )
        return tensor + penc

#----------------------------------------AR_TRANSFORMER----------------------------------------
class AR_Transformer(nn.Module):
    #--------------------__INIT__--------------------
    def __init__(self, config):
        super(AR_Transformer, self).__init__()

        self.n_features = config["n_features"]
        self.lag = config["lag"]
        self.latent_dim = config["latent_dim"]
        self.num_layers = config["num_layers"]
        self.n_head = config["n_head"]
        self.dropout = config["dropout"]
        self.path = config["path"]

        self.val_y_true = []
        self.val_y_pred = []

        self.test_y_true = []
        self.test_y_pred = []

        self.embedding = nn.Linear(self.n_features, self.latent_dim)
        self.positional_encoder = Summer(PositionalEncodingPermute1D(self.lag))
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.n_head, dropout=self.dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=self.num_layers)        
        self.decoder = nn.Linear(self.latent_dim, 1)
        
        self.relu = nn.ReLU()
        
        self.init_weights()

    #--------------------INIT_WEIGHTS--------------------
    def init_weights(self):
        initrange = 0.1    
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    #--------------------_GENERATE_SQUARE_SUBSEQUENT_MASK--------------------
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    #--------------------FORWARD--------------------
    def forward(self, input, device):
        b, l, f = input.shape
        mask = self._generate_square_subsequent_mask(len(input)).to(device)

        x = torch.reshape(input, (b*l, f))
        x = self.embedding(x)
        x = torch.reshape(x, (b, l, self.latent_dim))
        x = self.positional_encoder(x)

        output = self.transformer_encoder(x,mask)
        output = self.decoder(output)

        return output

    #--------------------SAVE_MODEL--------------------
    def save_model(self):
        torch.save(self.state_dict(), self.path)

    #--------------------SAVE_MODEL--------------------
    def save_model_finetuning(self, path):
        torch.save(self.state_dict(), path)

    #--------------------LOAD_MODEL--------------------
    def load_model(self):
        self.load_state_dict(torch.load(self.path))
    
    #--------------------LOAD_MODEL--------------------
    def load_model_finetuning(self, path):
        self.load_state_dict(torch.load(path))

    #--------------------COMPUTE_VALIDATION--------------------
    def compute_validation(self, val_data, loss_fn, metric, device, target, finetune=False):
        '''
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform evaluation
        target (int) :
            Target serie to forecast
        '''
        self.eval()
        with torch.no_grad():
            batch = 0
            val_err = 0
            val_end_err = 0
            val_l = 0
            val_acc = 0
            val_end_acc = 0
            val_mape = 0
            val_end_mape = 0 
            Y_true = []
            Y_pred = []
            for X_batch in val_data:
                batch+=1
                b, l, f = X_batch.shape
                Y_batch = torch.roll(X_batch, -1, 1)
                X_batch , Y_batch = X_batch[:, 0:l, 0:5].float().to(device), Y_batch[:, 0:l, target].reshape((b,l,1)).float().to(device)

                outputs = self(X_batch, device)
                outputs = outputs.to(device)
                if finetune==True:
                    outputs2 = outputs[:,l-2,:]
                    Y_batch2 = Y_batch[:,l-2,:]
                    loss = loss_fn(outputs2, Y_batch2)
                else:
                    loss = loss_fn(outputs, Y_batch)

                val_l+=loss.item()

                val_err+= metric(outputs, Y_batch)
                val_end_err+=metric(outputs[:,-2,:],Y_batch[:,-2,:])

                val_acc+= direction_accuracy(outputs, Y_batch)
                val_end_acc += direction_accuracy(outputs[:,-2,:],Y_batch[:,-2,:])

                val_mape += mape(outputs, Y_batch)
                val_end_mape+=mape(outputs[:,-2,:],Y_batch[:,-2,:])

                Y_true.append(Y_batch.detach())
                Y_pred.append(outputs.detach())

            val_loss = round(val_l/batch, 8)

            val_metric = round(val_err/batch, 8)
            val_end_metric = round(val_end_err/batch, 8)

            val_accuracy = round(val_acc/batch, 8)
            val_end_accuracy = round(val_end_acc/batch, 8)
            
            val_MAPE = round(val_mape/batch, 8)
            val_end_MAPE = round(val_end_mape/batch, 8)

            return val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_MAPE, val_end_MAPE, Y_true, Y_pred

    #--------------------PREDICT--------------------
    def predict_index(self, data_days, metric, target, device, save_path=''):
        '''
        data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        device (torch.device) :
            Device where to perform evaluation
        '''
        with torch.no_grad():
            if os.path.exists(save_path)==False:
                os.mkdir(save_path)
            self.to(device)
            self.eval()
            self.load_model()
            err=0
            Mape=0
            step = 0
            for d in data_days:
                data = torch.utils.data.DataLoader(d, batch_size=40)
                Y_pred = np.empty((0,), dtype=np.float32)
                Y_true = np.empty((0,), dtype=np.float32)
                Y_names = []
                for X_batch, X_names in data:
                    b, l, f = X_batch.shape
                    X_batch = X_batch[:, 0:l, 0:5].float().to(device)
                    Y_batch = torch.roll(X_batch, -1, 1)
                    X_batch , Y_batch = X_batch[:, 0:l, 0:5].float().to(device), Y_batch[:, 0:l, target].reshape((b,l,1)).float().to(device)
                    outputs = self(X_batch, device)
                    outputs = outputs.to(device)
                    outputs2 = outputs[:,l-2,:]
                    Y_batch2 = Y_batch[:,l-2,:]

                    err+= metric(outputs2, Y_batch2)

                    Mape += mape(outputs2, Y_batch2)

                    Y_true = np.append(Y_true, Y_batch2.cpu().numpy().reshape((b,)))
                    Y_pred = np.append(Y_pred, outputs2.cpu().numpy().reshape((b,)))
                    Y_names = Y_names + list(X_names)

                    step+=1
                table = pd.DataFrame({"Name" : Y_names, "True" : Y_true, "Pred" : Y_pred})
                table.to_csv(f'{save_path}/{data.dataset.day}.csv', index=False)

            print("RMSE : ", round(err/step, 8))
            print('MAPE : ', round(Mape/step, 8))

    #--------------------TRAINING_PHASE--------------------
    def training_phase(self, train_data, epochs, optimizer, loss_fn, metric, device, target, batch_start=1, batch_end=40000, val_data=None, val_step=40000, start_val_count=1, log_dir=None, load=False):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            Iterable training dataset
        epochs (int) <0 :
            Number of training epochs
        optimizer (torch.optim) :
            Algorithm for perfom optimization
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform training
        target (int) :
            Target serie to forecast
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        '''
        self.to(device)
        self.train()
        writer = SummaryWriter(log_dir)
        step = batch_start
        val_counter = start_val_count
        if load == True:
            self.load_model()
        for i in range(epochs):
            self.train()
            batch = 0
            err = 0
            total_loss = 0
            self.val_y_true = []
            self.val_y_pred = []
            for X_batch in train_data:
                batch+=1
                if batch<batch_start:
                    continue
                if batch>batch_end:
                    self.save_model()
                    break
                b, l, f = X_batch.shape
                Y_batch = torch.roll(X_batch, -1, 1)
                X_batch , Y_batch = X_batch[:, 0:l, 0:5].float().to(device), Y_batch[:, 0:l, target].reshape((b,l,1)).float().to(device)
                
                optimizer.zero_grad()

                outputs = self(X_batch, device)
                outputs = outputs.to(device)
                loss = loss_fn(outputs, Y_batch)

                total_loss+=loss.item()
                
                err+= metric(outputs, Y_batch)

                end_err = metric(outputs[:,-2,:],Y_batch[:,-2,:])
                end_acc = direction_accuracy(outputs[:,-2,:],Y_batch[:,-2,:])

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/Batch - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/Batch - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(end_err,8), step) #Metric on last value

                writer.add_scalar("Accuracy/Train/Batch - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction
                writer.add_scalar("Accuracy/Train/BatchEnd - step", round(end_acc, 8), step) #Accuracy on last value

                writer.add_scalar("MAPE/Train/Batch - step", mape(outputs, Y_batch), step)
                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs[:,-2,:],Y_batch[:,-2,:]), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch
                
                step+=1
                
                if batch%val_step == 0:

                    val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_mape, val_end_mape, Y_true, Y_pred = self.compute_validation(val_data, loss_fn, metric, device, target)
                    self.train()
                    self.val_y_true = Y_true
                    self.val_y_pred = Y_pred

                    writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

                    writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric
                    writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_end_metric, val_counter) #Val Metric at end of epoch

                    writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy
                    writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_end_accuracy, val_counter) #Val Accuracy at end of epoch

                    writer.add_scalar("MAPE/Val/ValStep", val_mape, val_counter)
                    writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_end_mape, val_counter)

                    val_counter+=1
            self.save_model()
        writer.close()
    
    def finetuning(self, train_data, epochs, optimizer, loss_fn, metric, device, target, patience, min_delta, val_data=None, log_dir=None, load_path="", save_path=""):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            Iterable training dataset
        epochs (int) <0 :
            Number of training epochs
        optimizer (torch.optim) :
            Algorithm for perfom optimization
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform training
        target (int) :
            Target serie to forecast
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        '''
        self.to(device)
        self.train()
        writer = SummaryWriter(log_dir)
        step = 1
        val_counter = 0
        self.load_model_finetuning(load_path)
        early_stop = EarlyStopping(patience, min_delta)

        val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_mape, val_end_mape, Y_true, Y_pred = self.compute_validation(val_data, loss_fn, metric, device, target, finetune=True)
        self.train()
        self.val_y_true = Y_true
        self.val_y_pred = Y_pred
        early_stop(val_loss)

        if early_stop.counter==0:
            self.save_model_finetuning(save_path)

        writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

        writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric
        writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_end_metric, val_counter) #Val Metric at end of epoch

        writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy
        writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_end_accuracy, val_counter) #Val Accuracy at end of epoch

        writer.add_scalar("MAPE/Val/ValStep", val_mape, val_counter)
        writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_end_mape, val_counter)
        
        val_counter+=1

        for i in range(epochs):
            self.train()
            batch = 0
            err = 0
            total_loss = 0
            self.val_y_true = []
            self.val_y_pred = []
            
            for X_batch in train_data:
                batch+=1
                b, l, f = X_batch.shape
                Y_batch = torch.roll(X_batch, -1, 1)
                X_batch , Y_batch = X_batch[:, 0:l, 0:5].float().to(device), Y_batch[:, 0:l, target].reshape((b,l,1)).float().to(device)
                
                optimizer.zero_grad()

                outputs = self(X_batch, device)
                outputs = outputs.to(device)
                outputs = outputs[:,l-2,:]
                Y_batch = Y_batch[:,l-2,:]
                loss = loss_fn(outputs, Y_batch)

                total_loss+=loss.item()
                
                err+= metric(outputs, Y_batch)

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/Batch - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/Batch - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch

                writer.add_scalar("Accuracy/Train/Batch - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction

                writer.add_scalar("MAPE/Train/Batch - step", mape(outputs, Y_batch), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch
                
                step+=1

            val_loss, val_metric, val_end_metric, val_accuracy, val_end_accuracy, val_mape, val_end_mape, Y_true, Y_pred = self.compute_validation(val_data, loss_fn, metric, device, target, finetune=True)
            self.train()
            self.val_y_true = Y_true
            self.val_y_pred = Y_pred

            writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

            writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric
            writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_end_metric, val_counter) #Val Metric at end of epoch

            writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy
            writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_end_accuracy, val_counter) #Val Accuracy at end of epoch

            writer.add_scalar("MAPE/Val/ValStep", val_mape, val_counter)
            writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_end_mape, val_counter)

            val_counter+=1

            early_stop(val_loss)

            if early_stop.counter==0:
                self.save_model_finetuning(save_path)

            if early_stop.early_stop==True:
                break

        writer.close()

#----------------------------------------DATASET----------------------------------------
class TimeSeriesIterableDataset(IterableDataset):
    #--------------------__INIT__--------------------
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDataset).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_zip = os.listdir(self.dir)
        
    #--------------------__ITER__--------------------
    def __iter__(self):
        for zips in self.list_of_zip:
            with ZipFile(self.dir+'/'+zips) as z:
                filelist = z.namelist()

            if self.phase=="train":
                filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
            
            if self.phase=="val":
                filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
            
            if self.phase=="test":
                filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

            for filename in filelist:
                with ZipFile(self.dir+'/'+zips, 'r') as z: #Open zipfile
                    with z.open(filename) as file: #Open pickle file
                        data = pkl.load(file)
                        file.close()
                    for i in range(len(data)):
                        d = data[i].replace([np.nan, np.inf], 1).values
                        if d[:,3].max()>2 or d[:,3].min()<0.5:
                            continue
                        yield d

class TimeSeriesIterableDatasetFinetuning(IterableDataset):
    #--------------------__INIT__--------------------
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDatasetFinetuning).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_file = os.listdir(self.dir)
        
    #--------------------__ITER__--------------------
    def __iter__(self):
        filelist = self.list_of_file

        if self.phase=="train":
            filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
        
        if self.phase=="val":
            filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
        
        if self.phase=="test":
            filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

        for filename in filelist:
            with open(self.dir+'/'+filename, 'rb') as file: #Open pickle file
                data = pkl.load(file)
                file.close()
            for i in range(len(data)):
                d = data[i].replace([np.nan, np.inf], 1).values
                if d[:,3].max()>10 or d[:,3].min()<0.1:
                    continue
                yield d
