import argparse
from typing import Tuple

import numpy as np


def load_result_from_npy(npy_path: str):
    """
    npy -> list
    """
    arr = np.load(npy_path, allow_pickle=True)
    return arr.tolist()

def compute_vig_scores(result, empty_policy: str = "original", alpha: str = "mean") -> np.ndarray:
    """
    VIG score of each sample
    """
    max_diffs = np.empty(len(result), dtype=np.float32)

    for i, r in enumerate(result):
        w = np.asarray(r.get("w_loss", []), dtype=np.float32)
        wo = np.asarray(r.get("wo_loss", []), dtype=np.float32)
        L = min(w.size, wo.size)

        if L == 0:
            if empty_policy == "skip":
                max_diffs[i] = -np.inf
            elif empty_policy == "original":
                max_diffs[i] = -np.inf
            elif empty_policy == "selective":
                max_diffs[i] = np.inf
            else:
                raise ValueError(f"unknown empty_policy: {empty_policy}")
            continue

        d = wo[:L] - w[:L]
        if alpha == "max":
            max_diffs[i] = d.max()
        elif alpha == "mean":
            max_diffs[i] = d.mean()
        else:
            max_diffs[i] = np.maximum(d, 0).mean()
    return max_diffs


def select_top_p(
    vig_scores: np.ndarray,
    p: float,
) -> Tuple[float, np.ndarray]:
    if not (0 < p <= 100):
        raise ValueError("p has to be (0, 100)")
    
    vig_scores_array = np.asarray(vig_scores, dtype=np.float32)
    N = vig_scores_array.size
    K_target = int(round(N * p / 100))

    sorted_indices = np.argsort(vig_scores_array)
    cand_sel_idx = sorted_indices[-K_target:]
    tau_p = vig_scores_array[cand_sel_idx].min()
    sel_mask = vig_scores_array > tau_p
    sel_idx = np.flatnonzero(sel_mask).astype(np.int64)
    
    return tau_p, sel_idx


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--p",
        type=float,
        default=70,
        help="ratio p",
    )
    parser.add_argument(
        "--npy-path",
        type=str,
        default=None,
        help="vig score of answer tokens file",
    )
    parser.add_argument(
        "--save-idx",
        type=str,
        default="sel_idx.npy",
        help="output file name",
    )
    args = parser.parse_args()

    result = load_result_from_npy(args.npy_path)
    vig_scores = compute_vig_scores(result, empty_policy="original", alpha="mean")
    tau_p, selected_idx = select_top_p(vig_scores, args.p)

    print(f"Total samples: {len(vig_scores)}")
    print(f"p = {args.p:.2f} -> tau_p = {tau_p}")
    print(f"#selected (vig > tau_p): {len(selected_idx)}")

    if args.save_idx:
        np.save(args.save_idx, selected_idx)
        print(f"Saved the indices as '{args.save_idx}'")

if __name__ == "__main__":

    main()

