import numpy as np


def covr_fun(ci, target):
    """Checks if the target value falls within the prediction interval(s)."""
    if ci is None or np.all(np.isnan(ci)):
        return 0
    ci = np.atleast_2d(ci)
    covered = False
    # Simple loop check
    for i in range(ci.shape[0]):
        # Basic check assuming ci[i,0] <= ci[i,1]
        if ci[i, 0] <= target <= ci[i, 1]:
            covered = True
            break
    return int(covered)


def loss_fun(ci):
    """Calculates the total length of the prediction interval(s)."""
    if ci is None or np.all(np.isnan(ci)):
        return 0  # Treat NA as 0 loss (consistent with original)
    ci = np.atleast_2d(ci)
    total_length = 0
    for i in range(ci.shape[0]):
        length = ci[i, 1] - ci[i, 0]
        # Return 0 if length is Inf or negative (consistent with original TODO comment logic)
        if not np.isfinite(length) or length < -1e-9:
            return 0
        total_length += max(0, length)

    return total_length


def _counts_int_py(intervals, weights, midpoint):
    """Helper: Calculates weighted coverage indicator mean at a midpoint."""
    # Basic check for valid inputs assumed for simplicity
    coverage_indicators = (intervals[:, 0] <= midpoint) & (midpoint <= intervals[:, 1])
    try:
        # Handle potential NaNs in coverage_indicators before averaging
        valid_inds = ~np.isnan(coverage_indicators) & ~np.isnan(weights)
        if np.sum(weights[valid_inds]) < 1e-9:
            return 0.0

        # Approximate R's NA propagation if indicator is NA
        if np.any(np.isnan(coverage_indicators[valid_inds])):
            return np.nan

        return np.average(coverage_indicators[valid_inds], weights=weights[valid_inds])
    except ZeroDivisionError:
        return 0.0


def majority_vote(intervals, weights, rho=0.5):
    """
    Simplified weighted majority vote, closer to the structure of the original R code.
    Less robust edge-case handling.
    (Kept original Python version - differences from R are complex and secondary to hedge/adahedge focus)
    """
    if intervals is None or weights is None or len(intervals) != len(weights):
        return None
    K = intervals.shape[0]
    if K == 0:
        return None

    # Basic filtering of NaN intervals/weights for stability
    valid_mask = ~np.isnan(intervals).any(axis=1) & ~np.isnan(weights)
    if not np.any(valid_mask):
        return None
    intervals = intervals[valid_mask, :]
    weights = weights[valid_mask]

    weights_sum = np.sum(weights)
    if weights_sum <= 1e-9:
        return None  # Avoid division by zero later
    weights = weights / weights_sum  # Normalize valid weights

    # Get unique, sorted finite breakpoints (Difference from R which uses all breaks)
    finite_breaks = np.unique(intervals[np.isfinite(intervals)])
    breaks = np.sort(finite_breaks)

    if len(breaks) < 2:
        # Simplified check for [-inf, inf] like R code might implicitly do
        inf_inf_indicators = (intervals[:, 0] <= -np.inf) & (intervals[:, 1] >= np.inf)
        if np.sum(weights * inf_inf_indicators) > rho:
            return np.array([[-np.inf, np.inf]])
        else:
            return None  # Mimic R returning NA

    merged_intervals_list = []
    i = 0
    n_breaks = len(breaks)

    while i < n_breaks - 1:
        if breaks[i + 1] <= breaks[i] + 1e-9:  # Skip zero-width segments
            i += 1
            continue

        midpoint = (breaks[i] + breaks[i + 1]) / 2.0
        weighted_coverage = _counts_int_py(intervals, weights, midpoint)

        if np.isnan(weighted_coverage):
            is_covered = False  # Treat NA as not covered
        else:
            is_covered = weighted_coverage > rho

        if is_covered:
            start_break = breaks[i]
            j = i + 1
            while j < n_breaks - 1:
                if breaks[j + 1] <= breaks[j] + 1e-9:
                    j += 1
                    continue
                next_midpoint = (breaks[j] + breaks[j + 1]) / 2.0
                next_coverage = _counts_int_py(intervals, weights, next_midpoint)
                if np.isnan(next_coverage) or next_coverage <= rho:
                    break
                j += 1

            end_break = breaks[j]
            merged_intervals_list.append([start_break, end_break])
            i = j  # Skip outer loop
        else:
            i += 1  # Move to next segment

    if not merged_intervals_list:
        # Final check for full [-inf, inf] if nothing else found
        inf_inf_indicators = (intervals[:, 0] <= -np.inf) & (intervals[:, 1] >= np.inf)
        if np.sum(weights * inf_inf_indicators) > rho:
            return np.array([[-np.inf, np.inf]])
        else:
            return None
    else:
        return np.array(merged_intervals_list)  # Return potentially unmerged blocks




def _mix(L, eta):
    """
    Revised _mix to closely follow R's mix function logic, including NaN/Inf handling.
    Removed clipping and explicit nan=0 conversions.
    """
    K = len(L)
    if K == 0:
        return {"w": np.array([]), "M": np.nan}

    L_array = np.array(L, dtype=float)  # Ensure numpy float array for calculations

    # R's min finds the minimum of finite values if any exist
    finite_L = L_array[np.isfinite(L_array)]
    if len(finite_L) == 0:
        mn = np.inf  # R's min(numeric(0)) or min(c(Inf, Inf)) is Inf
    else:
        mn = np.min(finite_L)

    # Handle eta == Inf (or non-finite) like R
    if not np.isfinite(eta):  # Covers eta=inf, eta=-inf, eta=nan
        if not np.isfinite(mn):
            # All L were Inf or L was empty
            w = np.ones(K) / K if K > 0 else np.array([])
        else:
            # R uses L == mn check
            with np.errstate(invalid="ignore"):  # Ignore comparisons with NaN
                w = (L_array == mn).astype(float)
            w_sum = np.sum(w)
            # Avoid division by zero if mn wasn't found or all L are Inf > mn
            if w_sum > 1e-9:
                w = w / w_sum
            else:  # Fallback if sum is zero (e.g. L = [Inf, Inf]) or all L are NaN
                w = np.ones(K) / K if K > 0 else np.array([])
        M = mn  # R returns mn (which could be Inf)

    # R likely errors on eta=0 because of log(s/K)/eta. Mimic with NaN M.
    elif np.abs(eta) < 1e-12:  # Check for eta close to zero
        w = np.ones(K) / K if K > 0 else np.array([])
        M = np.nan

    else:  # Normal case: 0 < eta < Inf
        # R: w <- exp(-eta * (L - mn))
        # Calculate exponent without clipping, allowing Inf/NaN/0
        exp_arg = -eta * (L_array - mn)

        # np.exp handles large negative -> 0, large positive -> inf
        # NaNs in L will result in NaN in w_unnormalized
        # Infs in L: If eta > 0, exp(-eta*(Inf-mn)) -> 0. If eta < 0, -> Inf.
        with np.errstate(
            over="ignore", invalid="ignore"
        ):  # Suppress warnings temporarily
            w_unnormalized = np.exp(exp_arg)

        # R: s <- sum(w). Sum can be NaN/Inf/0.
        s = np.sum(w_unnormalized)

        # R: w <- w / s
        # Check if s is valid for division (finite, non-zero)
        # If s is NaN, Inf, or 0, R would produce NaN/Inf weights. Mimic this.
        if np.isfinite(s) and s > 1e-9:
            w = w_unnormalized / s
            # R: M <- mn - log(s / length(L)) / eta
            try:
                log_arg = s / K
                # Check for log(negative or zero) or log(Inf)
                if log_arg <= 0 or not np.isfinite(log_arg):
                    M = np.nan
                else:
                    log_term = np.log(log_arg)
                    M = mn - log_term / eta
                    # Allow M to be Inf, but convert NaN result from calculation to np.nan
                    if np.isnan(M):
                        M = np.nan

            except (ValueError, FloatingPointError):
                M = np.nan  # Catch potential errors during calculation
        else:
            # s is NaN, Inf, or close to zero. R would likely yield NaN/Inf.
            # Set weights and M to NaN to propagate the issue.
            w = np.full(K, np.nan)
            M = np.nan

    # Ensure w is valid shape, even if empty
    if K == 0:
        w = np.array([])
    # Ensure M is scalar nan if K=0
    if K == 0:
        M = np.nan

    return {"w": w, "M": M}


def hedge(losses, eta):
    """
    Revised Hedge mirroring R: No clipping, no nan_to_num, propagate NaNs.
    """
    if losses.ndim == 1:
        losses = losses.reshape(-1, 1)
    N, K = losses.shape
    if K == 0:
        return {"h": np.array([]), "weights": np.array([]).reshape(N, 0)}

    h = np.full(N, np.nan)
    weights_history = np.full((N, K), np.nan)
    w = np.ones(K) / K if K > 0 else np.array([])

    for t in range(N):
        weights_history[t, :] = w  # Store weights *before* update

        current_loss = losses[t, :]  # Keep as is, may contain NaN/Inf

        # R: w <- w * exp(-eta * l[t,])
        # Calculate exponent without clipping
        exp_arg = -eta * current_loss
        with np.errstate(over="ignore", invalid="ignore"):  # Allow Inf/NaN/0 from exp
            # If current_loss has NaN, w_unnormalized will have NaN
            # If w has NaN, w_unnormalized will have NaN
            w_unnormalized = w * np.exp(exp_arg)

        # R: w <- w / sum(w)
        w_sum = np.sum(w_unnormalized)  # Sum can be NaN/Inf/0

        # Match R's division behavior: NaN/Inf propagation
        # If sum is 0, NaN, Inf, R would likely give NaN/Inf weights
        if np.isfinite(w_sum) and w_sum > 1e-9:  # Check for positive finite sum
            w = w_unnormalized / w_sum
        else:
            # Propagate issue by setting weights to NaN
            w = np.full(K, np.nan)

        # R: h[t] <- sum(w * l[t,])
        # Calculate h[t] using updated w, allowing NaN propagation
        # np.sum propagates NaN by default
        # Use updated w and original current_loss (may contain NaN/Inf)
        h[t] = np.sum(
            w * current_loss
        )  # If w or current_loss has NaN, h[t] becomes NaN

    return {"h": h, "weights": weights_history}


def adahedge(losses):
    """
    Revised AdaHedge mirroring R: Correct eta, no nan_to_num, propagate NaNs in L, h, delta.
    Uses revised _mix.
    """
    if losses.ndim == 1:
        losses = losses.reshape(-1, 1)
    N, K = losses.shape
    if K == 0:
        return {
            "h": np.array([]),
            "weights": np.array([]).reshape(N, 0),
            "eta": np.array([]),
        }

    h = np.full(N, np.nan)
    # Ensure L starts as float to allow NaNs later
    L = np.zeros(K, dtype=float)
    etas = np.full(N, np.nan)
    weights_history = np.full((N, K), np.nan)
    Delta = 0.0

    for t in range(N):
        # R: eta <- log(K) / Delta
        # Handle K=1 -> log(K)=0. Handle Delta=0.
        eta_t = np.nan  # Default to NaN
        if K > 1:
            if Delta > 1e-12:  # Delta is positive
                eta_t = np.log(K) / Delta
            elif Delta == 0:  # Delta is exactly 0 (or very close)
                eta_t = np.inf
            # If Delta < 0 (shouldn't happen), eta_t remains NaN
        elif K == 1:
            # R: log(1)/Delta = 0/Delta. This is 0 if Delta != 0, NaN if Delta == 0.
            if Delta == 0:
                eta_t = np.nan
            else:  # Includes Delta > 0 and potentially Delta < 0
                eta_t = 0.0
        # If K < 1, eta_t remains NaN

        etas[t] = eta_t

        # Use revised mix function
        result_prev = _mix(L, eta_t)
        w = result_prev["w"]  # Can be NaN
        M_prev = result_prev["M"]  # Can be NaN/Inf
        weights_history[t, :] = w

        current_loss = losses[t, :]  # Keep as is, may contain NaN/Inf

        # R: h[t] <- sum(w * l[t,])
        # Allow NaN propagation using potentially NaN weights 'w'
        h[t] = np.sum(w * current_loss)  # Can be NaN/Inf

        # R: L <- L + l[t,]
        # Allow NaN propagation in L
        L = L + current_loss  # If current_loss is NaN or L is NaN, L propagates NaN

        # Use revised mix function again (L might contain NaN)
        result_curr = _mix(L, eta_t)
        M_curr = result_curr["M"]  # Can be NaN/Inf

        # R: delta <- max(0, h[t] - (result$M - Mprev))
        # Calculate difference, allowing NaN/Inf
        diff_M = M_curr - M_prev  # Can be NaN/Inf

        # Calculate delta_t, allowing NaN/Inf propagation via h[t] or diff_M
        # np.maximum propagates NaN. max(0, Inf) is Inf.
        # Use errstate to suppress warning on comparing NaN
        with np.errstate(invalid="ignore"):
            delta_t = np.maximum(
                0.0, h[t] - diff_M
            )  # Can be NaN or Inf or positive finite

        # R: Delta <- Delta + delta
        # Allow Delta to become NaN or Inf if delta_t is NaN/Inf
        # Ensure delta_t is treated as 0 if it's NaN when adding to finite Delta?
        # R's default sum(c(1, NaN)) is NA. So Delta should become NaN if delta_t is NaN.
        Delta = Delta + delta_t  # If delta_t is NaN/Inf, Delta becomes NaN/Inf

    return {"h": h, "weights": weights_history, "eta": etas}

