from collections.abc import Iterable
import warnings

import numpy as np

def r2_score(y_true, y_pred, axis=None, axis_bias=None, axis_ref=None, force_finite=True):
    """
    R^2 score for multidimensional predictions.
    collapses all axes except the specified axis.

    Computes 1 - RSS / TSS, where RSS is the residual sum of squares and TSS is the total sum of squares.

    Parameters
    ----------
    y_true : np.ndarray
    y_pred : np.ndarray
    axis: int or iterable of int, default=None
        Axis to collapse.
        If None, collapses all axes to yield a single number.
        
    axis_ref: reference axis to measure variability across, which normalizes r2 score.

    axis_bias: axis used to measure y_true.mean(axis=axis_bias), when measuring reference variability (TSS) to normalize.
    
    Returns
    -------
    z : np.ndarray
        if axis is specified, returns an array of shape with remaining axes.
        if axis=None, a single number is returned.
    """

    if axis is None: # Default to collapsing all dimensions
        axis = tuple(range(y_true.ndim))
    elif not isinstance(axis, Iterable): # Single int
        axis = (axis,)

    if axis_ref is None: # Default to axis
        axis_ref = axis
    elif not isinstance(axis_ref, Iterable): # Single int
        axis_ref = (axis_ref,)

    if axis_bias is None: # Default to axis_ref
        axis_bias = axis_ref
    elif not isinstance(axis_bias, Iterable): # Single int
        axis_bias = (axis_bias,)


    axis_ref_set = set(axis_ref)
    axis_set = set(axis)
    axis_bias_set = set(axis_bias)

    assert axis_bias_set.issubset(axis_ref_set), f'axis_bias ({axis_bias}) must be a subset of axis_ref ({axis_ref}) because axis_ measure variability' # If axis_bias is not a subset of axis_ref, expand axis_ref to include axis_bias

    # Dimensions of TSS must be smaller than RSS, so mean/sum over axis_ref and axis
    axis_sum = axis
    axis_mean = tuple(axis_ref_set - axis_set) # Average over axis_ref - axis

    if not axis_set.issubset(axis_ref_set): # If axis is not a subset of axis_ref, expand axis_ref to include axis
        warnings.warn(f"axis {axis} is not a subset of axis_ref {axis_ref}, TSS sums over axis {axis_sum} and averages over remaining axis_ref {axis_mean}")

    # Residual Sum of Squares (RSS) and Total Sum of Squares (TSS)
    RS = (y_true - y_pred)**2 # Residual Square (RS)
    RSS = np.sum(RS, axis=axis, keepdims=True)

    # axis_bias used to compute the mean of y_true
    y_mean = np.mean(y_true, axis=axis_bias, keepdims=True)
    TS = (y_true - y_mean)**2 # Total Square (TS)

    # axis_ref used to aggregate additional dimensions (average) additional to axis
    TSS = np.mean(TS, axis=axis_mean, keepdims=True)
    TSS = np.sum(TSS, axis=axis_sum, keepdims=True)
    
    score = 1 - RSS / TSS
    score = np.squeeze(score, axis=axis) # Collapse the axis_ref dimension

    if force_finite:
        score[np.isnan(score)] = 1
        score[np.isinf(score)] = 0 # -Inf means no fit, so set to 0

    if len(axis) == y_true.ndim:
        score = score.item()

    return score