import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader, Dataset
import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/trend_seasonality_causal_structure/src/')

from layers.Embed import DataEmbedding
from layers.AutoCorrelation import AutoCorrelationLayer
from layers.FourierCorrelation import FourierBlock, FourierCrossAttention
from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
import numpy as np
from omegaconf import DictConfig

from src.models.utils import bce
from src.data import RealDatasetCollection, SyntheticDatasetCollection
import logging

logger = logging.getLogger(__name__)



class Model(LightningModule):
    """
    FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
    Paper link: https://proceedings.mlr.press/v162/zhou22g.html
    """
    model_type = 'multi'
    
                    #  version='fourier', 
                #  mode_select='random', 
                #  modes=32,
                
    def __init__(self, 
                 args: DictConfig = None, 
                 dataset_collection: SyntheticDatasetCollection = None, 
                 **kwargs):
        """
        version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets].
        mode_select: str, for FEDformer, there are two mode selection method, options: [random, low].
        modes: int, modes to be selected.
        """
        
        super().__init__()
        
        self.dataset_collection = dataset_collection
        self.args = args
        self.save_hyperparameters(args)
        args = args.model.multi
        self.task_name = args.task_name
        self.seq_len = args.seq_len
        self.label_len = args.label_len
        self.pred_len = args.pred_len
        self.version = args.version
        self.mode_select = args.mode_select
        self.modes = args.modes
        self.has_vitals = args.has_vitals
        self.enc_in = args.dims_treatment
        self.enc_in += args.dims_vitals if self.has_vitals else 0
        self.enc_in += args.dims_outcome
        self.autoregressive = self.args.dataset.autoregressive
        print(self.autoregressive)

        # Decomp
        self.decomp = series_decomp(args.moving_avg)
        self.enc_embedding = DataEmbedding(self.enc_in, 
                                           args.d_model, 
                                           args.embed, 
                                           args.freq,
                                           args.dropout)
        self.dec_embedding = DataEmbedding(self.enc_in, 
                                           args.d_model, 
                                           args.embed, 
                                           args.freq,
                                           args.dropout)

        if self.version == 'Wavelets':
            encoder_self_att = MultiWaveletTransform(ich=args.d_model, 
                                                     L=1, 
                                                     base='legendre')
            decoder_self_att = MultiWaveletTransform(ich=args.d_model, 
                                                     L=1, 
                                                     base='legendre')
            decoder_cross_att = MultiWaveletCross(in_channels=args.d_model,
                                                  out_channels=args.d_model,
                                                  seq_len_q=self.seq_len // 2 + self.pred_len,
                                                  seq_len_kv=self.seq_len,
                                                  modes=self.modes,
                                                  ich=args.d_model,
                                                  base='legendre',
                                                  activation='tanh')
        else:
            encoder_self_att = FourierBlock(in_channels=args.d_model,
                                            out_channels=args.d_model,
                                            seq_len=self.seq_len,
                                            modes=self.modes,
                                            mode_select_method=self.mode_select)
            decoder_self_att = FourierBlock(in_channels=args.d_model,
                                            out_channels=args.d_model,
                                            seq_len=self.seq_len // 2 + self.pred_len,
                                            modes=self.modes,
                                            mode_select_method=self.mode_select)
            decoder_cross_att = FourierCrossAttention(in_channels=args.d_model,
                                                      out_channels=args.d_model,
                                                      seq_len_q=self.seq_len // 2 + self.pred_len,
                                                      seq_len_kv=self.seq_len,
                                                      modes=self.modes,
                                                      mode_select_method=self.mode_select)
        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AutoCorrelationLayer(
                        encoder_self_att,  # instead of multi-head attention in transformer
                        args.d_model, 
                        args.n_heads),
                    args.d_model,
                    args.d_ff,
                    moving_avg=args.moving_avg,
                    dropout=args.dropout,
                    activation=args.activation
                ) for l in range(args.e_layers)
            ],
            norm_layer=my_Layernorm(args.d_model)
        )
        # Decoder
        self.decoder = Decoder(
            [
                DecoderLayer(
                    AutoCorrelationLayer(
                        decoder_self_att,
                        args.d_model, 
                        args.n_heads),
                    AutoCorrelationLayer(
                        decoder_cross_att,
                        args.d_model, 
                        args.n_heads),
                    args.d_model,
                    self.enc_in,   # output dims
                    args.d_ff,
                    moving_avg=args.moving_avg,
                    dropout=args.dropout,
                    activation=args.activation,
                )
                for l in range(args.d_layers)
            ],
            norm_layer=my_Layernorm(args.d_model),
            projection=nn.Linear(args.d_model, 
                                 self.enc_in, 
                                 bias=True)
        )
        
        self.static_transformation = nn.Linear(1, 
                                               self.pred_len)
        self.input_transformation = nn.Linear(self.enc_in, 
                                               args.c_out)
        
        self.outcome_layer = nn.Linear(self.enc_in*2 + args.dims_static + args.dims_treatment, args.c_out)
        self.treat_layer = nn.Linear(self.enc_in + args.dims_static, args.dims_treatment)

        
        if self.task_name == 'imputation':
            self.projection = nn.Linear(args.d_model, 
                                        args.c_out, 
                                        bias=True)
        if self.task_name == 'anomaly_detection':
            self.projection = nn.Linear(args.d_model, 
                                        args.c_out, 
                                        bias=True)
        if self.task_name == 'classification':
            self.act = F.gelu
            self.dropout = nn.Dropout(args.dropout)
            self.projection = nn.Linear(args.d_model * args.seq_len, 
                                        args.num_class)




    # mark: time-stamp related input (e.g., month, day, weekday, hour, minute)
    def forecast(self, batch):
        prev_outputs = batch['prev_outputs']                   # B X T X input_dims
        prev_treatments = batch['prev_treatments']             # B X T X treat_dims
        static_features = batch['static_features']             # B X static_dims
        curr_treatments = batch['current_treatments']          # B X T X treat_dims
        vitals = batch['vitals'] if self.has_vitals else None  # B X T X vitals_dims
        nan_mask = ~prev_outputs.isnan().any(axis=-1)
        prev_outputs[~nan_mask] = 0
        

        
        static_features = self.static_transformation(static_features.unsqueeze(2)).permute(0, 2, 1)  # B X dim_static X pred_len

        
        x_enc = torch.cat((prev_outputs, prev_treatments), dim=-1)         # B X T X (dim_output + dim_treat)
        if self.has_vitals:
            x_enc = torch.cat((x_enc, vitals), dim=-1)                     # B X T X (dim_output + dim_treat + dim_vitals) 

        
        # decomp init
        mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)    # B X pred_len X dim_enc
        seasonal_init, trend_init = self.decomp(x_enc)  # x - moving_avg, moving_avg

        
        # decoder input

        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)              # B X (seq_len + pred_len) X dim_enc
        seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)) # B X (seq_len + pred_len) X dim_enc
 
        
        # enc
        enc_out = self.enc_embedding(x_enc)                    # B X seq_len X d_model
        dec_out = self.dec_embedding(seasonal_init)            # B X seq_len X d_model
        enc_out, attns = self.encoder(enc_out, attn_mask=None) # B X seq_len X d_model
     
        # dec
        seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init)
        seasonal_part = seasonal_part[:, -self.pred_len:, :]    # B X pred_len X enc_in
        trend_part = trend_part[:, -self.pred_len:, :]          # B X pred_len X enc_in
        

        
        est_outcome = self.outcome_layer(torch.cat([trend_part, seasonal_part, static_features, curr_treatments], dim=-1))    # B X pred_len X dim_out
        est_treatment = self.treat_layer(torch.cat([trend_part, static_features], dim=-1))                                    # B X pred_len X dim_treat
        

        return est_outcome, est_treatment, trend_part, seasonal_part

    def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
        # enc
        enc_out = self.enc_embedding(x_enc, x_mark_enc)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        # final
        dec_out = self.projection(enc_out)
        return dec_out

    def anomaly_detection(self, x_enc):
        # enc
        enc_out = self.enc_embedding(x_enc, None)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
        # final
        dec_out = self.projection(enc_out)
        return dec_out

    def classification(self, x_enc, x_mark_enc):
        # enc
        enc_out = self.enc_embedding(x_enc, None)
        enc_out, attns = self.encoder(enc_out, attn_mask=None)

        # Output
        output = self.act(enc_out)
        output = self.dropout(output)
        output = output * x_mark_enc.unsqueeze(-1)
        output = output.reshape(output.shape[0], -1)
        output = self.projection(output)
        return output

    def forward(self, x_enc, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            est_outcome, est_treatment, trend_part, seasonal_part = self.forecast(x_enc)
            return est_outcome, est_treatment, trend_part, seasonal_part  # [B, L, D]

        return None
    
    
    def generate_continuous_mask(self, B, T, n=5, l=0.1):
        res = torch.full((B, T), True, dtype=torch.bool)
        if isinstance(n, float):
            n = int(n * T)
        n = max(min(n, T // 2), 1)
        
        if isinstance(l, float):
            l = int(l * T)
        l = max(l, 1)
        
        for i in range(B):
            for _ in range(n):
                t = np.random.randint(T-l+1)
                res[i, t:t+l] = False
        return res


    def generate_binomial_mask(self, B, T, p=0.5):
        return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)
    
    def _get_optimizer(self, param_optimizer: list):
        no_decay = ['bias', 'layer_norm']
        sub_args = self.hparams.model[self.model_type]
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': sub_args['optimizer']['weight_decay'],
            },
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]
        lr = sub_args['optimizer']['learning_rate']
        optimizer_cls = sub_args['optimizer']['optimizer_cls']
        if optimizer_cls.lower() == 'adamw':
            optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'adam':
            optimizer = optim.Adam(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'sgd':
            optimizer = optim.SGD(optimizer_grouped_parameters, lr=lr,
                                  momentum=sub_args['optimizer']['momentum'])
        else:
            raise NotImplementedError()

        return optimizer

    def _get_lr_schedulers(self, optimizer):
        if not isinstance(optimizer, list):
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
            return [optimizer], [lr_scheduler]
        else:
            lr_schedulers = []
            for opt in optimizer:
                lr_schedulers.append(optim.lr_scheduler.ExponentialLR(opt, gamma=0.99))
            return optimizer, lr_schedulers

    def configure_optimizers(self):
        optimizer = self._get_optimizer(list(self.named_parameters()))

        if self.hparams.model[self.model_type]['optimizer']['lr_scheduler']:
            return self._get_lr_schedulers(optimizer)

        return optimizer
    
    def train_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.train_f, shuffle=True, batch_size=self.args.dataset.train_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.train_f_non, shuffle=True, batch_size=self.args.dataset.train_batch_size, drop_last=True)
    
    def val_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.val_f, shuffle=False, batch_size=self.args.dataset.val_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.val_f_non, shuffle=False, batch_size=self.args.dataset.val_batch_size, drop_last=True)
    
    def predict_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.test_cf_one_step, shuffle=False, batch_size=self.args.dataset.test_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.test_cf_one_step_non, shuffle=False, batch_size=self.args.dataset.test_batch_size, drop_last=True)

    
    def compute_p_t(self, batch):
        # p_t: [T, 4]

        p_t = torch.zeros(batch['prev_treatments'].shape[1], batch['prev_treatments'].shape[2])
        for cur_ts in range(0, batch['prev_treatments'].shape[1]):
            for cur_b in range(0, batch['prev_treatments'].shape[0]):
                for cur_t in range(0, batch['prev_treatments'].shape[2]):
                    if batch['prev_treatments'][cur_b, cur_ts, cur_t] == 1:
                        p_t[cur_ts, cur_t] += 1
        p_t /= batch['prev_treatments'].shape[0]
        return p_t

    
    def compute_factual_loss(self, est, true, pi_0, p_t, t):
        # pi_0: [B, T, 4]
        # p_t:  [T, 4]
        # est, true: [B, T, 1]
        # t: [B, T, 4]
        
 
        
        ''' Compute sample reweighting '''
        sample_weight = torch.zeros_like(est)
        for cur_ts in range(0, t.shape[1]):
            for cur_t1 in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, :] == cur_t1).squeeze()[:, 0]    # find index
                sample_weight[idx_cur_t, cur_ts, 0] = 1.
                if cur_ts > 1:
                    for cur_t2 in  range(0, t.shape[-1]):
                        if cur_t2 != cur_t1:
                            if p_t[cur_ts, cur_t2] < 1e-4 or (pi_0[idx_cur_t, cur_ts, cur_t1] < 1e-4).any():
                                continue
                            sample_weight[idx_cur_t, cur_ts, 0] += (pi_0[idx_cur_t, cur_ts, cur_t2])/(pi_0[idx_cur_t, cur_ts, cur_t1] + 1e-7) * (p_t[cur_ts, cur_t1]/(p_t[cur_ts, cur_t2] + 1e-7))
        
        return torch.mean(sample_weight * torch.square(est - true))
    
    
    def compute_cross_entropy_loss(self, est, target):
        # target: treatment (0 or 1 one-hot vector)
        # est: class probability

        
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(est, target)
        
        return loss
    
    def mmd2_lin_dfr(self, upsilon, t, p_t):
        ''' Linear MMD '''
        # t: treatment variable [B, T, 4]
        # X: data               [B, T, 1]
        # p_t: probability        [T, 4]
        
        mmd = 0
        for cur_ts in range(0, t.shape[1]):
            list_upsilon_mean = []
            for cur_tr in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, cur_tr] == 1).squeeze()
                
                list_upsilon_mean.append(torch.mean(upsilon[idx_cur_t, cur_ts, :], dim=0))
            # print(list_X_mean)   # nan
            for i in range(1, len(list_upsilon_mean)):
                if cur_ts > 1:
                    flag = True
                    for i in range(len(list_upsilon_mean)):
                        if torch.isnan(list_upsilon_mean[i]).any():
                            flag = False
                    if flag:
                        mmd += torch.sum(torch.square(2.0*p_t[cur_ts, i]*list_upsilon_mean[i] - 2.0*p_t[cur_ts, i-1]*list_upsilon_mean[i-1]))
        
        return mmd
    
    def get_normalised_masked_rmse(self, dataset: Dataset, one_step_counterfactual=True):
        outputs_scaled = self.get_predictions(dataset)
        
        unscale = self.hparams.exp.unscale_rmse
        percentage = self.hparams.exp.percentage_rmse


        if unscale:
            output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']
            outputs_unscaled = outputs_scaled * output_stds + output_means

            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * dataset.data['active_entries']
        else:
            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_scaled - dataset.data['outputs']) ** 2) * dataset.data['active_entries']

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        mse_step = np.sqrt(mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)) / dataset.norm_const
        mse_orig = mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)
        mse_orig = mse_orig.mean()
        rmse_normalised_orig = np.sqrt(mse_orig) / dataset.norm_const

        # Masked averaging over all dimensions at once
        mse_all = mse.sum() / dataset.data['active_entries'].sum()
        rmse_normalised_all = np.sqrt(mse_all) / dataset.norm_const

        if percentage:
            rmse_normalised_orig *= 100.0
            rmse_normalised_all *= 100.0
            mse_step *= 100

        if one_step_counterfactual:
            # Only considering last active entry with actual counterfactuals
            num_samples, time_dim, output_dim = dataset.data['active_entries'].shape
            last_entries = dataset.data['active_entries'] - np.concatenate([dataset.data['active_entries'][:, :-1, :],
                                                                            np.zeros((num_samples, 1, output_dim))], axis=1)
            if unscale:
                mse_last = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * last_entries
            else:
                mse_last = ((outputs_scaled - dataset.data['outputs']) ** 2) * last_entries

            mse_last = mse_last.sum() / last_entries.sum()
            rmse_normalised_last = np.sqrt(mse_last) / dataset.norm_const

            if percentage:
                rmse_normalised_last *= 100.0

            return rmse_normalised_orig, rmse_normalised_all, rmse_normalised_last, mse_step

        return rmse_normalised_orig, rmse_normalised_all
    
    
    def bce_loss(self, treatment_pred, current_treatments, kind='predict'):
        mode = self.hparams.dataset.treatment_mode
        bce_weights = torch.tensor(self.bce_weights).type_as(current_treatments) if self.hparams.exp.bce_weight else None

        if kind == 'predict':
            bce_loss = bce(treatment_pred, current_treatments, mode, bce_weights)
        elif kind == 'confuse':
            uniform_treatments = torch.ones_like(current_treatments)
            if mode == 'multiclass':
                uniform_treatments *= 1 / current_treatments.shape[-1]
            elif mode == 'multilabel':
                uniform_treatments *= 0.5
            bce_loss = bce(treatment_pred, uniform_treatments, mode)
        else:
            raise NotImplementedError()
        return bce_loss
    
    
    def training_step(self, batch, batch_ind, optimizer_idx=0):
        for par in self.parameters():
            par.requires_grad = True
        
        outcome_pred, treatment_pred, trend, season = self(batch)

        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        p_t = self.compute_p_t(batch)
        
        
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        
        bce_loss = (last_entries.squeeze(-1) * bce_loss).sum() / last_entries.sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.train_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
            
        return loss
        
        
    def validation_step(self, batch, batch_ind):
        outcome_pred, treatment_pred, trend, season = self(batch)
        

        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        p_t = self.compute_p_t(batch)
        
    
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        
        bce_loss = (last_entries.squeeze(-1) * bce_loss).sum() / last_entries.sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)

    def test_step(self, batch, batch_ind, **kwargs):
        treatment_pred, outcome_pred, _, _ = self(batch)
        
        p_t = self.compute_p_t(batch)
        
        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(batch['prev_outputs'], batch['current_treatments'], p_t)
        
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        mse_loss = (batch['active_entries'] * mse_loss).sum() / batch['active_entries'].sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True)

    def predict_step(self, batch, batch_idx, dataset_idx=0):
        """
        Generates normalised output predictions
        """
        outcome_pred, _, trend, season = self(batch)
        return outcome_pred.cpu(), trend.cpu(), season.cpu()
    
    def get_predictions(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        outcome_pred, _, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return outcome_pred.numpy()
    
    def get_trend_seasonality(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        _, trend, seasonality = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return trend.numpy(), seasonality.numpy()
