"""Compute hypervolume for a batch of solutions, under multi-objective MINIMIZATION."""

import torch
from pymoo.indicators.hv import HV
import numpy as np
from typing import Optional, Tuple
from torch import Tensor
from data.function_preprocessing import min_max_normalize
from einops import repeat
from utils.data import tnp


def norm_compute_hv_batch(
    solutions: Tensor,
    minimum: Tensor,
    maximum: Tensor,
    ref_point: Optional[Tensor] = None,
    y_mask: Optional[Tensor] = None,
    normalize: bool = True,
) -> Tuple[np.ndarray, Tensor, Tensor]:
    """Compute hypervolume for batch solutions, with following steps:
        - Process solutions and reference points
        - Optional normalization: `solutions` and `ref_point` are normalized to [0, 1]
        - Hypervolume computation

    Args:
        solutions: [B, N, max_y_dim]
        minimum, maximum, ref_point, y_mask: [B, max_y_dim] or [max_y_dim]
        normalize: Normalize `solutions` and `ref_point` to [0, 1] if True

    Returns:
        hv: hypervolume values, [B]
        solutions: (optionally normalized) solutions, [B, N, max_y_dim]
        ref_point: (optionally normalized) reference points, [B, 1, max_y_dim]
    """
    solutions, ref_point = _get_solutions_n_ref_points(
        solutions=solutions,
        minimum=minimum,
        maximum=maximum,
        ref_point=ref_point,
        normalize=normalize,
    )

    # Compute the hypervolume
    hv = compute_hv_batch(ref_point=ref_point, solutions=solutions, y_mask=y_mask)
    return hv, solutions, ref_point


def compute_hv_batch(
    ref_point: Tensor, solutions: Tensor, y_mask: Optional[Tensor] = None
) -> np.ndarray:
    """Compute hypervolume for batch solutions with respect to reference points.

    Args:
        ref_point: [B, 1, max_y_dim]
        solutions: [B, N, max_y_dim]
        y_mask: Optional mask for valid dims, [B, max_y_dim]

    Returns: hvs [B]
    """
    B, N, max_y_dim = solutions.shape
    assert ref_point.shape == (B, 1, max_y_dim)

    # Prepare y_mask
    if y_mask is None:
        y_mask = torch.ones((B, max_y_dim), dtype=torch.bool, device=solutions.device)
    else:
        if y_mask.ndim == 1:
            y_mask = repeat(y_mask, "d -> b d", b=B)
        assert y_mask.shape == (B, max_y_dim)

    hvs = np.empty(B, dtype=np.float64)

    for b in range(B):
        mask_b = y_mask[b]  # [max_y_dim]

        # Valid y dimension counts, NOTE can vary between y[b]
        dy = mask_b.int().sum().item()

        refs_b = ref_point[b, 0, mask_b]  # [dy]
        sols_b = solutions[b, :, mask_b]  # [N, dy]

        assert refs_b.shape == (dy,), f"refs_b: {refs_b.shape}"
        assert sols_b.shape == (N, dy), f"sols_b: {sols_b.shape}"

        hv = compute_hv(ref_point=refs_b, solutions=sols_b)

        hvs[b] = hv

    return hvs


def compute_hv(ref_point: np.ndarray | Tensor, solutions: np.ndarray | Tensor) -> float:
    """Compute hypervolume for a set of solutions with respect to a reference point.

    Args: ref_point [y_dim], solutions [num_solutions, y_dim]

    Returns: hv, float
    """
    ref_point = tnp(ref_point)
    solutions = tnp(solutions)

    hv_indicator = HV(ref_point=ref_point)
    hv = hv_indicator(solutions)

    return hv


def _get_solutions_n_ref_points(
    solutions: Tensor,
    minimum: Tensor,
    maximum: Tensor,
    ref_point: Optional[Tensor] = None,
    normalize: bool = True,
) -> Tuple[Tensor, Tensor]:
    """Get solutions and reference points for hypervolume computation.
        - If `ref_point=None`: `ref_point=maximum`
        - If `normalize=True`: `solutions` and `ref_point` are normalized to `[0, 1]`

    Normalizing outputs during training can remove bias towards larger objectives.

    Args:
        solutions: [B, N, D]
        minimum, maximum, ref_point: [B, D] or [D]
        normalize: Normalize solutions and reference points to [0, 1] if True

    Returns:
        solutions [B, N, D], ref_point [B, 1, D]
    """
    B, _, D = solutions.shape

    # Prepare value for reference points
    if ref_point is None:
        ref_point = maximum
    else:
        assert ref_point.shape == (B, D) or ref_point.shape == (D,)

    # Prepare shape for reference points: [B, 1, D]
    if ref_point.ndim == 1:
        ref_point = repeat(ref_point, "d -> b d", b=B)
    ref_point = ref_point.unsqueeze(1)

    # Normalize solutions and reference points if requested
    if normalize:
        solutions = min_max_normalize(solutions, minimum, maximum)
        ref_point = min_max_normalize(ref_point, minimum, maximum)

    return solutions, ref_point
