#!/usr/bin/env python3

from __future__ import annotations

import math
import warnings
from numbers import Number
from typing import Optional, Tuple, Union

import torch
from linear_operator import to_dense, to_linear_operator
from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
from torch import Tensor
from torch.distributions import MultivariateNormal as TMultivariateNormal, Chi2
from torch.distributions.kl import register_kl
from torch.distributions.utils import _standard_normal, lazy_property

from .. import settings
from ..utils.warnings import NumericalWarning
from .distribution import Distribution


class MultivariateQExponential(TMultivariateNormal, Distribution):
    """
    Constructs a multivariate q-exponential random variable, based on mean and covariance.
    Can be multivariate, or a batch of multivariate q-exponentials

    Passing a vector mean corresponds to a multivariate q-exponential.
    Passing a matrix mean corresponds to a batch of multivariate q-exponentials.

    :param mean: `... x N` mean of qep distribution.
    :param covariance_matrix: `... x N X N` covariance matrix of qep distribution.
    :param power: (scalar) power of qep distribution. (Default: 2.)
    :param validate_args: If True, validate `mean` and `covariance_matrix` arguments. (Default: False.)

    :ivar torch.Size base_sample_shape: The shape of a base sample (without
        batching) that is used to generate a single sample.
    :ivar torch.Tensor covariance_matrix: The covariance matrix, represented as a dense :class:`torch.Tensor`
    :ivar ~linear_operator.LinearOperator lazy_covariance_matrix: The covariance matrix, represented
        as a :class:`~linear_operator.LinearOperator`.
    :ivar torch.Tensor mean: The mean.
    :ivar torch.Tensor stddev: The standard deviation.
    :ivar torch.Tensor variance: The variance.
    """

    def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0), validate_args: bool = False):
        self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
        if self._islazy:
            if validate_args:
                ms = mean.size(-1)
                cs1 = covariance_matrix.size(-1)
                cs2 = covariance_matrix.size(-2)
                if not (ms == cs1 and ms == cs2):
                    raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
            self.loc = mean
            self._covar = covariance_matrix
            self.__unbroadcasted_scale_tril = None
            self._validate_args = validate_args
            batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])

            event_shape = self.loc.shape[-1:]

            # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
            super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
        else:
            super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
        self.power = power

    def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
        """
        Returns the size of the sample returned by the distribution, given
        a `sample_shape`. Note, that the batch and event shapes of a distribution
        instance are fixed at the time of construction. If this is empty, the
        returned shape is upcast to (1,).

        :param sample_shape: the size of the sample to be drawn.
        """
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        return sample_shape + self._batch_shape + self.base_sample_shape

    @staticmethod
    def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0)) -> str:
        return f"MultivariateQExponential(loc: {mean.size()}, scale: {covariance_matrix.size()}, pow: {power.size()})"

    @property
    def _unbroadcasted_scale_tril(self) -> Tensor:
        if self.islazy and self.__unbroadcasted_scale_tril is None:
            # cache root decoposition
            ust = to_dense(self.lazy_covariance_matrix.cholesky())
            self.__unbroadcasted_scale_tril = ust
        return self.__unbroadcasted_scale_tril

    @_unbroadcasted_scale_tril.setter
    def _unbroadcasted_scale_tril(self, ust: Tensor):
        if self.islazy:
            raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy QEP distributions")
        else:
            self.__unbroadcasted_scale_tril = ust

    def add_jitter(self, noise: float = 1e-4) -> MultivariateQExponential:
        r"""
        Adds a small constant diagonal to the QEP covariance matrix for numerical stability.

        :param noise: The size of the constant diagonal.
        """
        return self.__class__(self.mean, self.lazy_covariance_matrix.add_jitter(noise), self.power)

    @property
    def base_sample_shape(self) -> torch.Size:
        base_sample_shape = self.event_shape
        if isinstance(self.lazy_covariance_matrix, RootLinearOperator):
            base_sample_shape = self.lazy_covariance_matrix.root.shape[-1:]

        return base_sample_shape

    @lazy_property
    def covariance_matrix(self) -> Tensor:
        if self.islazy:
            return self._covar.to_dense()
        else:
            return super().covariance_matrix

    def confidence_region(self) -> Tuple[Tensor, Tensor]:
        """
        Returns 2 standard deviations above and below the mean.

        :return: Pair of tensors of size `... x N`, where N is the
            dimensionality of the random variable. The first (second) Tensor is the
            lower (upper) end of the confidence region.
        """
        std2 = self.stddev.mul_(2)
        mean = self.mean
        return mean.sub(std2), mean.add(std2)

    def expand(self, batch_size: torch.Size) -> MultivariateQExponential:
        r"""
        See :py:meth:`torch.distributions.Distribution.expand
        <torch.distributions.distribution.Distribution.expand>`.
        """
        new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
        new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
        res = self.__class__(new_loc, new_covar, self.power)
        return res

    def get_base_samples(self, sample_shape: torch.Size = torch.Size(), rescale = False) -> Tensor:
        r"""
        Returns marginally identical but uncorrelated (m.i.u.) standard Q-Exponential samples to be used with
        :py:meth:`MultivariateQExponential.rsample(base_samples=base_samples)
        <gpytorch.distributions.MultivariateQExponential.rsample>`.

        :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
        :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. standard Q-Exponential samples.
        """
        with torch.no_grad():
            shape = self._extended_shape(sample_shape)
            base_samples = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
            if self.power!=2: base_samples = torch.nn.functional.normalize(base_samples, dim=-1)*Chi2(shape[-1]).sample(shape[:-1]+torch.Size([1])).to(self.loc.device)**(1./self.power)
            if rescale: base_samples /= torch.exp((2./self.power*math.log(2) - math.log(shape[-1]) + torch.lgamma(shape[-1]/2.+2./self.power) - math.lgamma(shape[-1]/2.))/2.)
        return base_samples

    @lazy_property
    def lazy_covariance_matrix(self) -> LinearOperator:
        if self.islazy:
            return self._covar
        else:
            return to_linear_operator(super().covariance_matrix)

    def log_prob(self, value: Tensor) -> Tensor:
        r"""
        See :py:meth:`torch.distributions.Distribution.log_prob
        <torch.distributions.distribution.Distribution.log_prob>`.
        """
        if settings.fast_computations.log_prob.off():
            return super().log_prob(value)

        if self._validate_args:
            self._validate_sample(value)

        mean, covar, power = self.loc, self.lazy_covariance_matrix, self.power
        diff = value - mean

        # Repeat the covar to match the batch shape of diff
        if diff.shape[:-1] != covar.batch_shape:
            if len(diff.shape[:-1]) < len(covar.batch_shape):
                diff = diff.expand(covar.shape[:-1])
            else:
                padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape)
                covar = covar.repeat(
                    *(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)),
                    1,
                    1,
                )

        # Get log determininant and first part of quadratic form
        covar = covar.evaluate_kernel()
        inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)

        res = -0.5 * sum([inv_quad**(power/2.), logdet, diff.size(-1) * math.log(2 * math.pi)])
        if power!=2: res += sum([0.5 * diff.size(-1) * (power/2.-1) * torch.log(inv_quad), torch.log(power/2.)])
        return res

    def entropy(self, exact: bool = False) -> Tensor:
        r"""
        See :py:meth:`torch.distributions.Distribution.entropy
        <torch.distributions.distribution.Distribution.entropy>`.
        """
        d = self._event_shape[0] #self.loc.shape[-1]
        if self.islazy:
            logdet = self.lazy_covariance_matrix.logdet()
            res = 0.5 * sum([d*math.log(2*math.pi), logdet, d**(1 if exact else self.power/2.)])
        else:
            res = super().entropy()
            if not exact: res += 0.5*(-d + d**(self.power/2.))
        if self.power!=2:
            res += sum([d/2.*(self.power/2.-1) *(2./self.power* Chi2(d).entropy() if exact else -math.log(d)), -torch.log(self.power/2.)])
        return res
    
    def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None) -> Tensor:
        r"""
        Generates a `sample_shape` shaped reparameterized sample or `sample_shape`
        shaped batch of reparameterized samples if the distribution parameters
        are batched.

        For the MultivariateQExponential distribution, this is accomplished through:

        .. math::
            \boldsymbol \mu + \mathbf L \boldsymbol \epsilon

        where :math:`\boldsymbol \mu \in \mathcal R^N` is the QEP mean,
        :math:`\mathbf L \in \mathcal R^{N \times N}` is a "root" of the
        covariance matrix :math:`\mathbf K` (i.e. :math:`\mathbf L \mathbf
        L^\top = \mathbf K`), and :math:`\boldsymbol \epsilon \in \mathcal R^N` is a
        vector of (approximately) m.i.u. standard Q-Exponential random variables.

        :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
        :param base_samples: The `*sample_shape x *batch_shape x N` tensor of
            m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
            reparameterize. (Default: None.)
        :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. reparameterized samples.
        """
        covar = self.lazy_covariance_matrix
        if base_samples is None:
            # Create some samples
            num_samples = sample_shape.numel() or 1
            
            covar_base = covar.base_linear_op if hasattr(covar, 'base_linear_op') else covar
            if covar_base.size()[-2:] == torch.Size([1, 1]):
                covar_root = covar_base.to_dense().sqrt()
            else:
                covar_root = covar_base.root_decomposition().root
            
            base_samples = self.get_base_samples(torch.Size([num_samples]))
            if len(self.event_shape)==2: # multitask case
                if not self._interleaved: base_samples = base_samples.transpose(-1,-2)
                base_samples = base_samples.reshape(base_samples.shape[:-2] + covar_base.shape[-1:])
            base_samples = base_samples.permute(*range(1,covar_base.dim()),0)
            if covar_root.shape < covar_base.shape: base_samples = base_samples[...,:covar_root.size(-1),:]
            
            # Get samples
            res = covar_root.matmul(base_samples).permute(-1, *range(covar_base.dim() - 1)).contiguous()
            if hasattr(covar, '_remove_batch_dim'): res = covar._remove_batch_dim(res.unsqueeze(-1)).squeeze(-1)
            res += self.loc.unsqueeze(0)
            res = res.view(sample_shape + self.loc.shape)
        
        else:
            covar_root = covar.root_decomposition().root
    
            # Make sure that the base samples agree with the distribution
            if (
                self.loc.shape != base_samples.shape[-self.loc.dim() :]
                and covar_root.shape[-1] < base_samples.shape[-1]
            ):
                raise RuntimeError(
                    "The size of base_samples (minus sample shape dimensions) should agree with the size "
                    "of self.loc. Expected ...{} but got {}".format(self.loc.shape, base_samples.shape)
                )
    
            # Determine what the appropriate sample_shape parameter is
            sample_shape = base_samples.shape[: base_samples.dim() - self.loc.dim()]
    
            # Reshape samples to be batch_size x num_dim x num_samples
            # or num_bim x num_samples
            base_samples = base_samples.view(-1, *self.loc.shape[:-1], covar_root.shape[-1])
            base_samples = base_samples.permute(*range(1, self.loc.dim() + 1), 0)
    
            # Now reparameterize those base samples
            # If necessary, adjust base_samples for rank of root decomposition
            if covar_root.shape[-1] < base_samples.shape[-2]:
                base_samples = base_samples[..., : covar_root.shape[-1], :]
            elif covar_root.shape[-1] > base_samples.shape[-2]:
                # raise RuntimeError("Incompatible dimension of `base_samples`")
                covar_root = covar_root.transpose(-2, -1)
            res = covar_root.matmul(base_samples) + self.loc.unsqueeze(-1)
    
            # Permute and reshape new samples to be original size
            res = res.permute(-1, *range(self.loc.dim())).contiguous()
            res = res.view(sample_shape + self.loc.shape)

        return res

    def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None) -> Tensor:
        r"""
        Generates a `sample_shape` shaped sample or `sample_shape`
        shaped batch of samples if the distribution parameters
        are batched.

        Note that these samples are not reparameterized and therefore cannot be backpropagated through.

        :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
        :param base_samples: The `*sample_shape x *batch_shape x N` tensor of
            m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
            reparameterize. (Default: None.)
        :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. samples.
        """
        with torch.no_grad():
            return self.rsample(sample_shape=sample_shape, base_samples=base_samples)

    @property
    def stddev(self) -> Tensor:
        # self.variance is guaranteed to be positive, because we do clamping.
        return self.variance.sqrt()

    def to_data_uncorrelated_dist(self) -> MultivariateQExponential:
        """
        Convert a `... x N` QEP distribution into a batch of uncorrelated Q-Exponential distributions.
        Essentially, this throws away all covariance information
        and treats all dimensions as batch dimensions.

        :returns: A (data-uncorrelated) Q-Exponential distribution with batch shape `*batch_shape x N`.
        """
        # Create batch distribution where all data are independent, but the tasks are dependent
        # try:
        #     # If pyro is installed, use that set of base distributions
        #     import pyro.distributions as base_distributions
        # except ImportError:
        #     # Otherwise, use PyTorch
        #     import torch.distributions as base_distributions
        # return base_distributions.Normal(self.mean, self.stddev)
        new_cov = DiagLinearOperator(
            self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
                )
        return self.__class__(mean=self.mean, covariance_matrix=new_cov, power=self.power)
    
    to_data_independent_dist = to_data_uncorrelated_dist # alias to the same function with a more appropriate name

    @property
    def variance(self) -> Tensor:
        if self.islazy:
            # overwrite this since torch uses unbroadcasted_scale_tril for this
            diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
            diag = diag.view(diag.shape[:-1] + self._event_shape)
            variance = diag.expand(self._batch_shape + self._event_shape)
        else:
            variance = super().variance

        # Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
        # This ensures that all variances are positive
        min_variance = settings.min_variance.value(variance.dtype)
        if variance.lt(min_variance).any():
            warnings.warn(
                f"Negative variance values detected. "
                "This is likely due to numerical instabilities. "
                f"Rounding negative variances up to {min_variance}.",
                NumericalWarning,
            )
            variance = variance.clamp_min(min_variance)
        return variance

    def __add__(self, other: MultivariateQExponential) -> MultivariateQExponential:
        if isinstance(other, MultivariateQExponential):
            return self.__class__(
                mean=self.mean + other.mean,
                covariance_matrix=(self.lazy_covariance_matrix + other.lazy_covariance_matrix),
                power=self.power
            )
        elif isinstance(other, int) or isinstance(other, float):
            return self.__class__(self.mean + other, self.lazy_covariance_matrix, self.power)
        else:
            raise RuntimeError("Unsupported type {} for addition w/ MultivariateQExponential".format(type(other)))

    def __getitem__(self, idx) -> MultivariateQExponential:
        r"""
        Constructs a new MultivariateQExponential that represents a random variable
        modified by an indexing operation.

        The mean and covariance matrix arguments are indexed accordingly.

        :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
        """

        if not isinstance(idx, tuple):
            idx = (idx,)
        if len(idx) > self.mean.dim() and Ellipsis in idx:
            idx = tuple(i for i in idx if i != Ellipsis)
            if len(idx) < self.mean.dim():
                raise IndexError("Multiple ambiguous ellipsis in index!")

        rest_idx = idx[:-1]
        last_idx = idx[-1]
        new_mean = self.mean[idx]

        if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
            # We are only indexing the batch dimensions in this case
            new_cov = self.lazy_covariance_matrix[idx]
        elif len(idx) > self.mean.dim():
            raise IndexError(f"Index {idx} has too many dimensions")
        else:
            # In this case we know last_idx corresponds to the last dimension
            # of mean and the last two dimensions of lazy_covariance_matrix
            if isinstance(last_idx, int):
                new_cov = DiagLinearOperator(
                    self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)[(*rest_idx, last_idx)]
                )
            elif isinstance(last_idx, slice):
                new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, last_idx)]
            elif last_idx is (...):
                new_cov = self.lazy_covariance_matrix[rest_idx]
            else:
                new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, slice(None, None, None))][..., last_idx]
        return self.__class__(mean=new_mean, covariance_matrix=new_cov, power=self.power)

    def __mul__(self, other: Number) -> MultivariateQExponential:
        if not (isinstance(other, int) or isinstance(other, float)):
            raise RuntimeError("Can only multiply by scalars")
        if other == 1:
            return self
        return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other**2), power=self.power)

    def __radd__(self, other: MultivariateQExponential) -> MultivariateQExponential:
        if other == 0:
            return self
        return self.__add__(other)

    def __truediv__(self, other: Number) -> MultivariateQExponential:
        return self.__mul__(1.0 / other)


@register_kl(MultivariateQExponential, MultivariateQExponential)
def kl_qep_qep(p_dist: MultivariateQExponential, q_dist: MultivariateQExponential, exact: bool = False) -> Tensor:
    output_shape = torch.broadcast_shapes(p_dist.batch_shape, q_dist.batch_shape)
    if output_shape != p_dist.batch_shape:
        p_dist = p_dist.expand(output_shape)
    if output_shape != q_dist.batch_shape:
        q_dist = q_dist.expand(output_shape)

    q_mean = q_dist.loc
    q_covar = q_dist.lazy_covariance_matrix

    p_mean = p_dist.loc
    p_covar = p_dist.lazy_covariance_matrix
    root_p_covar = p_covar.root_decomposition().root.to_dense()

    mean_diffs = p_mean - q_mean
    dim = float(mean_diffs.size(-1))
    if isinstance(root_p_covar, LinearOperator):
        # right now this just catches if root_p_covar is a DiagLinearOperator,
        # but we may want to be smarter about this in the future
        root_p_covar = root_p_covar.to_dense()
    inv_quad_rhs = torch.cat([mean_diffs.unsqueeze(-1), root_p_covar], -1)
    logdet_p_covar = p_covar.logdet()
    trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True)

    # Compute the KL Divergence.
    res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form**(q_dist.power/2.), -dim**(1 if exact else p_dist.power/2.)])
    if q_dist.power!=2: res +=  dim/2. * sum([-(q_dist.power/2.-1)*torch.log(trace_plus_inv_quad_form), -(p_dist.power/2.-1)*(2./p_dist.power*Chi2(dim).entropy() if exact else -math.log(dim))]) # exact value is intractable; an approximation is provided instead.
    return res
