

import os
import numpy as np


def calculate_metric(preds: np.ndarray, labels: np.ndarray):
    """
    Compute CP/CR/CF1 and OP/OR/OF1 metrics.
    Args:
        preds: binary predictions [N, C] (0/1)
        labels: ground-truth binary labels [N, C] (0/1)
    Returns:
        (CP, CR, CF1, OP, OR, OF1) in percentage (0-100)
    """
    # Column-wise stats
    n_correct_pos = (labels * preds).sum(0)
    n_pred_pos = (preds == 1).sum(0)
    n_true_pos = labels.sum(0)

    with np.errstate(divide='ignore', invalid='ignore'):
        OP = n_correct_pos.sum() / max(n_pred_pos.sum(), 1)
        OR = n_correct_pos.sum() / max(n_true_pos.sum(), 1)

        CP = np.nanmean(n_correct_pos / np.maximum(n_pred_pos, 1))
        CR = np.nanmean(n_correct_pos / np.maximum(n_true_pos, 1))

    CF1 = (2 * CP * CR) / (CP + CR) if (CP + CR) > 0 else 0.0
    OF1 = (2 * OP * OR) / (OP + OR) if (OP + OR) > 0 else 0.0

    return CP * 100, CR * 100, CF1 * 100, OP * 100, OR * 100, OF1 * 100


def load_answers_array(path_without_suffix: str):
    """
    Try to load a per-class result file given a base path WITHOUT suffix.
    Priority:
        - <base>.npz and read key 'answers'
        - <base>.npy
    Returns:
        N-length array or None if not found.
    """
    npz_path = path_without_suffix + ".npz"
    npy_path = path_without_suffix + ".npy"
    if os.path.exists(npz_path):
        z = np.load(npz_path, allow_pickle=True)
        if 'answers' not in z.files:
            raise KeyError(f"'answers' not found in {npz_path}")
        return z['answers']
    if os.path.exists(npy_path):
        return np.load(npy_path, allow_pickle=True)
    return None


def answers_to_binary(ans):
    """
    Convert an answers array to binary {0,1}.
    Supports strings like 'yes'/'no', booleans, ints, and floats (threshold 0.5).
    """
    if ans is None:
        return None
    arr = np.asarray(ans)
    # String-like
    if arr.dtype.kind in ('U', 'S', 'O'):
        out = np.zeros(arr.shape[0], dtype=np.uint8)
        for i, v in enumerate(arr):
            s = str(v).strip().lower()
            if v.lower() == "y" or v.lower() == "yes":
                out[i] = 1
            else:
                out[i] = 0  # fallback
            # if v.lower() == "n" or v.lower() == "no":
            #     out[i] = 0
            # else:
            #     out[i] = 1 # fallback
        return out

    # Numeric
    if np.issubdtype(arr.dtype, np.bool_) or np.issubdtype(arr.dtype, np.integer):
        v = arr.astype(np.uint8)
        return (v > 0).astype(np.uint8)
    if np.issubdtype(arr.dtype, np.floating):
        conf = arr.astype(np.float32)
        return (conf >= 0.5).astype(np.uint8)

    # Fallback
    return (arr != 0).astype(np.uint8)

def _build_base_no_subset(outputs_root: str, dataset: str, method: str,
                          cls_id: int, partition: str):
    """
    Build per-class base path when no subset is provided (merged mode).
    For COCO: same as before.
    For O365: replace subset segment with 'merged' tag.
    """
    if dataset in ["coco2014", "coco2017"]:

        return os.path.join(outputs_root, dataset, f"{method}_mlc_1", f"answer_{cls_id}")
    elif dataset == "o365":

        return os.path.join(
            outputs_root, "objects365",
            f"{method}_mlc_1",
            f"answer_{cls_id}"
        )
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

def build_paths_for_class(outputs_root: str, dataset: str, method: str,
                          cls_id: int, partition: str, subset: str):
    """
    Build the base path (without suffix) for one class file, given dataset/method layout.
    Returns:
        path without suffix (.npy or .npz)
    """
    if dataset == "coco":

        base = os.path.join(outputs_root, "coco", f"{method}_coco_train_1")
        return os.path.join(base, f"cls_{cls_id}")
    elif dataset == "o365":

        base = os.path.join(
            outputs_root,
            "objects365",
            f"{method}_{partition}_objects365_train_{subset}_1"
        )
        return os.path.join(base, f"cls_{cls_id}")
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")


def load_labels_for_subset(data_root: str, dataset: str, partition: str, subset: str):
    """
    Load ground-truth label matrix for a given subset.
    Returns:
        labels [N, C]
    """
    if dataset == "o365":
        p = os.path.join(data_root, "objects365", partition, f"formatted_train_labels_{subset}.npy")
        if not os.path.exists(p):
            raise FileNotFoundError(f"Label file not found: {p}")
        return np.load(p, allow_pickle=True)
    elif dataset == "coco2014":
        p = os.path.join(data_root, "coco2014", "pml", "formatted_train_labels.npy")
        if not os.path.exists(p):
            raise FileNotFoundError(f"Label file not found: {p}")
        return np.load(p, allow_pickle=True)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")


def print_basic_stats(name: str, preds: np.ndarray, labels: np.ndarray):
    """
    Print basic statistics for debugging/monitoring.
    """
    print(f"Size ({name}): preds {preds.shape} | labels {labels.shape}")
    unique_vals = np.unique(preds)
    print(f"{name} unique elements in preds: {unique_vals}")
    zero_rows = np.all(preds == 0, axis=1).sum()
    print(f"{name} zero rows: {int(zero_rows)}")
    print(f"{name} avg candidate labels per image: {preds.sum(axis=1).mean():.2f}")
