import torch
import torch.nn as nn
import sys
sys.path.append('/workspace/junghee.kim/Project/Causality/Causal_Effect_Estimation/trend_seasonality_causal_structure/src/')

from layers.Embed import DataEmbedding
from layers.Autoformer_EncDec import series_decomp, series_decomp_multi
import torch.nn.functional as F

from pytorch_lightning import LightningModule
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from src.models.utils import bce
from src.data import RealDatasetCollection, SyntheticDatasetCollection
import logging
from omegaconf import DictConfig

logger = logging.getLogger(__name__)


class MIC(nn.Module):
    """
    MIC layer to extract local and global features
    """

    def __init__(self, feature_size=512, n_heads=8, dropout=0.05, decomp_kernel=[32], conv_kernel=[24],
                 isometric_kernel=[18, 6], device='cuda'):
        super(MIC, self).__init__()
        self.conv_kernel = conv_kernel
        self.device = device

        # isometric convolution
        self.isometric_conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
                                                       kernel_size=i, padding=0, stride=1)
                                             for i in isometric_kernel])

        # downsampling convolution: padding=i//2, stride=i
        self.conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
                                             kernel_size=i, padding=i // 2, stride=i)
                                   for i in conv_kernel])

        # upsampling convolution
        self.conv_trans = nn.ModuleList([nn.ConvTranspose1d(in_channels=feature_size, out_channels=feature_size,
                                                            kernel_size=i, padding=0, stride=i)
                                         for i in conv_kernel])
                                                                                                                                                                  
        self.decomp = nn.ModuleList([series_decomp(k) for k in decomp_kernel])
        self.merge = torch.nn.Conv2d(in_channels=feature_size, out_channels=feature_size,
                                     kernel_size=(len(self.conv_kernel), 1))

        # feedforward network
        self.conv1 = nn.Conv1d(in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1)
        self.norm1 = nn.LayerNorm(feature_size)
        self.norm2 = nn.LayerNorm(feature_size)

        self.norm = torch.nn.LayerNorm(feature_size)
        self.act = torch.nn.Tanh()
        self.drop = torch.nn.Dropout(0.05)

    def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric):
        batch, seq_len, channel = input.shape
        x = input.permute(0, 2, 1)

        # print(x.shape)
        
        # downsampling convolution
        x1 = self.drop(self.act(conv1d(x)))
        x = x1
        # print(x.shape)
        
        # isometric convolution
        zeros = torch.zeros((x.shape[0], x.shape[1], x.shape[2] - 1), device='cuda:' + str(x.get_device()))
        x = torch.cat((zeros, x), dim=-1)
        # print(x.shape)
        x = self.drop(self.act(isometric(x)))
        x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1)

        # upsampling convolution
        x = self.drop(self.act(conv1d_trans(x)))
        x = x[:, :, :seq_len]  # truncate

        x = self.norm(x.permute(0, 2, 1) + input)
        return x

    def forward(self, src):
        # multi-scale
        multi = []
        for i in range(len(self.conv_kernel)):
            src_out, trend1 = self.decomp[i](src)
            src_out = self.conv_trans_conv(src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i])
            multi.append(src_out)

            # merge
        mg = torch.tensor([], device='cuda:' + str(multi[0].get_device()))
        for i in range(len(self.conv_kernel)):
            mg = torch.cat((mg, multi[i].unsqueeze(1)), dim=1)
        mg = self.merge(mg.permute(0, 3, 1, 2)).squeeze(-2).permute(0, 2, 1)

        y = self.norm1(mg)
        y = self.conv2(self.conv1(y.transpose(-1, 1))).transpose(-1, 1)

        return self.norm2(mg + y)


class SeasonalPrediction(nn.Module):
    def __init__(self, embedding_size=512, n_heads=8, dropout=0.05, d_layers=1, decomp_kernel=[32], c_out=1,
                 conv_kernel=[2, 4], isometric_kernel=[18, 6], device='cuda'):
        super(SeasonalPrediction, self).__init__()

        self.mic = nn.ModuleList([MIC(feature_size=embedding_size, n_heads=n_heads,
                                      decomp_kernel=decomp_kernel, conv_kernel=conv_kernel,
                                      isometric_kernel=isometric_kernel, device=device)
                                  for i in range(d_layers)])

        self.projection = nn.Linear(embedding_size, c_out)

    def forward(self, dec):
        for mic_layer in self.mic:
            dec = mic_layer(dec)
        return self.projection(dec)


class Model(LightningModule):
    
    model_type = 'multi'
    
    """
    Paper link: https://openreview.net/pdf?id=zt53IDUR1U
    """
    def __init__(self, 
                 configs: DictConfig = None, 
                 dataset_collection: SyntheticDatasetCollection = None, 
                 **kwargs):
        """
        conv_kernel: downsampling and upsampling convolution kernel_size
        """
        super().__init__()
        self.dataset_collection = dataset_collection
        self.configs = configs.model.multi
        self.save_hyperparameters(configs)

        decomp_kernel = []  # kernel of decomposition operation
        isometric_kernel = []  # kernel of isometric convolution
        for ii in self.configs.conv_kernel:
            if ii % 2 == 0:  # the kernel of decomposition operation must be odd
                decomp_kernel.append(ii + 1)
                isometric_kernel.append((self.configs.seq_len + self.configs.pred_len + ii) // ii)
            else:
                decomp_kernel.append(ii)
                isometric_kernel.append((self.configs.seq_len + self.configs.pred_len + ii - 1) // ii)

        self.task_name = self.configs.task_name
        self.pred_len = self.configs.pred_len
        self.seq_len = self.configs.seq_len
        self.has_vitals = configs.model.multi.has_vitals
        self.dim_outcome = configs.model.multi.outcome_dims
        self.dim_treatment = configs.model.multi.treatment_dims
        self.dim_static = configs.model.multi.static_dims
        self.dim_vitals = configs.model.multi.vitals_dims if self.has_vitals else None
        self.dim_input = self.dim_treatment
        self.dim_input += self.dim_vitals if self.has_vitals else 0
        self.dim_input += self.dim_outcome
        self.configs.enc_in = self.dim_input
        self.autoregressive = configs.dataset.autoregressive
        
        self.static_transformation = nn.Linear(1, 
                                               self.pred_len)
        
        # self.static_transformation = nn.Linear(self.dim_static, 
        #                                        self.pred_len)

        # Multiple Series decomposition block from FEDformer
        self.decomp_multi = series_decomp_multi(decomp_kernel)

        # embedding
        self.dec_embedding = DataEmbedding(self.configs.enc_in, self.configs.d_model, self.configs.embed, self.configs.freq,
                                           self.configs.dropout)

        # print("cuda: ", configs.gpus[0])
        self.conv_trans = SeasonalPrediction(embedding_size=self.configs.d_model, n_heads=self.configs.n_heads,
                                             dropout=self.configs.dropout,
                                             d_layers=self.configs.d_layers, decomp_kernel=decomp_kernel,
                                             c_out=self.configs.c_out, conv_kernel=self.configs.conv_kernel,
                                             isometric_kernel=isometric_kernel, device=torch.device('cuda:' + str(configs.gpus[0])))
        
        self.outcome_layer = nn.Linear(self.dim_outcome*2 + self.dim_static + self.dim_treatment, self.configs.c_out)
        self.treat_layer = nn.Linear(self.dim_outcome + self.dim_static, self.dim_treatment)
        
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            # refer to DLinear
            self.regression = nn.Linear(self.configs.seq_len, self.configs.pred_len)
            self.regression.weight = nn.Parameter(
                (1 / self.configs.pred_len) * torch.ones([self.configs.pred_len, self.configs.seq_len]),
                requires_grad=True)
            self.channel_transform = nn.Linear(self.dim_input, self.configs.c_out)
            self.channel_transform.weight = nn.Parameter(
                (1 / self.configs.c_out) * torch.ones([self.configs.c_out, self.dim_input]),
                requires_grad=True)
            
        if self.task_name == 'imputation':
            self.projection = nn.Linear(self.configs.d_model, self.configs.c_out, bias=True)
        if self.task_name == 'anomaly_detection':
            self.projection = nn.Linear(self.configs.d_model, self.configs.c_out, bias=True)
        if self.task_name == 'classification':
            self.act = F.gelu
            self.dropout = nn.Dropout(self.configs.dropout)
            self.projection = nn.Linear(self.configs.c_out * self.configs.seq_len, self.configs.num_class)

    def forecast(self, batch):
        prev_outputs = batch['prev_outputs']                   # B X T X input_dims
        prev_treatments = batch['prev_treatments']             # B X T X treat_dims
        static_features = batch['static_features']             # B X static_dims
        curr_treatments = batch['current_treatments']          # B X T X treat_dims
        vitals = batch['vitals'] if self.has_vitals else None  # B X T X vitals_dims
        nan_mask = ~prev_outputs.isnan().any(axis=-1)
        prev_outputs[~nan_mask] = 0
        
        # static_features = self.static_transformation(static_features.unsqueeze(2)).permute(0, 2, 1)  # B X dim_static X pred_len
        static_features = self.static_transformation(static_features.unsqueeze(-1)).permute(0, 2, 1)
        
        x_enc = torch.cat((prev_outputs, prev_treatments), dim=-1)         # B X T X (dim_output + dim_treat)
        if self.has_vitals:
            x_enc = torch.cat((x_enc, vitals), dim=-1)
        
        # Multi-scale Hybrid Decomposition
        seasonal_init_enc, trend = self.decomp_multi(x_enc)
        # print(self.configs.seq_len)
        
        # print("seasonal init: ", seasonal_init_enc.shape)
        # print("trend: ", trend.shape)
        trend = self.regression(trend.permute(0, 2, 1)).permute(0, 2, 1)
        trend = self.channel_transform(trend)
        # print("trend: ", trend.shape)
        
        # embedding
        zeros = torch.zeros([x_enc.shape[0], self.pred_len, x_enc.shape[2]], device=x_enc.device)
        seasonal_init_dec = torch.cat([seasonal_init_enc[:, -self.seq_len:, :], zeros], dim=1)
        # print("seasonal_init_dec: ", seasonal_init_dec.shape)
        # dec_out = self.dec_embedding(seasonal_init_dec, x_mark_dec)
        seasonal_out = self.dec_embedding(seasonal_init_dec, x_mark=None)
        # print("dec_out: ", dec_out.shape)
        seasonal_out = self.conv_trans(seasonal_out)
        # print("dec_out: ", dec_out.shape)
        # print(seasonal_out.shape)
        # print(trend.shape)
        
        # print(static_features.shape)
        est_outcome = self.outcome_layer(torch.cat([trend[:, -self.pred_len:, :], seasonal_out[:, -self.pred_len:, :], static_features, curr_treatments], dim=-1))    # B X pred_len X dim_out
        # print("est_out: ", est_outcome.shape)
        est_treatment = self.treat_layer(torch.cat([trend[:, -self.pred_len:, :], static_features], dim=-1))                                    # B X pred_len X dim_treat
        
        # dec_out = seasonal_out[:, -self.pred_len:, :] + trend[:, -self.pred_len:, :]
        return est_outcome, est_treatment, trend, seasonal_out

    def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
        # Multi-scale Hybrid Decomposition
        seasonal_init_enc, trend = self.decomp_multi(x_enc)

        # embedding
        dec_out = self.dec_embedding(seasonal_init_enc, x_mark_dec)
        dec_out = self.conv_trans(dec_out)
        dec_out = dec_out + trend
        return dec_out

    def anomaly_detection(self, x_enc):
        # Multi-scale Hybrid Decomposition
        seasonal_init_enc, trend = self.decomp_multi(x_enc)

        # embedding
        dec_out = self.dec_embedding(seasonal_init_enc, None)
        dec_out = self.conv_trans(dec_out)
        dec_out = dec_out + trend
        return dec_out

    def classification(self, x_enc, x_mark_enc):
        # Multi-scale Hybrid Decomposition
        seasonal_init_enc, trend = self.decomp_multi(x_enc)
        # embedding
        dec_out = self.dec_embedding(seasonal_init_enc, None)
        dec_out = self.conv_trans(dec_out)
        dec_out = dec_out + trend

        # Output from Non-stationary Transformer
        output = self.act(dec_out)  # the output transformer encoder/decoder embeddings don't include non-linearity
        output = self.dropout(output)
        output = output * x_mark_enc.unsqueeze(-1)  # zero-out padding embeddings
        output = output.reshape(output.shape[0], -1)  # (batch_size, seq_length * d_model)
        output = self.projection(output)  # (batch_size, num_classes)
        return output

    def forward(self, batch, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            est_outcome, est_treatment, trend, seasonality = self.forecast(batch)
            return est_outcome, est_treatment, trend, seasonality  # [B, L, D]
        # if self.task_name == 'imputation':
        #     dec_out = self.imputation(
        #         x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
        #     return dec_out  # [B, L, D]
        # if self.task_name == 'anomaly_detection':
        #     dec_out = self.anomaly_detection(x_enc)
        #     return dec_out  # [B, L, D]
        # if self.task_name == 'classification':
        #     dec_out = self.classification(x_enc, x_mark_enc)
        #     return dec_out  # [B, N]
        return None


    def generate_continuous_mask(self, B, T, n=5, l=0.1):
        res = torch.full((B, T), True, dtype=torch.bool)
        if isinstance(n, float):
            n = int(n * T)
        n = max(min(n, T // 2), 1)
        
        if isinstance(l, float):
            l = int(l * T)
        l = max(l, 1)
        
        for i in range(B):
            for _ in range(n):
                t = np.random.randint(T-l+1)
                res[i, t:t+l] = False
        return res


    def generate_binomial_mask(self, B, T, p=0.5):
        return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)
    
    def _get_optimizer(self, param_optimizer: list):
        no_decay = ['bias', 'layer_norm']
        sub_args = self.hparams.model[self.model_type]
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': sub_args['optimizer']['weight_decay'],
            },
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]
        lr = sub_args['optimizer']['learning_rate']
        optimizer_cls = sub_args['optimizer']['optimizer_cls']
        if optimizer_cls.lower() == 'adamw':
            optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'adam':
            optimizer = optim.Adam(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == 'sgd':
            optimizer = optim.SGD(optimizer_grouped_parameters, lr=lr,
                                  momentum=sub_args['optimizer']['momentum'])
        else:
            raise NotImplementedError()

        return optimizer

    def _get_lr_schedulers(self, optimizer):
        if not isinstance(optimizer, list):
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
            return [optimizer], [lr_scheduler]
        else:
            lr_schedulers = []
            for opt in optimizer:
                lr_schedulers.append(optim.lr_scheduler.ExponentialLR(opt, gamma=0.99))
            return optimizer, lr_schedulers

    def configure_optimizers(self):
        optimizer = self._get_optimizer(list(self.named_parameters()))

        if self.hparams.model[self.model_type]['optimizer']['lr_scheduler']:
            return self._get_lr_schedulers(optimizer)

        return optimizer
    
    def train_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            # print("auto-regressive dataset!!")
            return DataLoader(self.dataset_collection.train_f, shuffle=True, batch_size=self.hparams.dataset.train_batch_size, drop_last=True)
        else:
            # print("non-autoregressive dataset!!")
            return DataLoader(self.dataset_collection.train_f_non, shuffle=True, batch_size=self.hparams.dataset.train_batch_size, drop_last=True)
    
    def val_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.val_f, shuffle=False, batch_size=self.hparams.dataset.val_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.val_f_non, shuffle=False, batch_size=self.hparams.dataset.val_batch_size, drop_last=True)
    
    def predict_dataloader(self) -> DataLoader:
        # sub_args = self.hparams.model[self.model_type]
        if self.autoregressive:
            return DataLoader(self.dataset_collection.test_cf_one_step, shuffle=False, batch_size=self.hparams.dataset.test_batch_size, drop_last=True)
        else:
            return DataLoader(self.dataset_collection.test_cf_one_step_non, shuffle=False, batch_size=self.hparams.dataset.test_batch_size, drop_last=True)
    
    # def compute_p_t(self, batch):
    #     # p_t: [4]

    #     p_t = torch.zeros(batch['prev_treatments'].shape[2])
    #     for cur_ts in range(0, batch['prev_treatments'].shape[1]):
    #         for cur_b in range(0, batch['prev_treatments'].shape[0]):
    #             for cur_t in range(0, batch['prev_treatments'].shape[2]):
    #                 if batch['prev_treatments'][cur_b, cur_ts, cur_t] == 1:
    #                     p_t[cur_t] += 1
    #     p_t /= batch['prev_treatments'].shape[0] / batch['prev_treatments'].shape[1]
    #     return p_t
    
    def compute_p_t(self, batch):
        # p_t: [T, 4]

        p_t = torch.zeros(batch['prev_treatments'].shape[1], batch['prev_treatments'].shape[2])
        for cur_ts in range(0, batch['prev_treatments'].shape[1]):
            for cur_b in range(0, batch['prev_treatments'].shape[0]):
                for cur_t in range(0, batch['prev_treatments'].shape[2]):
                    if batch['prev_treatments'][cur_b, cur_ts, cur_t] == 1:
                        p_t[cur_ts, cur_t] += 1
        p_t /= batch['prev_treatments'].shape[0]
        return p_t

    
    def compute_factual_loss(self, est, true, pi_0, p_t, t):
        # pi_0: [B, T, 4]
        # p_t:  [T, 4]
        # est, true: [B, T, 1]
        # t: [B, T, 4]
        
        
        # print("est: ", est.shape)
        # print("true: ", true.shape)
        # print("pi_0: ", pi_0.shape)  # model's estimated treatment
        # print("p_t: ", p_t.shape)    # probability of treatment
        # print("t: ", t.shape)        # to find 
        
        ''' Compute sample reweighting '''
        sample_weight = torch.zeros_like(est)
        for cur_ts in range(0, t.shape[1]):
            for cur_t1 in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, :] == cur_t1).squeeze()[:, 0]    # find index
                sample_weight[idx_cur_t, cur_ts, 0] = 1.
                if cur_ts > 1:
                    for cur_t2 in  range(0, t.shape[-1]):
                        if cur_t2 != cur_t1:
                            if p_t[cur_ts, cur_t2] < 1e-4 or (pi_0[idx_cur_t, cur_ts, cur_t1] < 1e-4).any():
                                continue
                            sample_weight[idx_cur_t, cur_ts, 0] += (pi_0[idx_cur_t, cur_ts, cur_t2])/(pi_0[idx_cur_t, cur_ts, cur_t1] + 1e-7) * (p_t[cur_ts, cur_t1]/(p_t[cur_ts, cur_t2] + 1e-7))
        
        return torch.mean(sample_weight * torch.square(est - true))
    
    
    def compute_cross_entropy_loss(self, est, target):
        # target: treatment (0 or 1 one-hot vector)
        # est: class probability
        
        # sigma = torch.sigmoid(tf.matmul(h_rep_norm, W) + b)
        
        # pi_0 = t * sigma + (1.0-t) * (1.0-sigma)
        # loss = -torch.mean( t * torch.log(sigma) + (1.0-t) * torch.log(1.0-sigma) )
        
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(est, target)
        
        return loss
    
    def mmd2_lin_dfr(self, upsilon, t, p_t):
        ''' Linear MMD '''
        # t: treatment variable [B, T, 4]
        # X: data               [B, T, 1]
        # p_t: probability        [T, 4]
        
        mmd = 0
        for cur_ts in range(0, t.shape[1]):
            list_upsilon_mean = []
            for cur_tr in range(0, t.shape[-1]):
                idx_cur_t = (t[:, cur_ts, cur_tr] == 1).squeeze()
                
                list_upsilon_mean.append(torch.mean(upsilon[idx_cur_t, cur_ts, :], dim=0))
            # print(list_X_mean)   # nan
            for i in range(1, len(list_upsilon_mean)):
                if cur_ts > 1:
                    flag = True
                    for i in range(len(list_upsilon_mean)):
                        if torch.isnan(list_upsilon_mean[i]).any():
                            flag = False
                    if flag:
                        mmd += torch.sum(torch.square(2.0*p_t[cur_ts, i]*list_upsilon_mean[i] - 2.0*p_t[cur_ts, i-1]*list_upsilon_mean[i-1]))
        
        return mmd
    
    def get_normalised_masked_rmse(self, dataset: Dataset, one_step_counterfactual=True):
        outputs_scaled = self.get_predictions(dataset)
        
        unscale = self.hparams.exp.unscale_rmse
        percentage = self.hparams.exp.percentage_rmse
        
        if unscale:
            print("unscale")
        else:
            print("not unscale")
            
        if percentage:
            print("percentage error")
        else:
            print("not percentage error")

        if unscale:
            output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']
            outputs_unscaled = outputs_scaled * output_stds + output_means

            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * dataset.data['active_entries']
        else:
            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_scaled - dataset.data['outputs']) ** 2) * dataset.data['active_entries']

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        mse_step = np.sqrt(mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)) / dataset.norm_const
        mse_orig = mse.sum(0).sum(-1) / dataset.data['active_entries'].sum(0).sum(-1)
        mse_orig = mse_orig.mean()
        rmse_normalised_orig = np.sqrt(mse_orig) / dataset.norm_const

        # Masked averaging over all dimensions at once
        mse_all = mse.sum() / dataset.data['active_entries'].sum()
        rmse_normalised_all = np.sqrt(mse_all) / dataset.norm_const

        if percentage:
            rmse_normalised_orig *= 100.0
            rmse_normalised_all *= 100.0
            mse_step *= 100

        if one_step_counterfactual:
            # Only considering last active entry with actual counterfactuals
            num_samples, time_dim, output_dim = dataset.data['active_entries'].shape
            last_entries = dataset.data['active_entries'] - np.concatenate([dataset.data['active_entries'][:, :-1, :],
                                                                            np.zeros((num_samples, 1, output_dim))], axis=1)
            if unscale:
                mse_last = ((outputs_unscaled - dataset.data['unscaled_outputs']) ** 2) * last_entries
            else:
                mse_last = ((outputs_scaled - dataset.data['outputs']) ** 2) * last_entries

            mse_last = mse_last.sum() / last_entries.sum()
            rmse_normalised_last = np.sqrt(mse_last) / dataset.norm_const

            if percentage:
                rmse_normalised_last *= 100.0

            return rmse_normalised_orig, rmse_normalised_all, rmse_normalised_last, mse_step

        return rmse_normalised_orig, rmse_normalised_all
    
    
    def bce_loss(self, treatment_pred, current_treatments, kind='predict'):
        mode = self.hparams.dataset.treatment_mode
        bce_weights = torch.tensor(self.bce_weights).type_as(current_treatments) if self.hparams.exp.bce_weight else None

        if kind == 'predict':
            bce_loss = bce(treatment_pred, current_treatments, mode, bce_weights)
        elif kind == 'confuse':
            uniform_treatments = torch.ones_like(current_treatments)
            if mode == 'multiclass':
                uniform_treatments *= 1 / current_treatments.shape[-1]
            elif mode == 'multilabel':
                uniform_treatments *= 0.5
            bce_loss = bce(treatment_pred, uniform_treatments, mode)
        else:
            raise NotImplementedError()
        return bce_loss
    
    
    def training_step(self, batch, batch_ind, optimizer_idx=0):
        for par in self.parameters():
            par.requires_grad = True
        
        outcome_pred, treatment_pred, trend, season = self(batch)

        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        p_t = self.compute_p_t(batch)
        
        # mse_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t, batch['current_treatments'], last_entries)
        # factual_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t, batch['current_treatments'])
        # mse_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t)
        
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        
        bce_loss = (last_entries.squeeze(-1) * bce_loss).sum() / last_entries.sum()
        # mse_loss = (last_entries * mse_loss).sum() / last_entries.sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.train_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
            
        return loss
        
        
    def validation_step(self, batch, batch_ind):
        outcome_pred, treatment_pred, trend, season = self(batch)
        
        # print("output: ", batch['outputs'].shape)

        mse_loss = F.mse_loss(outcome_pred, batch['outputs'], reduce=True)
        # bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        num_samples, time_dim, output_dim = batch['active_entries'].shape
        last_entries = batch['active_entries'] - torch.concat([batch['active_entries'][:, :-1, :],
                                                                 torch.zeros((num_samples, 1, output_dim)).to(treatment_pred.device)], dim=1)
        
        p_t = self.compute_p_t(batch)
        
        # for key in batch.keys():
        #     print(key)
        #     print(batch[key].shape)
        
        # factual_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t, batch['current_treatments'])
        # mse_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t)
        
        # print(batch['current_treatments'].shape)
        # print(treatment_pred.shape)
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(season, batch['current_treatments'], p_t)
        
        bce_loss = (last_entries.squeeze(-1) * bce_loss).sum() / last_entries.sum()
        # mse_loss = (last_entries * mse_loss).sum() / last_entries.sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)

    def test_step(self, batch, batch_ind, **kwargs):
        treatment_pred, outcome_pred, _, _ = self(batch)
        
        p_t = self.compute_p_t(batch)
        
        mse_loss = self.compute_factual_loss(outcome_pred, batch['outputs'], torch.softmax(treatment_pred, dim=-1), p_t, batch['current_treatments'])
        
        
        bce_loss = self.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
        
        imbalanced_loss = self.mmd2_lin_dfr(batch['prev_outputs'], batch['current_treatments'], p_t)
        
        bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
        mse_loss = (batch['active_entries'] * mse_loss).sum() / batch['active_entries'].sum()
        
        loss = bce_loss + mse_loss + imbalanced_loss
        
        subset_name = self.val_dataloader().dataset.subset_name
        self.log(f'{subset_name}_loss', loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)
        self.log(f'{subset_name}_bce_loss', bce_loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(f'{subset_name}_mse_loss', mse_loss, on_epoch=True, on_step=False, sync_dist=True)

    def predict_step(self, batch, batch_idx, dataset_idx=0):
        """
        Generates normalised output predictions
        """
        outcome_pred, _, trend, season = self(batch)
        return outcome_pred.cpu(), trend.cpu(), season.cpu()
    
    def get_predictions(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        outcome_pred, _, _ = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return outcome_pred.numpy()
    
    def get_trend_seasonality(self, dataset: Dataset) -> np.array:
        # Creating Dataloader
        data_loader = DataLoader(dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False)
        _, trend, seasonality = [torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))]
        return trend.numpy(), seasonality.numpy()