# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Any, List, Optional, Tuple

import mxnet as mx
import numpy as np
from mxnet import autograd

from gluonts.core.component import validated
from gluonts.mx import Tensor

from . import bijection as bij
from .distribution import Distribution, _index_tensor, getF
from .bijection import AffineTransformation


def get_quantile(sorted_samples, q):
    # sorted_samples has shape = (num_samples,batch_size,seq_len,1), dimension is fixed at this point
    # sorted_samples = mx.nd.squeeze(sorted_samples, axis=-1) #remove dim axis which only has length 1

    # same sample_idx *for each* batch_size and seq_len point.
    num_samples = sorted_samples.shape[0]
    sample_idx = int(np.round((num_samples - 1) * q))  # round up because y >= q_pred

    return sorted_samples[sample_idx, :, :]  # return dim is (batch_size, seq_len)


# compute quantile loss for single quantile
def quantile_loss(sorted_samples, y, q):
    # sorted_samples has shape = (num_samples,batch_size,seq_len,1)
    # q is a scalar

    # I think get_quantile function is outside of the mxnet 'path'
    # quantile_pred has shape = (batch_size,seq_len,1)
    quantile_pred = get_quantile(sorted_samples, q)  # shape = (batch_size, seq_len, 1)

    assert (y.shape == quantile_pred.shape)

    # return shape is (batch_size,seq_len,1)
    return mx.nd.where(
        y >= quantile_pred,
        q * (y - quantile_pred),  # if >=
        (1 - q) * (quantile_pred - y)
    )



class TransformedDistribution(Distribution):
    r"""
    A distribution obtained by applying a sequence of transformations on top
    of a base distribution.
    """

    @validated()
    def __init__(
        self, base_distribution: Distribution, transforms: List[bij.Bijection]
    ) -> None:
        self.base_distribution = base_distribution
        self.transforms = transforms
        self.is_reparameterizable = self.base_distribution.is_reparameterizable

        # use these to cache shapes and avoid recomputing all steps
        # the reason we cannot do the computations here directly
        # is that this constructor would fail in mx.symbol mode
        self._event_dim: Optional[int] = None
        self._event_shape: Optional[Tuple] = None
        self._batch_shape: Optional[Tuple] = None

    @property
    def F(self):
        return self.base_distribution.F

    @property
    def support_min_max(self) -> Tuple[Tensor, Tensor]:
        F = self.F
        lb, ub = self.base_distribution.support_min_max
        for t in self.transforms:
            _lb = t.f(lb)
            _ub = t.f(ub)
            lb = F.minimum(_lb, _ub)
            ub = F.maximum(_lb, _ub)
        return lb, ub

    def _slice_bijection(
        self, trans: bij.Bijection, item: Any
    ) -> bij.Bijection:
        from .box_cox_transform import BoxCoxTransform

        if isinstance(trans, bij.AffineTransformation):
            loc = (
                _index_tensor(trans.loc, item)
                if trans.loc is not None
                else None
            )
            scale = (
                _index_tensor(trans.scale, item)
                if trans.scale is not None
                else None
            )
            return bij.AffineTransformation(loc=loc, scale=scale)
        elif isinstance(trans, BoxCoxTransform):
            return BoxCoxTransform(
                _index_tensor(trans.lambda_1, item),
                _index_tensor(trans.lambda_2, item),
            )
        elif isinstance(trans, bij.InverseBijection):
            return bij.InverseBijection(
                self._slice_bijection(trans._bijection, item)
            )
        else:
            return trans

    def __getitem__(self, item):
        bd_slice = self.base_distribution[item]
        trans_slice = [self._slice_bijection(t, item) for t in self.transforms]
        return TransformedDistribution(bd_slice, trans_slice)

    @property
    def event_dim(self):
        if self._event_dim is None:
            self._event_dim = max(
                [self.base_distribution.event_dim]
                + [t.event_dim for t in self.transforms]
            )
        assert isinstance(self._event_dim, int)
        return self._event_dim

    @property
    def batch_shape(self) -> Tuple:
        if self._batch_shape is None:
            shape = (
                self.base_distribution.batch_shape
                + self.base_distribution.event_shape
            )
            self._batch_shape = shape[: len(shape) - self.event_dim]
        assert isinstance(self._batch_shape, tuple)
        return self._batch_shape

    @property
    def event_shape(self) -> Tuple:
        if self._event_shape is None:
            shape = (
                self.base_distribution.batch_shape
                + self.base_distribution.event_shape
            )
            self._event_shape = shape[len(shape) - self.event_dim :]
        assert isinstance(self._event_shape, tuple)
        return self._event_shape

    def sample(
        self, num_samples: Optional[int] = None, dtype=np.float32
    ) -> Tensor:
        with autograd.pause():
            s = self.base_distribution.sample(
                num_samples=num_samples, dtype=dtype
            )
            for t in self.transforms:
                s = t.f(s)
            return s

    def sample_rep(
        self, num_samples: Optional[int] = None, dtype=np.float
    ) -> Tensor:
        s = self.base_distribution.sample_rep(
            num_samples=num_samples, dtype=dtype
        )
        # print("Well Well", self.transforms)
        for t in self.transforms:
            s = t.f(s)
        # print("Well", s.mean(axis=0).mean(), self.mean.mean(), self.loc)
        return s
    
    def crps(self, samples: Tensor, y: Tensor) -> Tensor:
        # TODO: use event_shape
        F = getF(y)
        scale = 1.0
        for t in self.transforms[::-1]:
            # assert isinstance(
            #     t, bijection.AffineTransformation
            # ), "Not an AffineTransformation"
            y = t.f_inv(y)
            samples = t.f_inv(samples)
            scale *= t.scale
        p = self.base_distribution.crps(samples, y) #shape (batch_size,seq_len,m)

        scaled_p = F.broadcast_mul(p, scale)
        return F.sum(scaled_p, axis=-1).expand_dims(-1)


    def log_prob(self, y: Tensor) -> Tensor:
        F = getF(y)
        lp = 0.0
        x = y
        for t in self.transforms[::-1]:
            x = t.f_inv(y)
            ladj = t.log_abs_det_jac(x, y)
            lp -= sum_trailing_axes(F, ladj, self.event_dim - t.event_dim)
            y = x

        return self.base_distribution.log_prob(x) + lp

    def cdf(self, y: Tensor) -> Tensor:
        x = y
        sign = 1.0
        for t in self.transforms[::-1]:
            x = t.f_inv(x)
            sign = sign * t.sign
        f = self.base_distribution.cdf(x)
        return sign * (f - 0.5) + 0.5

    def quantile(self, level: Tensor) -> Tensor:
        F = getF(level)

        sign = 1.0
        for t in self.transforms:
            sign = sign * t.sign

        if not isinstance(sign, (mx.nd.NDArray, mx.sym.Symbol)):
            level = level if sign > 0 else (1.0 - level)
            q = self.base_distribution.quantile(level)
        else:
            # level.shape = (#levels,)
            # q_pos.shape = (#levels, batch_size, ...)
            # sign.shape = (batch_size, ...)
            q_pos = self.base_distribution.quantile(level)
            q_neg = self.base_distribution.quantile(1.0 - level)
            cond = F.broadcast_greater(sign, sign.zeros_like())
            cond = F.broadcast_add(cond, q_pos.zeros_like())
            q = F.where(cond, q_pos, q_neg)

        for t in self.transforms:
            q = t.f(q)
        return q


class AffineTransformedDistribution(TransformedDistribution):
    @validated()
    def __init__(
        self,
        base_distribution: Distribution,
        loc: Optional[Tensor] = None,
        scale: Optional[Tensor] = None,
    ) -> None:
        super().__init__(base_distribution, [AffineTransformation(loc, scale)])

        self.loc = loc
        self.scale = scale

    @property
    def mean(self) -> Tensor:
        if self.loc is None and self.scale is None:
            return self.base_distribution.mean
        elif self.loc is not None and self.scale is None:
            return self.base_distribution.mean + self.loc
        elif self.loc is None and self.scale is not None:
            return self.base_distribution.mean * self.scale
        else:
            self.base_distribution.mean * self.scale + self.loc

        # return (
        #     self.base_distribution.mean
        #     if self.loc is None and self.scale
        #     elif self.loc is not None and self.scale is None
        #     self.base_distribution.mean + self.loc
        # )

    @property
    def stddev(self) -> Tensor:
        return (
            self.base_distribution.stddev
            if self.scale is None
            else self.base_distribution.stddev * self.scale
        )

    # @property
    # def variance(self) -> Tensor:
    #     # TODO: cover the multivariate case here too
    #     return (
    #         self.base_distribution.variance
    #         if self.scale is None
    #         else self.base_distribution.variance * self.scale ** 2
    #     )

    @property
    def variance(self) -> Tensor:
        # TODO: cover the multivariate case here too
        from . import LowrankMultivariateGaussian
        from gluonts.mx.distribution import MultivariateGaussian

        if isinstance(
            self.base_distribution, LowrankMultivariateGaussian
        ) or isinstance(self.base_distribution, MultivariateGaussian):
            F = self.F
            sigma = self.base_distribution.variance
            scale_diag = self.scale.expand_dims(-1) * F.eye(self.base_distribution.dim)

            scale_diag = mx.nd.repeat(scale_diag, repeats=sigma.shape[1], axis=1)
            return mx.nd.batch_dot(scale_diag, mx.nd.batch_dot(sigma, scale_diag))

        else:
            return self.base_distribution.variance * self.scale**2
    # TODO: crps

    def reconciled_closed_form_crps(self, mean, variance,  y: Tensor, dtype=np.float32) -> Tensor:
        mu = mean
        var = variance

        F = self.F

        # Extract diagonal
        var = mx.nd.linalg.extractdiag(var)
        sigma = mx.nd.sqrt(var)

        # # Create an identity matrix of shape (89, 89)
        # eye = F.eye(self.base_distribution.dim)  # shape (89, 89)

        # # Reshape for broadcasting: (1, 1, 89, 89)
        # eye = eye.reshape((1, 1, self.base_distribution.dim, self.base_distribution.dim))

        # # Expand diags: (32, 32, 89, 1)
        # diags_exp = sigma.expand_dims(-1)

        # # Multiply to get batched diagonal matrix: (32, 32, 89, 89)
        # diag_matrix = diags_exp * eye

        # print(diag_matrix.shape)  # (32, 32, 89, 89)

        # # Check off-diagonal elements: difference should be zero on off-diagonal
        # off_diag_diff = mx.nd.abs(sigma - diag_matrix)

        # # Set a tolerance for numerical precision
        # tol = 1e-5

        # # Check whether all off-diagonal elements are below the tolerance
        # is_diag = (off_diag_diff < tol).min().asscalar() == 1.0

        # print("Is variance diagonal?", is_diag)

        # print(mu.shape, sigma.shape)

        z = (y - mu) / sigma
        sqrt_2 = mx.nd.sqrt(mx.nd.array([2.0], ctx=y.context))
        sqrt_pi = np.sqrt(np.pi)

        pdf = mx.nd.exp(-0.5 * z**2) / np.sqrt(2 * np.pi)
        cdf = 0.5 * (1 + mx.nd.erf(z / sqrt_2))

        crps = sigma * (z * (2 * cdf - 1) + 2 * pdf - 1 / sqrt_pi)
        return crps

    def closed_form_crps(self,  y: Tensor, dtype=np.float32) -> Tensor:
        mu = self.mean
        var = self.variance
        # sigma = mx.nd.sqrt(var)

        F = self.F

        # Extract diagonal
        var = mx.nd.linalg.extractdiag(var)
        sigma = mx.nd.sqrt(var)

        # # Create an identity matrix of shape (89, 89)
        # eye = F.eye(self.base_distribution.dim)  # shape (89, 89)

        # # Reshape for broadcasting: (1, 1, 89, 89)
        # eye = eye.reshape((1, 1, self.base_distribution.dim, self.base_distribution.dim))

        # # Expand diags: (32, 32, 89, 1)
        # diags_exp = sigma.expand_dims(-1)

        # # Multiply to get batched diagonal matrix: (32, 32, 89, 89)
        # diag_matrix = diags_exp * eye

        # print(diag_matrix.shape)  # (32, 32, 89, 89)

        # # Check off-diagonal elements: difference should be zero on off-diagonal
        # off_diag_diff = mx.nd.abs(sigma - diag_matrix)

        # # Set a tolerance for numerical precision
        # tol = 1e-5

        # # Check whether all off-diagonal elements are below the tolerance
        # is_diag = (off_diag_diff < tol).min().asscalar() == 1.0

        # print("Is variance diagonal?", is_diag)

        # print(mu.shape, sigma.shape)

        z = (y - mu) / sigma
        sqrt_2 = F.sqrt(mx.nd.array([2.0], ctx=y.context))
        sqrt_pi = np.sqrt(np.pi)

        pdf = F.exp(-0.5 * z**2) / np.sqrt(2 * np.pi)
        cdf = 0.5 * (1 + F.erf(z / sqrt_2))

        crps = sigma * (z * (2 * cdf - 1) + 2 * pdf - 1 / sqrt_pi)
        return crps

    # def crps(self, samples: Tensor, y: Tensor, quantiles=np.arange(0.1, 1.0, 0.1)) -> Tensor:
    #     r"""
    #     Compute the *continuous rank probability score* (CRPS) of `y` according
    #     to the distribution.

    #     Parameters
    #     ----------
    #     samples
    #         Tensor of shape `(*batch_shape, *event_shape)`.
    #     y
    #         Tensor of ground truth

    #     Returns
    #     -------
    #     Tensor
    #         Tensor of shape `batch_shape` containing the CRPS score,
    #         according to the distribution, for each event in `x`.
    #     """
    #     # y is ground truth. Has shape (batch_size, seq_len, m)
    #     # samples has shape = (num_samples, batch_size, seq_len, m)

    #     # sum over m axis, sum over T axis, sum over bs axis

    #     # loss for single ground truth point across all dimensions
    #     # loop through dimensions
    #     losses = []
    #     for d in range(samples.shape[-1]):

    #         # dim of dim_slice = (num_samples,batch_size,seq_len,1)
    #         dim_slice = mx.nd.slice_axis(samples, axis=-1, begin=d, end=d + 1)

    #         # sort samples along sample axis. shape = (num_samples,batch_size,seq_len,1)
    #         sorted_slice = mx.nd.sort(dim_slice, axis=0)  # sort along sample axis (first axis)

    #         # slice of y for dimension d. shape = (batch_size, seq_len,1)
    #         y_slice = mx.nd.slice_axis(y, axis=-1, begin=d, end=d + 1)

    #         # compute quantile loss, shape = (batch_size, seq_len, 1)
    #         #qloss = mx.nd.zeros((y_slice.shape))
    #         qlosses = []
    #         for q in quantiles:
    #             qlosses.append(quantile_loss(sorted_slice, y_slice, q))
    #         #qloss = quantile_loss(sorted_slice, y_slice, .1)
    #         qloss = mx.nd.stack(*qlosses, axis=-1) #shape = (batch_size, seq_len, 1, Q)

    #         #take average
    #         qloss = (1/len(qlosses)) * mx.nd.sum(qloss, axis=-1)  #shape = (batch_size, seq_len,1)

    #         # append qloss tensor
    #         losses.append(mx.nd.squeeze(qloss))  # remove dummy last axis of dim_slice and append

    #     loss = mx.nd.stack(*losses, axis=-1)  # shape = (batch_size, seq_len,m)
    #     return mx.nd.sum(loss, axis=-1).expand_dims(-1)  # shape = (batch_size, seq_len,1)
    #     # return loss #shape = (batch_size, seq_len,m)



def sum_trailing_axes(F, x: Tensor, k: int) -> Tensor:
    for _ in range(k):
        x = F.sum(x, axis=-1)
    return x
