# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import List, Tuple
from itertools import product

# Third-party imports
import mxnet as mx
import numpy as np

# First-party imports
from gluonts.core.component import validated
from gluonts.mx import Tensor
from gluonts.mx.distribution import DistributionOutput
from gluonts.mx.util import assert_shape, weighted_average
from gluonts.mx.distribution import LowrankMultivariateGaussian
from gluonts.model.deepvar._network import DeepVARNetwork


def compute_psigma(F, S, A):
    # compute reconciliation matrix for each sample
    # I-\Sigma\tp{A}(A\Sigma\tp{A})^{-1}A

    SAT = mx.nd.dot(S, A, transpose_b=True)
    ASAT = mx.nd.swapaxes(
        mx.nd.dot(mx.nd.swapaxes(SAT, -2, -1), A, transpose_b=True), -2, -1
    )
    ASAT_inv = mx.nd.linalg_inverse(ASAT)
    ASAT_inv_times_A = mx.nd.dot(ASAT_inv, A)
    SAT_times_ASAT_inv_times_A = mx.nd.batch_dot(SAT, ASAT_inv_times_A)
    Psigma = F.eye(A.shape[1]) - SAT_times_ASAT_inv_times_A
    # Psigma = F.eye(distr.base_distribution.dim) - SAT_times_ASAT_inv_times_A

    return Psigma


def batched_diag_mvn_logpdf(mu, var, target):
    """
    Computes log-likelihood assuming diagonal covariance.

    Parameters
    ----------
    mu : NDArray
        Mean tensor, shape (B, m, m, n)
    var : NDArray
        Covariance tensor (assumed diagonal), shape (B, m, m, n, n)
    target : NDArray
        Ground truth, shape (B, m, m, n)

    Returns
    -------
    log_prob : NDArray
        Log probability per sample, shape (B, m, m)
    """
    # Extract variance vector from diagonal
    diag_var = mx.nd.linalg.extractdiag(var)  # shape (B, m, m, n)
    sigma2 = diag_var + 1e-8  # avoid divide-by-zero
    sigma = mx.nd.sqrt(sigma2)

    # Quadratic term: (y - mu)^2 / sigma^2
    diff2 = mx.nd.square(target - mu)
    quad_term = mx.nd.sum(diff2 / sigma2, axis=-1)  # shape (B, m, m)

    # Log determinant term: sum(log(sigma^2)) = 2 * sum(log(sigma))
    log_det = mx.nd.sum(mx.nd.log(sigma2), axis=-1)  # shape (B, m, m)

    n = mu.shape[-1]
    const = n * np.log(2 * np.pi)

    log_prob = -0.5 * (const + log_det + quad_term)  # shape (B, m, m)

    return log_prob

def gaussian_crps(mu, sigma, y):
    """
    Closed-form CRPS for univariate Gaussian N(mu, sigma^2),
    fully using mx.nd.

    Parameters
    ----------
    y : mx.nd.NDArray, shape (N,)
        Ground truth targets
    mu : mx.nd.NDArray, shape (N,)
        Predicted means
    sigma : mx.nd.NDArray, shape (N,)
        Predicted std deviations

    Returns
    -------
    crps : mx.nd.NDArray, shape (N,)
        CRPS values
    """
    z = (y - mu) / sigma
    sqrt_2 = mx.nd.sqrt(mx.nd.array([2.0], ctx=y.context))
    sqrt_pi = np.sqrt(np.pi)

    pdf = mx.nd.exp(-0.5 * z**2) / np.sqrt(2 * np.pi)
    cdf = 0.5 * (1 + mx.nd.erf(z / sqrt_2))

    crps = sigma * (z * (2 * cdf - 1) + 2 * pdf - 1 / sqrt_pi)
    return crps

class ProbHardE2ENetwork(DeepVARNetwork):
    @validated()
    def __init__(
        self,
        M,
        A,
        num_layers: int,
        num_cells: int,
        cell_type: str,
        history_length: int,
        context_length: int,
        prediction_length: int,
        distr_output: DistributionOutput,
        dropout_rate: float,
        lags_seq: List[int],
        target_dim: int,
        conditioning_length: int,
        cardinality: List[int] = [1],
        embedding_dimension: int = 1,
        scaling: bool = True,
        seq_axis: List[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            num_layers=num_layers,
            num_cells=num_cells,
            cell_type=cell_type,
            history_length=history_length,
            context_length=context_length,
            prediction_length=prediction_length,
            distr_output=distr_output,
            dropout_rate=dropout_rate,
            lags_seq=lags_seq,
            target_dim=target_dim,
            conditioning_length=conditioning_length,
            cardinality=cardinality,
            embedding_dimension=embedding_dimension,
            scaling=scaling,
            **kwargs
        )

        self.M = M
        self.A = A
        self.seq_axis = seq_axis

    def reconcile_mu_var(
                        self,
                        reconciliation_mat: Tensor,
                        seq_axis = None,
                        distr = None,
                        A: Tensor = None,
                        ) -> Tensor:

        F = distr.F
        S = distr.variance  # (batch_size, seq_len, target_dim, target_dim)
        mu = distr.mean
        #P = mx.nd.tile(self.M, reps=(S.shape[0], S.shape[1], 1, 1))
        P = compute_psigma(F, S, A) # (batch_size, seq_len, target_dim, target_dim)

        rec_mu = mx.nd.batch_dot(P,mu.expand_dims(axis=-1)).squeeze(-1)
        rec_S = mx.nd.batch_dot(P, S)
        #rec_S = mx.nd.batch_dot(rec_S, mx.nd.transpose(P,axes=(0,1,3,2)))
        return rec_mu, rec_S


    # def reconcile_samples(    
    #                     self,
    #                     reconciliation_mat: Tensor,
    #                     samples: Tensor,
    #                     seq_axis = None,
    #                     distr = None,
    #                     A: Tensor = None,):
    #     """
    #     Computes coherent samples by projecting unconstrained `samples` using the matrix `self.M`.

    #     Parameters
    #     ----------
    #     samples
    #         Unconstrained samples.
    #         Shape: (num_samples, batch_size, seq_len, num_ts) during training and
    #                (num_parallel_samples x batch_size, seq_len, num_ts) during prediction.

    #     Returns
    #     -------
    #     Coherent samples
    #         Tensor, shape same as that of `samples`.

    #     """

    #     if self.seq_axis:
    #         # bring the axis to iterate in the beginning
    #         samples = mx.nd.moveaxis(samples, self.seq_axis, list(range(len(self.seq_axis))))

    #         out = [
    #             mx.nd.dot(samples[idx], self.M, transpose_b=True)
    #             for idx in product(*[range(x) for x in [samples.shape[d] for d in range(len(self.seq_axis))]])
    #         ]

    #         # put the axis in the correct order again
    #         out = mx.nd.concat(*out, dim=0).reshape(samples.shape)
    #         out = mx.nd.moveaxis(out, list(range(len(self.seq_axis))), self.seq_axis)
    #         return out
    #     else:
    #         return mx.nd.dot(samples, self.M, transpose_b=True)

    
    def reconcile_samples(
    self,
    reconciliation_mat: Tensor,
    samples: Tensor,
    seq_axis = None,
    distr = None,
    A: Tensor = None,
    ) -> Tensor:
        """
        Computes coherent samples by multiplying unconstrained `samples` with
        `reconciliation_mat`.

        Parameters
        ----------
        reconciliation_mat
            Shape: (target_dim, target_dim)
        samples
            Unconstrained samples
            Shape: `(*batch_shape, target_dim)`
            During training: (num_samples, batch_size, seq_len, target_dim)
            During prediction: (num_parallel_samples x batch_size, seq_len,
            target_dim)
        seq_axis
            Specifies the list of axes that should be reconciled sequentially.
            By default, all axes are processeed in parallel.

        Returns
        -------
        Tensor, shape same as that of `samples`
            Coherent samples
            
        """
        # print(samples[0].shape, distr.variance.shape, distr.mean.shape)
        if not seq_axis:
            if distr is None:
                return mx.nd.dot(samples, reconciliation_mat, transpose_b=True)
            else:

                # print("Here")

                F = distr.F
                S = distr.variance  # (batch_size, seq_len, target_dim, target_dim)
                Psigma = compute_psigma(F, S, A) # (batch_size, seq_len, target_dim, target_dim)
                # print(Psigma.shape)

                # dimensions in training:
                # Psigma.shape
                # Out[3]: (17, 42, 7, 7)    - (batch_size, seq_len, target_dim, target_dim)
                # samples.shape
                # Out[4]: (200, 17, 42, 7)  - (num_samples_for_loss, batch_size, seq_len, target_dim)

                # dimensions in prediction:
                # Psigma.shape
                # Out[9]: (100, 1, 7, 7)     - (batch_size, 1, target_dim, target_dim)
                # samples.shape
                # Out[10]: (100, 1, 7)       - (num_samples, 1, target_dim)

                if len(Psigma.shape) == len(samples.shape) + 1:  # prediction case
                    return mx.nd.batch_dot(Psigma, samples.expand_dims(axis=-1)).squeeze(-1)
                else:  # training case
                    out = [
                        mx.nd.batch_dot(
                            Psigma, samples[sample_idx].expand_dims(axis=-1)
                        ).squeeze(-1) for sample_idx in range(samples.shape[0])
                    ]
                    return mx.nd.concat(*out, dim=0).reshape(samples.shape)
                    print(1)
                    # return mx.nd.batch_dot(
                    #     mx.nd.repeat(
                    #         Psigma.expand_dims(axis=0), repeats=samples.shape[0], axis=0
                    #     ),
                    #     samples.expand_dims(axis=-1),
                    # ).squeeze(-1)

        else:
            num_dims = len(samples.shape)

            last_dim_in_seq_axis = num_dims - 1 in seq_axis or -1 in seq_axis
            assert not last_dim_in_seq_axis, (
                "The last dimension cannot be processed iteratively. Remove axis"
                f" {num_dims - 1} (or -1) from `seq_axis`."
            )

            # In this case, reconcile samples by going over each index in
            # `seq_axis` iteratively. Note that `seq_axis` can be more than one
            # dimension.
            num_seq_axes = len(seq_axis)

            # bring the axes to iterate in the beginning
            samples = mx.nd.moveaxis(samples, seq_axis, list(range(num_seq_axes)))

            seq_axes_sizes = samples.shape[:num_seq_axes]
            out = [
                mx.nd.dot(samples[idx], reconciliation_mat, transpose_b=True)
                # get the sequential index from the cross-product of their sizes.
                for idx in product(*[range(size) for size in seq_axes_sizes])
            ]

            # put the axis in the correct order again
            out = mx.nd.concat(*out, dim=0).reshape(samples.shape)
            out = mx.nd.moveaxis(out, list(range(len(seq_axis))), seq_axis)
            return out
        
    


    # def reconcile_samples(self, samples):
    #     """
    #     Computes coherent samples by projecting unconstrained `samples` using the matrix `self.M`.

    #     Parameters
    #     ----------
    #     samples
    #         Unconstrained samples.
    #         Shape: (num_samples, batch_size, seq_len, num_ts) during training and
    #                (num_parallel_samples x batch_size, seq_len, num_ts) during prediction.

    #     Returns
    #     -------
    #     Coherent samples
    #         Tensor, shape same as that of `samples`.

    #     """

    #     if self.seq_axis:
    #         # bring the axis to iterate in the beginning
    #         samples = mx.nd.moveaxis(samples, self.seq_axis, list(range(len(self.seq_axis))))

    #         out = [
    #             mx.nd.dot(samples[idx], self.M, transpose_b=True)
    #             for idx in product(*[range(x) for x in [samples.shape[d] for d in range(len(self.seq_axis))]])
    #         ]

    #         # put the axis in the correct order again
    #         out = mx.nd.concat(*out, dim=0).reshape(samples.shape)
    #         out = mx.nd.moveaxis(out, list(range(len(self.seq_axis))), self.seq_axis)
    #         return out
    #     else:
    #         return mx.nd.dot(samples, self.M, transpose_b=True)



    def train_hybrid_forward(
        self,
        F,
        target_dimension_indicator: Tensor,
        past_time_feat: Tensor,
        past_target_cdf: Tensor,
        past_observed_values: Tensor,
        past_is_pad: Tensor,
        future_time_feat: Tensor,
        future_target_cdf: Tensor,
        future_observed_values: Tensor,
        epoch_frac: float,
    ) -> Tuple[Tensor, ...]:
        """
        Computes the loss for training DeepVAR, all inputs tensors representing
        time series have NTC layout.

        Parameters
        ----------
        F
        target_dimension_indicator
            Indices of the target dimension (batch_size, target_dim)
        past_time_feat
            Dynamic features of past time series (batch_size, history_length,
            num_features)
        past_target_cdf
            Past marginal CDF transformed target values (batch_size,
            history_length, target_dim)
        past_observed_values
            Indicator whether or not the values were observed (batch_size,
            history_length, target_dim)
        past_is_pad
            Indicator whether the past target values have been padded
            (batch_size, history_length)
        future_time_feat
            Future time features (batch_size, prediction_length, num_features)
        future_target_cdf
            Future marginal CDF transformed target values (batch_size,
            prediction_length, target_dim)
        future_observed_values
            Indicator whether or not the future values were observed
            (batch_size, prediction_length, target_dim)

        Returns
        -------
        distr
            Loss with shape (batch_size, 1)
        likelihoods
            Likelihoods for each time step
            (batch_size, context + prediction_length, 1)
        distr_args
            Distribution arguments (context + prediction_length,
            number_of_arguments)
        """

        seq_len = self.context_length + self.prediction_length

        # unroll the decoder in "training mode", i.e. by providing future data
        # as well
        rnn_outputs, _, scale, lags_scaled, inputs = self.unroll_encoder(
            F=F,
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
            future_target_cdf=future_target_cdf,
            target_dimension_indicator=target_dimension_indicator,
        )

        # put together target sequence
        # (batch_size, seq_len, target_dim)
        target = F.concat(
            past_target_cdf.slice_axis(
                axis=1, begin=-self.context_length, end=None
            ),
            future_target_cdf,
            dim=1,
        )

        # assert_shape(target, (-1, seq_len, self.target_dim))

        distr, distr_args = self.distr(
            time_features=inputs,
            rnn_outputs=rnn_outputs,
            scale=scale,
            lags_scaled=lags_scaled,
            target_dimension_indicator=target_dimension_indicator,
            seq_len=self.context_length + self.prediction_length,
        )

        # print("chr train samples",self.coherent_train_samples)

        # Assert CRPS_weight, likelihood_weight, and coherent_train_samples have harmonious values
        assert self.CRPS_weight >= 0.0, 'CRPS weight must be non-negative'
        assert self.likelihood_weight >= 0.0, 'Likelihood weight must be non-negative!'
        assert self.likelihood_weight + self.CRPS_weight > 0.0, 'At least one of CRPS or likelihood weights must be non-zero'
        if self.CRPS_weight == 0.0 and self.coherent_train_samples:
            assert 'No sampling being performed. coherent_train_samples flag is ignored'
        if not self.sample_LH == 0.0 and self.coherent_train_samples:
            assert 'No sampling being performed. coherent_train_samples flag is ignored'
        if self.likelihood_weight == 0.0 and self.sample_LH:\
            assert 'likelihood_weight is 0 but sample likelihoods are still being calculated. Set sample_LH=0 when likelihood_weight=0'

        # Sample from multivariate Gaussian distribution if we are using CRPS or LH-sample loss
        # dim: (num_samples, batch_size, seq_len, m)
        # print("Loss samples", self.num_samples_for_loss)
        if self.sample_LH or (self.CRPS_weight > 0.0) or True:
            # raw_samples = distr.sample_rep(num_samples=self.num_samples_for_loss, dtype='float32')
            # Only project during training if we have already sampled
            # print("Epoch fraction: ", self.warmstart_epoch_frac)
            if self.coherent_train_samples: 
                # and epoch_frac > self.warmstart_epoch_frac:
                # coherent_samples = self.reconcile_samples(raw_samples)
                # coherent_samples = self.reconcile_samples(
                # reconciliation_mat=self.M,
                # samples=raw_samples,
                # seq_axis=self.seq_axis,
                # distr=distr,
                # A=self.A,
                # )
                # assert_shape(coherent_samples, raw_samples.shape)
                # samples = coherent_samples

                rec_mu, rec_s = self.reconcile_mu_var(
                                reconciliation_mat=self.M,
                                seq_axis=self.seq_axis,
                                distr=distr,
                                A=self.A)
                                
                # rec_var = mx.nd.diag(rec_s, axis1=-2, axis2=-1)
                # samples = raw_samples

            else:
                rec_mu = distr.mean
                rec_s = distr.variance

                # samples = raw_samples

        # Compute likelihoods (always do this step)
        # we sum the last axis to have the same shape for all likelihoods
        # (batch_size, seq_len, 1)
        # calculates likelihood of NN prediction under the current learned distribution parameters

        ## Disable likelihood
        # assert self.likelihood_weight == 0
        if self.sample_LH and False: # likelihoods on samples
            # Compute mean and variance
            mu = samples.mean(axis=0)
            var = mx.nd.square(samples - samples.mean(axis=0)).mean(axis=0)
            likelihoods = -LowrankMultivariateGaussian(
                        dim=samples.shape[-1], rank=0, mu=mu, D=var
                            ).log_prob(target).expand_dims(axis=-1)
        else: # likelihoods on network params
            #likelihoods = -batched_diag_mvn_logpdf(rec_mu, rec_s, target).expand_dims(axis=-1)
            likelihoods = -distr.log_prob(target).expand_dims(axis=-1)
        # assert_shape(likelihoods, (-1, seq_len, 1))

        # Pick loss function approach. This avoids sampling if we are only training with likelihoods on params
        # print("CRPS weight", self.CRPS_weight)
        # print(target.shape)
        if self.CRPS_weight > 0.0:  # and epoch_frac > self.warmstart_epoch_frac:
            # loss_CRPS = distr.crps(samples, target)
            # print("CRPS:", loss_CRPS.shape)
            # loss_unmasked = self.CRPS_weight * loss_CRPS + self.likelihood_weight * likelihoods
            # mu = samples.mean(axis=0)
            # var = (1/(len(samples)-1))*mx.nd.square(samples - samples.mean(axis=0)).mean(axis=0)
            # stddev = mx.nd.sqrt(var)
            # stdevv = mx.nd.linalg.extractdiag(distr.variance)
            # print("ttt", mu.mean(), distr.mean.mean(), var.mean(), distr.variance.mean())
            #loss_unmasked = gaussian_crps(mu, stddev, target).mean(axis=-1,keepdims=True)
            # print(distr.reconciled_closed_form_crps(rec_mu,rec_s,target).mean(), distr.closed_form_crps(target).mean())

            loss_unmasked = distr.reconciled_closed_form_crps(rec_mu,rec_s,target).mean(axis=-1,keepdims=True)
            # print("ttt", loss_unmasked.mean()/loss_CRPS.mean())

            # print("LOSS",loss_unmasked.shape)
        else:  # CRPS_weight = 0.0 (asserted non-negativity above)
          #print("Yippee", likelihoods.shape, distr.log_prob(target).expand_dims(axis=-1).shape)
          loss_unmasked = likelihoods
              
        # get mask values
        past_observed_values = F.broadcast_minimum(
            past_observed_values, 1 - past_is_pad.expand_dims(axis=-1)
        )

        # (batch_size, subseq_length, target_dim)
        observed_values = F.concat(
            past_observed_values.slice_axis(
                axis=1, begin=-self.context_length, end=None
            ),
            future_observed_values,
            dim=1,
        )

        # mask the loss at one time step if one or more observations is missing
        # in the target dimensions (batch_size, subseq_length, 1)
        loss_weights = observed_values.min(axis=-1, keepdims=True)

        assert_shape(loss_weights, (-1, seq_len, 1)) #-1 is batch axis size

        loss = weighted_average(
            F=F, x=loss_unmasked, weights=loss_weights, axis=1
        )

        assert_shape(loss, (-1, -1, 1))

        self.distribution = distr

        return (loss, likelihoods) + distr_args

    def reconciliation_error(self, samples):
        r"""
        Computes the maximum relative reconciliation error among all the aggregated time series

        .. math::

                        \max_i \frac{|y_i - s_i|} {|y_i|},

        where :math:`i` refers to the aggregated time series index, :math:`y_i` is the (direct) forecast obtained for
        the :math:`i^{th}` time series and :math:`s_i` is its aggregated forecast obtained by summing the corresponding
        bottom-level forecasts. If :math:`y_i` is zero, then the absolute difference, :math:`|s_i|`, is used instead.

        This can be comupted as follows given the constraint matrix A:

        .. math::

                        \max \frac{|A \times samples|} {|samples[:r]|},

        where :math:`r` is the number aggregated time series.

        Parameters
        ----------
        samples
            Samples. Shape: `(*batch_shape, target_dim)`.

        Returns
        -------
        Float
            Reconciliation error


        """

        num_agg_ts = self.A.shape[0]
        forecasts_agg_ts = samples.slice_axis(
            axis=-1, begin=0, end=num_agg_ts
        ).asnumpy()

        abs_err = mx.nd.abs(mx.nd.dot(samples, self.A, transpose_b=True)).asnumpy()
        rel_err = np.where(
            forecasts_agg_ts == 0,
            abs_err,
            abs_err / np.abs(forecasts_agg_ts),
        )

        return np.max(rel_err)

    def sampling_decoder(
        self,
        F,
        past_target_cdf: Tensor,
        target_dimension_indicator: Tensor,
        time_feat: Tensor,
        scale: Tensor,
        begin_states: List[Tensor],
    ) -> Tensor:
        """
        Computes sample paths by unrolling the RNN starting with a initial
        input and state.

        Parameters
        ----------
        past_target_cdf
            Past marginal CDF transformed target values (batch_size,
            history_length, target_dim)
        target_dimension_indicator
            Indices of the target dimension (batch_size, target_dim)
        time_feat
            Dynamic features of future time series (batch_size, history_length,
            num_features)
        scale
            Mean scale for each time series (batch_size, 1, target_dim)
        begin_states
            List of initial states for the RNN layers (batch_size, num_cells)
        Returns
        --------
        sample_paths : Tensor
            A tensor containing sampled paths. Shape: (1, num_sample_paths,
            prediction_length, target_dim).
        """

        def repeat(tensor):
            return tensor.repeat(repeats=self.num_parallel_samples, axis=0)

        # blows-up the dimension of each tensor to
        # batch_size * self.num_sample_paths for increasing parallelism
        repeated_past_target_cdf = repeat(past_target_cdf)
        repeated_time_feat = repeat(time_feat)
        repeated_scale = repeat(scale)
        repeated_target_dimension_indicator = repeat(
            target_dimension_indicator
        )

        # slight difference for GPVAR and DeepVAR, in GPVAR, its a list
        repeated_states = self.make_states(begin_states)

        future_samples = []

        # for each future time-units we draw new samples for this time-unit
        # and update the state
        for k in range(self.prediction_length):
            lags = self.get_lagged_subsequences(
                F=F,
                sequence=repeated_past_target_cdf,
                sequence_length=self.history_length + k,
                indices=self.shifted_lags,
                subsequences_length=1,
            )

            rnn_outputs, repeated_states, lags_scaled, inputs = self.unroll(
                F=F,
                begin_state=repeated_states,
                lags=lags,
                scale=repeated_scale,
                time_feat=repeated_time_feat.slice_axis(
                    axis=1, begin=k, end=k + 1
                ),
                target_dimension_indicator=repeated_target_dimension_indicator,
                unroll_length=1,
            )

            distr, distr_args = self.distr(
                time_features=inputs,
                rnn_outputs=rnn_outputs,
                scale=repeated_scale,
                target_dimension_indicator=repeated_target_dimension_indicator,
                lags_scaled=lags_scaled,
                seq_len=1,
            )

            # (num_parallel_samples*batch_size, 1, m)
            # new_samples are not coherent (initially)
            new_incoherent_samples = distr.sample()

            # reconcile new_incoherent_samples if coherent_pred_samples=True, use new_incoherent_samples if False
            if self.coherent_pred_samples:
                # new_coherent_samples = self.reconcile_samples(new_incoherent_samples)

                new_coherent_samples = self.reconcile_samples(
                                                        reconciliation_mat=self.M,
                                                        samples=new_incoherent_samples,
                                                        seq_axis=self.seq_axis,
                                                        distr=distr,
                                                        A=self.A,
                                                        )

                assert_shape(new_coherent_samples, new_incoherent_samples.shape)

                if self.compute_reconciliation_error:
                    recon_err = self.reconciliation_error(samples=new_coherent_samples)
                    print(f"Reconciliation error for prediction time step t={k + 1}: {recon_err}")

                new_samples = new_coherent_samples
            else:
                new_samples = new_incoherent_samples

            # (batch_size, seq_len, target_dim)
            future_samples.append(new_samples)
            repeated_past_target_cdf = F.concat(
                repeated_past_target_cdf, new_samples, dim=1
            )

        # (batch_size * num_samples, prediction_length, target_dim)
        samples = F.concat(*future_samples, dim=1)

        # (batch_size, num_samples, prediction_length, target_dim)
        return samples.reshape(
            shape=(
                -1,
                self.num_parallel_samples,
                self.prediction_length,
                self.target_dim,
            )
        )


class ProbHardE2ETrainingNetwork(ProbHardE2ENetwork):

    def __init__(
        self,
        num_samples_for_loss: int,
        likelihood_weight: float,
        CRPS_weight: float,
        coherent_train_samples: bool,
        warmstart_epoch_frac: float,
        sample_LH: bool,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.num_samples_for_loss = num_samples_for_loss
        self.likelihood_weight = likelihood_weight
        self.CRPS_weight = CRPS_weight
        self.coherent_train_samples = coherent_train_samples
        self.warmstart_epoch_frac = warmstart_epoch_frac
        self.sample_LH = sample_LH

        # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        target_dimension_indicator: Tensor,
        past_time_feat: Tensor,
        past_target_cdf: Tensor,
        past_observed_values: Tensor,
        past_is_pad: Tensor,
        future_time_feat: Tensor,
        future_target_cdf: Tensor,
        future_observed_values: Tensor,
        epoch_frac: float,
    ) -> Tuple[Tensor, ...]:
        """
        Computes the loss for training DeepVAR, all inputs tensors representing
        time series have NTC layout.

        Parameters
        ----------
        F
        target_dimension_indicator
            Indices of the target dimension (batch_size, target_dim)
        past_time_feat
            Dynamic features of past time series (batch_size, history_length,
            num_features)
        past_target_cdf
            Past marginal CDF transformed target values (batch_size,
            history_length, target_dim)
        past_observed_values
            Indicator whether or not the values were observed (batch_size,
            history_length, target_dim)
        past_is_pad
            Indicator whether the past target values have been padded
            (batch_size, history_length)
        future_time_feat
            Future time features (batch_size, prediction_length, num_features)
        future_target_cdf
            Future marginal CDF transformed target values (batch_size,
            prediction_length, target_dim)
        future_observed_values
            Indicator whether or not the future values were observed
            (batch_size, prediction_length, target_dim)

        Returns
        -------
        distr
            Loss with shape (batch_size, 1)
        likelihoods
            Likelihoods for each time step
            (batch_size, context + prediction_length, 1)
        distr_args
            Distribution arguments (context + prediction_length,
            number_of_arguments)
        """
        return self.train_hybrid_forward(
            F,
            target_dimension_indicator,
            past_time_feat,
            past_target_cdf,
            past_observed_values,
            past_is_pad,
            future_time_feat,
            future_target_cdf,
            future_observed_values,
            epoch_frac,
        )


class ProbHardE2EPredictionNetwork(ProbHardE2ENetwork):
    @validated()
    def __init__(
        self,
        num_parallel_samples: int,
        compute_reconciliation_error: bool,
        coherent_pred_samples: bool,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.num_parallel_samples = num_parallel_samples
        self.compute_reconciliation_error = compute_reconciliation_error
        self.coherent_pred_samples=coherent_pred_samples

        # for decoding the lags are shifted by one,
        # at the first time-step of the decoder a lag of one corresponds to
        # the last target value
        self.shifted_lags = [l - 1 for l in self.lags_seq]

    # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        target_dimension_indicator: Tensor,
        past_time_feat: Tensor,
        past_target_cdf: Tensor,
        past_observed_values: Tensor,
        past_is_pad: Tensor,
        future_time_feat: Tensor,
    ) -> Tensor:
        """
        Predicts samples given the trained DeepVAR model.
        All tensors should have NTC layout.
        Parameters
        ----------
        F
        target_dimension_indicator
            Indices of the target dimension (batch_size, target_dim)
        past_time_feat
            Dynamic features of past time series (batch_size, history_length,
            num_features)
        past_target_cdf
            Past marginal CDF transformed target values (batch_size,
            history_length, target_dim)
        past_observed_values
            Indicator whether or not the values were observed (batch_size,
            history_length, target_dim)
        past_is_pad
            Indicator whether the past target values have been padded
            (batch_size, history_length)
        future_time_feat
            Future time features (batch_size, prediction_length, num_features)

        Returns
        -------
        sample_paths : Tensor
            A tensor containing sampled paths (1, num_sample_paths,
            prediction_length, target_dim).

        """
        return self.predict_hybrid_forward(
            F=F,
            target_dimension_indicator=target_dimension_indicator,
            past_time_feat=past_time_feat,
            past_target_cdf=past_target_cdf,
            past_observed_values=past_observed_values,
            past_is_pad=past_is_pad,
            future_time_feat=future_time_feat,
        )
