import math

import numpy as np
import torch
from torch import Tensor


def metrics2string(**kwargs) -> str:
    """
    Formats metrics as a string with each key-value pair in "key: mean±uncertainty" format.
    """
    return "; ".join(
        f"{key}: {_format_with_uncert(samples)}" for key, samples in kwargs.items()
    )


def kwargs2string(**kwargs) -> str:
    """
    Converts keyword arguments to a formatted string. Format is "keyword1: value1; keyword2: value2 ...".
    Args:
        **kwargs: Arbitrary keyword arguments.
    Returns:
        str: A formatted string with each key-value pair.
    """
    return "; ".join(f"{key}: {_format(value)}" for key, value in kwargs.items())


def _format(value) -> str:
    if value is None:
        return ""
    if (
        (isinstance(value, Tensor) and value.nelement() == 1)
        or isinstance(value, np.ndarray)
        or isinstance(value, np.float32)
    ):
        value = float(value.item())
    if isinstance(value, Tensor) and value.nelement() > 1:
        value = value.round(decimals=0)
        value = value.tolist()
    fvalue = value  # Print other types as they are
    if isinstance(value, float):
        if abs(value) < 1e-2 or abs(value) > 1e4:
            fvalue = _exp_format(value)
        else:
            fvalue = f"{value:.3g}"
    return f"{fvalue}"


def _format_with_uncert(samples) -> str:
    if samples is None:
        return ""
    if isinstance(samples, list):
        samples = torch.tensor(samples, dtype=torch.float32)
    if isinstance(samples, np.ndarray):
        samples = torch.from_numpy(samples.astype(np.float32))
    if not isinstance(samples, Tensor):
        raise ValueError("Only list, array or tensor are supported.")
    samples = samples.to(dtype=torch.float32)
    if len(samples.shape) != 1 or samples.size(0) < 1:
        raise ValueError("metric_samples should be 1d and contain at least 1 sample.")
    if samples.size(0) == 1:
        std, mean = torch.tensor([np.nan]), samples
    else:
        std, mean = torch.std_mean(samples)
    mean, std = mean.item(), std.item()
    return format_si(mean, std)


def format_si(mean, std, sig_digits=1):
    if std == 0 or math.isnan(std):
        return f"{mean:.{sig_digits}g}(0)"

    # Determine exponent of the std to get significant digits
    exponent = math.floor(math.log10(std))
    decimals = max(-exponent + (sig_digits - 1), 0)

    # Round both mean and std to this precision
    mean_rounded = round(mean, decimals)
    std_rounded = round(std, decimals)

    # Format string with fixed number of decimals
    format_str = f"{{:.{decimals}f}}"
    mean_str = format_str.format(mean_rounded)
    # No dot in std in SI style
    std_str = format_str.format(std_rounded).replace(".", "")

    return f"{mean_str}({std_str})"


def _exp_format(value: float, precision=2, exp_width=1) -> str:
    """
    Format a float in scientific notation with fixed-width exponent.
    E.g. 1.2e-5 -> '1.20e-005'
    """
    formatted = f"{value:.{precision}e}"
    if "e" in formatted:
        base, exp = formatted.split("e")
        sign = exp[0]
        digits = exp[1:].rjust(exp_width, "0")
        return f"{base}e{sign}{digits}"
    return formatted
