"""
Helpers for saving/loading fields and scalars.
"""

from typing import List, Sequence, Tuple

import h5py
import numpy as np
import torch

from Muscat.MeshTools.MeshInspectionTools import ComputeMeshMinMaxLengthScale


def get_bandwidth(mesh) -> float:
    """Compute a characteristic mesh length scale (bandwidth)."""
    return ComputeMeshMinMaxLengthScale(mesh)


def save_fields(filename: str, fields: List[torch.Tensor]) -> None:
    """Save a list of torch tensors into an HDF5 file under keys '0', '1', ..."""
    with h5py.File(filename, "w", libver="latest") as f:
        for idx, field in enumerate(fields):
            f.create_dataset(str(idx), data=field.detach().cpu().numpy())


def save_scalars(file_path: str, data_list: Sequence[np.ndarray]) -> None:
    """Save a list of numpy arrays into an HDF5 file under keys 'array_0', ..."""
    with h5py.File(file_path, "w") as f:
        for i, data_array in enumerate(data_list):
            f.create_dataset(f"array_{i}", data=data_array)


def load_fields(filename: str) -> List[torch.Tensor]:
    """Load tensors saved by `save_fields`, ordered by numeric key."""
    fields: List[torch.Tensor] = []
    with h5py.File(filename, "r") as f:
        for name in sorted(f.keys(), key=int):
            fields.append(torch.from_numpy(f[name][()]))
    return fields


def load_scalars(file_path: str) -> List[np.ndarray]:
    """Load numpy arrays saved by `save_scalars`, ordered by the numeric suffix."""
    data_list: List[np.ndarray] = []
    with h5py.File(file_path, "r") as f:
        for key in sorted(f.keys(), key=lambda x: int(x.split("_")[1])):
            data_list.append(f[key][()])
    return data_list


def relative_error(
    predictions: torch.Tensor,
    ground_truth: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute elementwise relative error and mean relative error per row."""
    denom = torch.max(ground_truth)
    relative_errors = torch.abs(predictions - ground_truth) / denom
    mean_relative_error = torch.mean(relative_errors, dim=1)
    return mean_relative_error, relative_errors
