from typing import List, Union, Optional
from array_api_compat import array_namespace
import numpy as np
import torch

from autogluon.timeseries.dataset import TimeSeriesDataFrame
from autogluon.timeseries.metrics.abstract import TimeSeriesScorer
from autogluon.timeseries.metrics.quantile import (
    WQL as AutogluonWQL,
    SQL as AutogluonSQL,
)
from autogluon.timeseries.metrics.point import MASE as AutogluonMASE
from autogluon.timeseries.metrics.utils import in_sample_abs_seasonal_error
from autogluon.timeseries.utils.datetime import get_seasonality
from autogluon.timeseries.metrics import check_get_evaluation_metric


def _per_level_quantile_loss(
    y_true: Union[np.ndarray, torch.tensor],
    # [num_folds, num_items, num_times]
    q_pred: Union[np.ndarray, torch.tensor],
    # [num_folds, num_items, num_times, num_quantiles]
    quantile_levels: Union[List[float], np.ndarray, torch.tensor],
    # [num_quantiles]
):
    # Input shapes etc should be fine as they have been checked by AbstractMetric.compute_metric

    # print(f"`per_level_quantile_loss`: y_true.shape={y_true.shape}, q_pred.shape={q_pred.shape}")
    y_true = y_true[..., None]

    if isinstance(quantile_levels, list):
        if isinstance(y_true, np.ndarray):
            quantile_levels = np.array(quantile_levels, dtype=y_true.dtype)
        elif isinstance(y_true, torch.Tensor):
            quantile_levels = torch.tensor(
                quantile_levels, dtype=y_true.dtype, device=y_true.device
            )
    qs = quantile_levels[None, None, None, :]

    return 2 * (
        (1 - qs) * (q_pred - y_true) * (y_true < q_pred)
        + qs * (y_true - q_pred) * (y_true >= q_pred)
    )


class AbstractMetric:
    def save_past_metrics(self, **kwargs) -> None:
        pass

    def compute_metric(
        self,
        y_true: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times]
        q_pred: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times, num_quantiles]
        quantile_levels: Union[List[float], np.ndarray, torch.tensor],
        # [num_quantiles]
        **kwargs,
    ) -> float:
        assert len(y_true.shape) == 3, f"y_true.shape={y_true.shape}"
        assert len(q_pred.shape) == 4, f"q_pred.shape={q_pred.shape}"
        assert y_true.shape == q_pred.shape[:3]
        if not isinstance(quantile_levels, list):
            assert len(quantile_levels.shape) == 1
        assert len(quantile_levels) == q_pred.shape[-1]
        return self._compute_metric(y_true, q_pred, quantile_levels, **kwargs)

    def _compute_metric(
        self,
        y_true: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times]
        q_pred: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times, num_quantiles]
        quantile_levels: Union[List[float], np.ndarray, torch.tensor],
        # [num_quantiles]
    ) -> float:
        raise NotImplementedError
    
    @property
    def ispointforecast(self):
        raise NotImplementedError


class WQL(AbstractMetric):
    """
    Torch-compatible version of WQL that operates on tensors instead of `TimeSeriesDataFrame`s
    Should be equivalent to `autogluon.timeseries.metrics.quantile.WQL`.
    """

    def __init__(self, fold_weight_fn: Optional[callable] = None):
        self.fold_weight_fn = fold_weight_fn

    @property
    def ispointforecast(self):
        return False

    def _compute_metric(
        self,
        y_true: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times]
        q_pred: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times, num_quantiles]
        quantile_levels: Union[List[float], np.ndarray, torch.tensor],
        # [num_quantiles]
    ) -> float:
        num_folds, _, _, num_quantiles = q_pred.shape
        plql = _per_level_quantile_loss(y_true, q_pred, quantile_levels)
        xp = array_namespace(plql, use_compat=False)
        if self.fold_weight_fn is not None:
            fold_weight = xp.asarray([self.fold_weight_fn(i) for i in range(num_folds)])
            plql = plql * fold_weight[..., None, None, None]
        return xp.nansum(plql) / xp.nansum(xp.abs(y_true)) / num_folds / num_quantiles


class SQL(AbstractMetric):
    """
    Torch-compatible version of SQL that operates on tensors instead of `TimeSeriesDataFrame`s
    Should be equivalent to `autogluon.timeseries.metrics.quantile.SQL`.
    """

    def __init__(self, fold_weight_fn: Optional[callable] = None):
        self.fold_weight_fn = fold_weight_fn
        self._past_abs_seasonal_error = None

    @property
    def ispointforecast(self):
        return False

    def save_past_metrics(
        self,
        data_past: List[TimeSeriesDataFrame],
        target: str = "target",
        seasonal_period: Union[int, None] = None,
        **kwargs,
    ) -> None:
        if not isinstance(data_past, list):
            raise ValueError(
                "`data_past` must be a list of `TimeSeriesDataFrame`s containing the folds"
            )

        if seasonal_period is None:
            seasonal_period = get_seasonality(data_past[0].freq)

        n_folds = len(data_past)
        seasonal_error_list = []
        for df in data_past:
            se = in_sample_abs_seasonal_error(
                y_past=df[target], seasonal_period=seasonal_period
            ).to_numpy()
            assert se.shape[0] == se.size
            seasonal_error_list.append(se)
        seasonal_errors = np.stack(seasonal_error_list, axis=0)
        self._past_abs_seasonal_error = seasonal_errors.reshape(n_folds, -1, 1, 1)
        self._past_abs_seasonal_error = self._past_abs_seasonal_error.clip(min=1e-3)

    @staticmethod
    def _safemean(array: Union[np.ndarray, torch.tensor]) -> float:
        """Compute mean of a numpy array-like object, ignoring inf, -inf and nan values."""
        xp = array_namespace(array, use_compat=False)
        return xp.nanmean(array[xp.isfinite(array)])

    def _compute_metric(
        self,
        y_true: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times]
        q_pred: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times, num_quantiles]
        quantile_levels: Union[List[float], np.ndarray, torch.tensor],
        # [num_quantiles]
    ) -> float:
        if self._past_abs_seasonal_error is None:
            raise AssertionError("Call `save_past_metrics` before `compute_metric`")

        num_folds, num_items, num_times = y_true.shape
        num_folds, num_items, num_times, num_quantiles = q_pred.shape
        assert len(quantile_levels) == num_quantiles

        plql = _per_level_quantile_loss(y_true, q_pred, quantile_levels)
        xp = array_namespace(plql, use_compat=False)
        # fold_weights = xp.asarray([1/(2**i) for i in range(num_folds)]).reshape(num_folds, 1, 1, 1)
        # fold_weights = xp.asarray([100*i for i in range(num_folds)]).reshape(num_folds, 1, 1, 1)
        # fold_weights = xp.asarray([2**i for i in range(num_folds)]).reshape(num_folds, 1, 1, 1)
        if self.fold_weight_fn is not None:
            fold_weights = xp.asarray([self.fold_weight_fn(i) for i in range(num_folds)])
            plql = plql * fold_weights[..., None, None, None]
        # print(f"SQL.compute_metric: plql.shape={plql.shape}")
        past_abs_seasonal_error = self._past_abs_seasonal_error
        if isinstance(y_true, torch.Tensor):
            past_abs_seasonal_error = torch.tensor(past_abs_seasonal_error)
        # print(f"SQL.compute_metric: scale.shape={scale.shape}")
        return self._safemean(plql / past_abs_seasonal_error)


class MASE(SQL):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @property
    def ispointforecast(self):
        return True

    def _compute_metric(
        self,
        y_true: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times]
        q_pred: Union[np.ndarray, torch.tensor],
        # [num_folds, num_items, num_times, num_quantiles]
        quantile_levels: Union[list[float], np.ndarray, torch.tensor],
        # [num_quantiles]
    ) -> float:
        # This function simply drops all the non-0.5 quantiles and then returns the SQL
        if 0.5 not in quantile_levels:
            raise ValueError("MASE requires quantile_levels to contain 0.5")
        median_quantile_idx = quantile_levels.index(0.5)
        q_pred = q_pred[..., median_quantile_idx, None]
        quantile_levels = [0.5]
        return super()._compute_metric(y_true, q_pred, quantile_levels)


def get_fold_weight_fn(fold_weight_fn_str: str = "none") -> callable:
    if fold_weight_fn_str == "none":
        return lambda f: 1
    elif fold_weight_fn_str == "linear":
        return lambda f: f+1
    elif fold_weight_fn_str == "exp":
        return lambda f: 2 ** f
    else:
        raise ValueError(f"Unknown fold weight function: {fold_weight_fn_str}")

def get_metric(metric_str: str, **kwargs) -> AbstractMetric:
    """
    Get the metric corresponding to the given metric string.
    """
    if metric_str.lower() == "wql":
        return WQL(**kwargs)
    elif metric_str.lower() == "sql":
        return SQL(**kwargs)
    elif metric_str.lower() == "mase":
        return MASE(**kwargs)
    else:
        raise ValueError(f"Unknown metric: {metric_str}")


class ClippedAutogluonSQL(AutogluonSQL):
    """
    A patched version of autogluon's SQL metric that clips the past seasonal error to a minimum value of 1e-3.
    This is to avoid division by zero when the seasonal error is zero.
    """
    def save_past_metrics(
        self,
        data_past: TimeSeriesDataFrame,
        target: str = "target",
        seasonal_period: int = 1,
        **kwargs,
    ) -> None:
        super().save_past_metrics(
            data_past=data_past,
            target=target,
            seasonal_period=seasonal_period,
            **kwargs,
        )
        self._past_abs_seasonal_error = self._past_abs_seasonal_error.clip(lower=1e-3)


class ClippedAutogluonMASE(AutogluonMASE):
    """
    A patched version of autogluon's MASE metric that clips the past seasonal error to a minimum value of 1e-3.
    This is to avoid division by zero when the seasonal error is zero.
    """
    def save_past_metrics(
        self,
        data_past: TimeSeriesDataFrame,
        target: str = "target",
        seasonal_period: int = 1,
        **kwargs,
    ) -> None:
        super().save_past_metrics(
            data_past=data_past,
            target=target,
            seasonal_period=seasonal_period,
            **kwargs,
        )
        self._past_abs_seasonal_error = self._past_abs_seasonal_error.clip(lower=1e-3)

    def compute_metric(
        self, data_future: TimeSeriesDataFrame, predictions: TimeSeriesDataFrame, target: str = "target", **kwargs
    ) -> float:
        # This patches the original AutogluonMASE which relies on a "mean" column instead of the 0.5 quantile
        assert "0.5" in predictions.columns
        predictions = predictions.copy()
        predictions["mean"] = predictions["0.5"]
        return super().compute_metric(data_future=data_future, predictions=predictions, target=target, **kwargs)


def get_ag_metric(metric_str: str):
    """
    Get the autogluon metric corresponding to the given metric string.
    In particular if SQL is requested, return the patched `ClippedAutogluonSQL` metric instead of Autogluon's original `SQL` metric.
    """
    if metric_str.lower() == "sql":
        return ClippedAutogluonSQL()
    elif metric_str.lower() == "mase":
        return ClippedAutogluonMASE()
    else:
        return check_get_evaluation_metric(metric_str)
