'''
Article : https://arxiv.org/pdf/2106.09305v3.pdf
Source : https://github.com/cure-lab/SCINet/blob/main/models/SCINet.py
'''

import math
import torch.nn.functional as F
from torch.autograd import Variable
from torch import nn
import torch
import numpy as np
import pandas as pd

from torch.utils.data import IterableDataset
import os
from zipfile import ZipFile
import pickle as pkl

from torch.utils.tensorboard import SummaryWriter

def direction_accuracy(pred, true):
    pred_adj = pred.reshape((-1))-1
    true_adj = true.reshape((-1))-1

    pred_adj = torch.where(pred_adj<0, -1, 1)
    true_adj = torch.where(true_adj<0, -1, 1)

    count = 0
    for i in range(pred_adj.shape[0]):
        if torch.equal(pred_adj[i], true_adj[i]):
            count+=1
    return count/pred_adj.shape[0]

def mape(pred, true):
    err = torch.abs(torch.add(true, -pred)/true).view(-1).mean() * 100
    return err.item()

class EarlyStopping():
    """
    source : https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

config = {
    "output_len" : 1,
    "input_len" : 70,
    "input_dim" : 5,
    "hid_size" : 1,
    "num_stacks" : 1,
    "num_levels" : 3,
    "num_decoder_layer" : 1,
    "concat_len" : 0,
    "groups" : 1,
    "kernel" : 5,
    "dropout" : 0.0,
    "single_step_output_One" : 0,
    "input_len_seg" : 0,
    "positionalE" : True,
    "modified" : True,
    "RIN" : False,
    "path" : "Weight/"
}

class Splitting(nn.Module):
    def __init__(self):
        super(Splitting, self).__init__()

    def even(self, x):
        return x[:, ::2, :]

    def odd(self, x):
        return x[:, 1::2, :]

    def forward(self, x):
        '''Returns the odd and even part'''
        return (self.even(x), self.odd(x))


class Interactor(nn.Module):
    def __init__(self, in_planes, splitting=True,
                 kernel = 5, dropout=0.5, groups = 1, hidden_size = 1, INN = True):
        super(Interactor, self).__init__()
        self.modified = INN
        self.kernel_size = kernel
        self.dilation = 1
        self.dropout = dropout
        self.hidden_size = hidden_size
        self.groups = groups
        if self.kernel_size % 2 == 0:
            pad_l = self.dilation * (self.kernel_size - 2) // 2 + 1 #by default: stride==1 
            pad_r = self.dilation * (self.kernel_size) // 2 + 1 #by default: stride==1 

        else:
            pad_l = self.dilation * (self.kernel_size - 1) // 2 + 1 # we fix the kernel size of the second layer as 3.
            pad_r = self.dilation * (self.kernel_size - 1) // 2 + 1
        self.splitting = splitting
        self.split = Splitting()

        modules_P = []
        modules_U = []
        modules_psi = []
        modules_phi = []
        prev_size = 1

        size_hidden = self.hidden_size
        modules_P += [
            nn.ReplicationPad1d((pad_l, pad_r)),

            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),

            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        modules_U += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]

        modules_phi += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        modules_psi += [
            nn.ReplicationPad1d((pad_l, pad_r)),
            nn.Conv1d(in_planes * prev_size, int(in_planes * size_hidden),
                      kernel_size=self.kernel_size, dilation=self.dilation, stride=1, groups= self.groups),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(self.dropout),
            nn.Conv1d(int(in_planes * size_hidden), in_planes,
                      kernel_size=3, stride=1, groups= self.groups),
            nn.Tanh()
        ]
        self.phi = nn.Sequential(*modules_phi)
        self.psi = nn.Sequential(*modules_psi)
        self.P = nn.Sequential(*modules_P)
        self.U = nn.Sequential(*modules_U)

    def forward(self, x):
        if self.splitting:
            (x_even, x_odd) = self.split(x)
        else:
            (x_even, x_odd) = x

        if self.modified:
            x_even = x_even.permute(0, 2, 1)
            x_odd = x_odd.permute(0, 2, 1)

            d = x_odd.mul(torch.exp(self.phi(x_even)))
            c = x_even.mul(torch.exp(self.psi(x_odd)))

            x_even_update = c + self.U(d)
            x_odd_update = d - self.P(c)

            return (x_even_update, x_odd_update)

        else:
            x_even = x_even.permute(0, 2, 1)
            x_odd = x_odd.permute(0, 2, 1)

            d = x_odd - self.P(x_even)
            c = x_even + self.U(d)

            return (c, d)


class InteractorLevel(nn.Module):
    def __init__(self, in_planes, kernel, dropout, groups , hidden_size, INN):
        super(InteractorLevel, self).__init__()
        self.level = Interactor(in_planes = in_planes, splitting=True,
                 kernel = kernel, dropout=dropout, groups = groups, hidden_size = hidden_size, INN = INN)

    def forward(self, x):
        (x_even_update, x_odd_update) = self.level(x)
        return (x_even_update, x_odd_update)

class LevelSCINet(nn.Module):
    def __init__(self,in_planes, kernel_size, dropout, groups, hidden_size, INN):
        super(LevelSCINet, self).__init__()
        self.interact = InteractorLevel(in_planes= in_planes, kernel = kernel_size, dropout = dropout, groups =groups , hidden_size = hidden_size, INN = INN)

    def forward(self, x):
        (x_even_update, x_odd_update) = self.interact(x)
        return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) #even: B, T, D odd: B, T, D

class SCINet_Tree(nn.Module):
    def __init__(self, in_planes, current_level, kernel_size, dropout, groups, hidden_size, INN):
        super().__init__()
        self.current_level = current_level


        self.workingblock = LevelSCINet(
            in_planes = in_planes,
            kernel_size = kernel_size,
            dropout = dropout,
            groups= groups,
            hidden_size = hidden_size,
            INN = INN)


        if current_level!=0:
            self.SCINet_Tree_odd=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN)
            self.SCINet_Tree_even=SCINet_Tree(in_planes, current_level-1, kernel_size, dropout, groups, hidden_size, INN)
    
    def zip_up_the_pants(self, even, odd):
        even = even.permute(1, 0, 2)
        odd = odd.permute(1, 0, 2) #L, B, D
        even_len = even.shape[0]
        odd_len = odd.shape[0]
        mlen = min((odd_len, even_len))
        _ = []
        for i in range(mlen):
            _.append(even[i].unsqueeze(0))
            _.append(odd[i].unsqueeze(0))
        if odd_len < even_len: 
            _.append(even[-1].unsqueeze(0))
        return torch.cat(_,0).permute(1,0,2) #B, L, D
        
    def forward(self, x):
        x_even_update, x_odd_update= self.workingblock(x)
        # We recursively reordered these sub-series. You can run the ./utils/recursive_demo.py to emulate this procedure. 
        if self.current_level ==0:
            return self.zip_up_the_pants(x_even_update, x_odd_update)
        else:
            return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update))

class EncoderTree(nn.Module):
    def __init__(self, in_planes,  num_levels, kernel_size, dropout, groups, hidden_size, INN):
        super().__init__()
        self.levels=num_levels
        self.SCINet_Tree = SCINet_Tree(
            in_planes = in_planes,
            current_level = num_levels-1,
            kernel_size = kernel_size,
            dropout =dropout ,
            groups = groups,
            hidden_size = hidden_size,
            INN = INN)
        
    def forward(self, x):

        x= self.SCINet_Tree(x)

        return x

class SCINet(nn.Module):
    def __init__(self, config):
        super(SCINet, self).__init__()

        self.input_dim = config["input_dim"]
        self.input_len = config["input_len"]
        self.output_len = config["output_len"]
        self.hidden_size = config["hid_size"]
        self.num_levels = config["num_levels"]
        self.groups = config["groups"]
        self.modified = config["modified"]
        self.kernel_size = config["kernel"]
        self.dropout = config["dropout"]
        self.single_step_output_One = config["single_step_output_One"]
        self.concat_len = config["concat_len"]
        self.pe = config["positionalE"]
        self.RIN= config["RIN"]
        self.num_decoder_layer = config["num_decoder_layer"]
        self.path = config["path"]

        self.blocks1 = EncoderTree(
            in_planes=self.input_dim,
            num_levels = self.num_levels,
            kernel_size = self.kernel_size,
            dropout = self.dropout,
            groups = self.groups,
            hidden_size = self.hidden_size,
            INN =  self.modified)

        if config["num_stacks"] == 2: # we only implement two stacks at most.
            self.blocks2 = EncoderTree(
                in_planes=self.input_dim,
            num_levels = self.num_levels,
            kernel_size = self.kernel_size,
            dropout = self.dropout,
            groups = self.groups,
            hidden_size = self.hidden_size,
            INN =  self.modified)

        self.stacks = config["num_stacks"]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
        self.projection1 = nn.Conv1d(self.input_len, self.output_len, kernel_size=1, stride=1, bias=False)
        self.div_projection = nn.ModuleList()
        self.overlap_len = self.input_len//4
        self.div_len = self.input_len//6

        if self.num_decoder_layer > 1:
            self.projection1 = nn.Linear(self.input_len, self.output_len)
            for layer_idx in range(self.num_decoder_layer-1):
                div_projection = nn.ModuleList()
                for i in range(6):
                    lens = min(i*self.div_len+self.overlap_len,self.input_len) - i*self.div_len
                    div_projection.append(nn.Linear(lens, self.div_len))
                self.div_projection.append(div_projection)

        if self.single_step_output_One: # only output the N_th timestep.
            if self.stacks == 2:
                if self.concat_len:
                    self.projection2 = nn.Conv1d(self.concat_len + self.output_len, 1,
                                                kernel_size = 1, bias = False)
                else:
                    self.projection2 = nn.Conv1d(self.input_len + self.output_len, 1,
                                                kernel_size = 1, bias = False)
        else: # output the N timesteps.
            if self.stacks == 2:
                if self.concat_len:
                    self.projection2 = nn.Conv1d(self.concat_len + self.output_len, self.output_len,
                                                kernel_size = 1, bias = False)
                else:
                    self.projection2 = nn.Conv1d(self.input_len + self.output_len, self.output_len,
                                                kernel_size = 1, bias = False)

        # For positional encoding
        self.pe_hidden_size = config["input_dim"]
        if self.pe_hidden_size % 2 == 1:
            self.pe_hidden_size += 1
    
        num_timescales = self.pe_hidden_size // 2
        max_timescale = 10000.0
        min_timescale = 1.0

        log_timescale_increment = (
                math.log(float(max_timescale) / float(min_timescale)) /
                max(num_timescales - 1, 1))
        temp = torch.arange(num_timescales, dtype=torch.float32)
        inv_timescales = min_timescale * torch.exp(
            torch.arange(num_timescales, dtype=torch.float32) *
            -log_timescale_increment)
        self.register_buffer('inv_timescales', inv_timescales)

        ### RIN Parameters ###
        if self.RIN:
            self.affine_weight = nn.Parameter(torch.ones(1, 1, config["input_dim"]))
            self.affine_bias = nn.Parameter(torch.zeros(1, 1, config["input_dim"]))


    def get_position_encoding(self, x):
        max_length = x.size()[1]
        position = torch.arange(max_length, dtype=torch.float32, device=x.device)  # tensor([0., 1., 2., 3., 4.], device='cuda:0')
        temp1 = position.unsqueeze(1)  # 5 1
        temp2 = self.inv_timescales.unsqueeze(0)  # 1 256
        scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)  # 5 256
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)  #[T, C]
        signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2))
        signal = signal.view(1, max_length, self.pe_hidden_size)
    
        return signal

    def forward(self, x):
        #assert self.input_len % (np.power(2, self.num_levels)) == 0 # evenly divided the input length into two parts. (e.g., 32 -> 16 -> 8 -> 4 for 3 levels)
        if self.pe:
            pe = self.get_position_encoding(x)
            if pe.shape[2] > x.shape[2]:
                x += pe[:, :, :-1]
            else:
                x += self.get_position_encoding(x)

        ### activated when RIN flag is set ###
        if self.RIN:
            print('/// RIN ACTIVATED ///\r',end='')
            means = x.mean(1, keepdim=True).detach()
            #mean
            x = x - means
            #var
            stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev
            # affine
            # print(x.shape,self.affine_weight.shape,self.affine_bias.shape)
            x = x * self.affine_weight + self.affine_bias

        # the first stack
        res1 = x
        x = self.blocks1(x)
        x += res1
        if self.num_decoder_layer == 1:
            x = self.projection1(x)
        else:
            x = x.permute(0,2,1)
            for div_projection in self.div_projection:
                output = torch.zeros(x.shape,dtype=x.dtype).cuda()
                for i, div_layer in enumerate(div_projection):
                    div_x = x[:,:,i*self.div_len:min(i*self.div_len+self.overlap_len,self.input_len)]
                    output[:,:,i*self.div_len:(i+1)*self.div_len] = div_layer(div_x)
                x = output
            x = self.projection1(x)
            x = x.permute(0,2,1)

        if self.stacks == 1:
            ### reverse RIN ###
            if self.RIN:
                x = x - self.affine_bias
                x = x / (self.affine_weight + 1e-10)
                x = x * stdev
                x = x + means

            return x

        elif self.stacks == 2:
            MidOutPut = x
            if self.concat_len:
                x = torch.cat((res1[:, -self.concat_len:,:], x), dim=1)
            else:
                x = torch.cat((res1, x), dim=1)

            # the second stack
            res2 = x
            x = self.blocks2(x)
            x += res2
            x = self.projection2(x)
            
            ### Reverse RIN ###
            if self.RIN:
                MidOutPut = MidOutPut - self.affine_bias
                MidOutPut = MidOutPut / (self.affine_weight + 1e-10)
                MidOutPut = MidOutPut * stdev
                MidOutPut = MidOutPut + means

            if self.RIN:
                x = x - self.affine_bias
                x = x / (self.affine_weight + 1e-10)
                x = x * stdev
                x = x + means

            return x, MidOutPut
    
    #--------------------SAVE_MODEL--------------------
    def save_model(self):
        torch.save(self.state_dict(), self.path)

    #--------------------SAVE_MODEL--------------------
    def save_model_finetuning(self, path):
        torch.save(self.state_dict(), path)

    #--------------------LOAD_MODEL--------------------
    def load_model(self):
        self.load_state_dict(torch.load(self.path))
    
    #--------------------LOAD_MODEL--------------------
    def load_model_finetuning(self, path):
        self.load_state_dict(torch.load(path))

    #--------------------PREDICT--------------------
    def predict_index(self, data_days, metric, target, device, save_path=''):
        '''
        data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        device (torch.device) :
            Device where to perform evaluation
        '''
        with torch.no_grad():
            if os.path.exists(save_path)==False:
                os.mkdir(save_path)
            self.to(device)
            self.eval()
            self.load_model()
            err=0
            Mape=0
            step = 0
            for d in data_days:
                data = torch.utils.data.DataLoader(d, batch_size=40)
                Y_pred = np.empty((0,), dtype=np.float32)
                Y_true = np.empty((0,), dtype=np.float32)
                Y_names = []
                for X_batch, X_names in data:
                    b, l, f = X_batch.shape
                    Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)
                    X_batch = X_batch[:, 1:l-1, 0:5].float().to(device)

                    outputs = self(X_batch)
                    outputs = outputs[:, :, target].reshape(b,1).float().to(device)

                    err+= metric(outputs, Y_batch)

                    Mape += mape(outputs, Y_batch)

                    Y_true = np.append(Y_true, Y_batch.cpu().numpy().reshape((b,)))
                    Y_pred = np.append(Y_pred, outputs.cpu().numpy().reshape((b,)))
                    Y_names = Y_names + list(X_names)

                    step+=1
                table = pd.DataFrame({"Name" : Y_names, "True" : Y_true, "Pred" : Y_pred})
                table.to_csv(f'{save_path}/{data.dataset.day}.csv', index=False)
            print("RMSE : ", round(err/step, 8))
            print('MAPE : ', round(Mape/step, 8))
    
    def compute_validation(self, val_data, loss_fn, metric, device, target):
        self.eval()
        with torch.no_grad():
            batch = 0
            val_err = 0
            val_l = 0
            val_acc = 0
            val_mape = 0
            Y_true = []
            Y_pred = []
            for X_batch in val_data:
                batch+=1
                b, l, f = X_batch.shape
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)
                X_batch = X_batch[:, 1:l-1, 0:5].float().to(device)

                outputs = self(X_batch)
                outputs = outputs[:, :, target].reshape(b,1).float().to(device)
                loss = loss_fn(outputs, Y_batch)

                val_l+=loss.item()

                val_err+= metric(outputs, Y_batch)

                val_acc+= direction_accuracy(outputs, Y_batch)

                val_mape += mape(outputs, Y_batch)

            val_loss = round(val_l/batch, 8)

            val_metric = round(val_err/batch, 8)

            val_accuracy = round(val_acc/batch, 8)
            
            val_MAPE = round(val_mape/batch, 8)

            return val_loss, val_metric, val_accuracy, val_MAPE

    def training_phase(self, train_data, epochs, optimizer, loss_fn, metric, device, target, batch_start=1, batch_end=40000, val_data=None, val_step=35866, start_val_count=1, log_dir=None, load=False):
        self.train()
        self.to(device)
        writer = SummaryWriter(log_dir)
        step = batch_start
        val_counter = start_val_count
        if load == True:
            self.load_model()
        for i in range(epochs):

            batch = 0
            err = 0
            total_loss = 0

            for X_batch in train_data:
                batch+=1
                if batch<batch_start:
                    continue
                if batch>batch_end:
                    self.save_model()
                    break

                b, l, f = X_batch.shape
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)
                X_batch = X_batch[:, 1:l-1, 0:5].float().to(device)
                
                optimizer.zero_grad()

                outputs = self(X_batch)
                outputs = outputs[:, :, target].reshape(b,1).float().to(device)
                loss = loss_fn(outputs, Y_batch)

                total_loss+=loss.item()
                
                err+= metric(outputs, Y_batch)

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/BatchEnd - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch

                writer.add_scalar("Accuracy/Train/BatchEnd - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction

                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs, Y_batch), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch

                step+=1

                if batch%val_step == 0:

                    val_loss, val_metric, val_accuracy, val_mape = self.compute_validation(val_data, loss_fn, metric, device, target)
                    self.train()

                    writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

                    writer.add_scalar("Metric/Val/BatchEnd - ValStep", val_metric, val_counter) #Val Metric at end of epoch

                    writer.add_scalar("Accuracy/Val/BatchEnd - ValStep", val_accuracy, val_counter) #Val Accuracy at end of epoch

                    writer.add_scalar("MAPE/Val/BatchEnd - ValStep", val_mape, val_counter)

                    val_counter+=1
            self.save_model()
        writer.close()

    def finetuning(self, train_data, epochs, optimizer, loss_fn, metric, device, target, patience, min_delta, val_data=None, log_dir=None, load_path="", save_path=""):
        '''
        train_data (torch.utils.data.dataloader.DataLoader) :
            Iterable training dataset
        epochs (int) <0 :
            Number of training epochs
        optimizer (torch.optim) :
            Algorithm for perfom optimization
        loss_fn (torch.nn.modules.loss) :
            Loss function to minimize
        metric (function) :
            Metric
        device (torch.device) :
            Device where to perform training
        target (int) :
            Target serie to forecast
        val_data (torch.utils.data.dataloader.DataLoader) :
            Iterable validation dataset
        '''
        self.to(device)
        self.train()
        writer = SummaryWriter(log_dir)
        step = 1
        val_counter = 0
        self.load_model_finetuning(load_path)
        early_stop = EarlyStopping(patience, min_delta)

        val_loss, val_metric, val_accuracy, val_MAPE = self.compute_validation(val_data, loss_fn, metric, device, target)
        self.train()

        early_stop(val_loss)

        if early_stop.counter==0:
            self.save_model_finetuning(save_path)


        writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

        writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric

        writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy

        writer.add_scalar("MAPE/Val/ValStep", val_MAPE, val_counter)
        
        val_counter+=1

        for i in range(epochs):

            batch = 0
            err = 0
            total_loss = 0

            for X_batch in train_data:
                batch+=1
                b, l, f = X_batch.shape
                Y_batch = X_batch[:, l-1, target].reshape(b, 1).float().to(device)
                X_batch = X_batch[:, 1:l-1, 0:5].float().to(device)
                
                optimizer.zero_grad()

                outputs = self(X_batch)
                outputs = outputs[:, :, target].reshape(b,1).float().to(device)
                loss = loss_fn(outputs, Y_batch)

                total_loss+=loss.item()
                
                err+= metric(outputs, Y_batch)

                loss.backward()
                optimizer.step()

                writer.add_scalar("Loss/Train/Epoch - step", round(total_loss/batch, 8), step) #Loss during epoch
                writer.add_scalar("Loss/Train/BatchEnd - step", round(loss.item(), 8), step) #Loss each batch

                writer.add_scalar("Metric/Train/Epoch - step", round(err/batch, 8), step) #Metric during epoch
                writer.add_scalar("Metric/Train/BatchEnd - step", round(metric(outputs, Y_batch), 8), step) #Metric each batch

                writer.add_scalar("Accuracy/Train/BatchEnd - step", direction_accuracy(outputs, Y_batch), step) #Accuracy of the direction

                writer.add_scalar("MAPE/Train/BatchEnd - step", mape(outputs, Y_batch), step)

                writer.add_scalar("Output/Mean/Train/Batch - step", outputs.mean(), step) #Mean prediction of the batch
                writer.add_scalar("Output/Min/Train/Batch - step", outputs.min(), step) #Min prediction of the batch
                writer.add_scalar("Output/Max/Train/Batch - step", outputs.max(), step) #Max prediction of the batch

                step+=1

            val_loss, val_metric, val_accuracy, val_MAPE = self.compute_validation(val_data, loss_fn, metric, device, target)
            self.train()

            writer.add_scalar("Loss/Val/ValStep", val_loss, val_counter) #Val Loss

            writer.add_scalar("Metric/Val/ValStep", val_metric, val_counter) #Val Metric

            writer.add_scalar("Accuracy/Val/ValStep", val_accuracy, val_counter) #Val Accuracy

            writer.add_scalar("MAPE/Val/ValStep", val_MAPE, val_counter)

            val_counter+=1

            early_stop(val_loss)

            if early_stop.counter==0:
                self.save_model_finetuning(save_path)

            if early_stop.early_stop==True:
                break

        writer.close()

class TimeSeriesIterableDataset(IterableDataset):
    #-----------------------------------__init__
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDataset).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_zip = os.listdir(self.dir)

    def __iter__(self):
        for zips in self.list_of_zip:
            with ZipFile(self.dir+'/'+zips) as z:
                filelist = z.namelist()

            if self.phase=="train":
                filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
            
            if self.phase=="val":
                filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
            
            if self.phase=="test":
                filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

            for filename in filelist:
                with ZipFile(self.dir+'/'+zips, 'r') as z: #Open zipfile
                    with z.open(filename) as file: #Open pickle file
                        data = pkl.load(file)
                        file.close()
                for i in range(len(data)):
                    d = data[i].replace([np.nan, np.inf], 1).values
                    if d[:,3].max()>2 or d[:,3].min()<0.5:
                        continue
                    yield d

class TimeSeriesIterableDatasetFinetuning(IterableDataset):
    #--------------------__INIT__--------------------
    def __init__(self, dir, phase):
        super(TimeSeriesIterableDatasetFinetuning).__init__()
        '''
        dir (str) :
            Directory containing zip files
        phase (str) ["train", "val", "test"] :
            Whether data is train, val or test
        '''
        self.dir = dir
        self.phase = phase
        self.list_of_file = os.listdir(self.dir)
        
    #--------------------__ITER__--------------------
    def __iter__(self):
        filelist = self.list_of_file

        if self.phase=="train":
            filelist = [item for item in filelist if item[-13:]=="_train.pickle"]
        
        if self.phase=="val":
            filelist = [item for item in filelist if item[-11:]=="_val.pickle"]
        
        if self.phase=="test":
            filelist = [item for item in filelist if item[-12:]=="_test.pickle"]

        for filename in filelist:
            with open(self.dir+'/'+filename, 'rb') as file: #Open pickle file
                data = pkl.load(file)
                file.close()
            for i in range(len(data)):
                d = data[i].replace([np.nan, np.inf], 1).values
                if d[:,3].max()>10 or d[:,3].min()<0.1:
                    continue
                yield d
