# utils/csv_utils.py  (revised)
from typing import Any, Iterable, Optional, Tuple
import pandas as pd
import numpy as np

try:
    import torch
    _TORCH_AVAILABLE = True
except Exception:
    torch = None
    _TORCH_AVAILABLE = False


def _to_1d_array(x: Any) -> np.ndarray:
    """
    Convert x (scalar | list | numpy array | torch tensor) to a 1D float numpy array.
    """
    if _TORCH_AVAILABLE and isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    if isinstance(x, (list, tuple)):
        return np.asarray(x, dtype=float).reshape(-1)
    if isinstance(x, np.ndarray):
        return x.astype(float).reshape(-1)
    # scalar fallback
    try:
        return np.asarray([float(x)], dtype=float)
    except Exception:
        # last resort
        return np.asarray([x], dtype=float)


def _relative_error(x_opt: np.ndarray, x_star: np.ndarray, eps: float = 1e-12) -> float:
    """
    Relative L2 error: ||x_opt - x_star|| / (||x_star|| + eps).
    Works for scalars and vectors.
    """
    num = np.linalg.norm(x_opt - x_star)
    den = np.linalg.norm(x_star) + eps
    return float(num / den)


def compute_errors(
    a_opt: Any,
    b_opt: Any,
    a_star: Any,
    b_star: Any,
    *,
    as_percent: bool = False,
    eps: float = 1e-12
) -> Tuple[float, float]:
    '''
    Compute relative errors between optimized and theoretical values.

    Inputs can be scalars, lists, numpy arrays, or torch tensors.
    - For scalars: |a_opt - a_star| / (|a_star| + eps)
    - For vectors:  ||a_opt - a_star||_2 / (||a_star||_2 + eps)
    Same for b.

    Returns:
        (err_a, err_b)  # optionally in percent if as_percent=True
    '''
    a_opt_arr  = _to_1d_array(a_opt)
    a_star_arr = _to_1d_array(a_star)
    b_opt_arr  = _to_1d_array(b_opt)
    b_star_arr = _to_1d_array(b_star)

    err_a = _relative_error(a_opt_arr, a_star_arr, eps=eps)
    err_b = _relative_error(b_opt_arr, b_star_arr, eps=eps)

    if as_percent:
        err_a *= 100.0
        err_b *= 100.0
    return err_a, err_b


def sort_and_save(
    results: Iterable[dict],
    file_name: str = "results.csv",
    sort_by: Optional[str] = None,
    ascending: bool = True
) -> pd.DataFrame:
    """
    Create a DataFrame from `results`, pick a sensible sort column, save to CSV, and return the sorted DF.

    Heuristics for `sort_by` if not provided:
      1) "error_a (%)"
      2) "error_a"
      3) "err_a_l2"
      4) "err_t_l2"
      5) any column that contains "error" (first match)
      6) fall back to "u1" if present, else no sorting
    """
    df = pd.DataFrame(results)

    if sort_by is None:
        candidates = [
            "error_a (%)",
            "error_a",
            "err_a_l2",
            "err_t_l2",
        ]
        sort_by = next((c for c in candidates if c in df.columns), None)
        if sort_by is None:
            # any column containing 'error'
            error_cols = [c for c in df.columns if "error" in c.lower()]
            sort_by = error_cols[0] if error_cols else ("u1" if "u1" in df.columns else None)

    if sort_by is not None and sort_by in df.columns:
        sorted_df = df.sort_values(sort_by, ascending=ascending).reset_index(drop=True)
    else:
        sorted_df = df.reset_index(drop=True)

    sorted_df.to_csv(file_name, index=False)
    print(f"✅ Saved to {file_name} (sorted by: {sort_by if sort_by else 'none'})")
    return sorted_df

def _slug_float(x: float) -> str:
    """Sanitize float so it’s safe in filenames: 0.01 -> 0p01, 10.0 -> 10."""
    if float(x).is_integer():
        return str(int(x))
    return str(x).replace(".", "p")

def save_results_for_c(
    results: Iterable[dict],
    c_value: float,
    root_dir: str,
    setting: str,
    var_name: str = "c",
    file_prefix: str = "results",
    sort_by: Optional[str] = None,
    ascending: bool = True,
    add_timestamp: bool = True,
) -> pd.DataFrame:
    """
    Save results into a CSV file under a path that encodes the c value.
    Example:
      root_dir/setting/c_rel/c=10/results_seed=0_2025-09-22T13-40-00.csv
    """
    df = pd.DataFrame(results)

    # pick sort key if not supplied
    if sort_by is None:
        candidates = ["error_a (%)", "error_a", "err_a_l2", "err_t_l2"]
        sort_by = next((c for c in candidates if c in df.columns), None)
        if sort_by is None:
            error_cols = [c for c in df.columns if "error" in c.lower()]
            sort_by = error_cols[0] if error_cols else None

    if sort_by is not None and sort_by in df.columns:
        df = df.sort_values(sort_by, ascending=ascending).reset_index(drop=True)

    # build directory path
    c_slug = _slug_float(c_value)
    out_dir = os.path.join(root_dir, setting, f"{var_name}_rel", f"{var_name}={c_slug}")
    os.makedirs(out_dir, exist_ok=True)

    # unique filename
    stamp = time.strftime("%Y-%m-%dT%H-%M-%S") if add_timestamp else ""
    fname = f"{file_prefix}{'_' + stamp if stamp else ''}.csv"
    file_path = os.path.join(out_dir, fname)

    df.to_csv(file_path, index=False)
    print(f"✅ Saved c={c_value} results to {file_path} (sorted by: {sort_by or 'none'})")
    return df
