#!/usr/bin/env python3
import json
import sys
import argparse
from typing import Any, Dict, List, Optional
from statistics import median
import re
from pathlib import Path


def _to_scalar(x: Any) -> Any:
    """Return first element if x is a non-empty list/tuple; otherwise return x."""
    if isinstance(x, (list, tuple)) and x:
        return x[0]
    return x

def _coerce_num(x: Any) -> Optional[float]:
    """Best-effort convert to float; return None on failure."""
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    try:
        return float(str(x))
    except ValueError:
        return None

def _get_agent_scores(eval_block: Any, n_agents: int) -> List[Optional[float]]:
    """
    Normalize an evaluation block into a length-n_agents list of floats/None.
    Supported formats:
      - dict with keys "0".."3" or 0..3 and values like [score] or score
      - list of values (will be taken in order)
    Missing entries become None.
    """
    scores: List[Optional[float]] = [None] * n_agents
    if isinstance(eval_block, dict):
        for i in range(n_agents):
            v = eval_block.get(str(i), eval_block.get(i, None))
            scores[i] = _coerce_num(_to_scalar(v))
    elif isinstance(eval_block, list):
        for i in range(min(n_agents, len(eval_block))):
            scores[i] = _coerce_num(_to_scalar(eval_block[i]))
    return scores

def _row_avg(row: List[Optional[float]]) -> Optional[float]:
    vals = [v for v in row if v is not None]
    return sum(vals)/len(vals) if vals else None

def _row_max(row: List[Optional[float]]) -> Optional[float]:
    vals = [v for v in row if v is not None]
    return max(vals) if vals else None

def _row_min(row: List[Optional[float]]) -> Optional[float]:
    vals = [v for v in row if v is not None]
    return min(vals) if vals else None

def _row_median(row: List[Optional[float]]) -> Optional[float]:
    vals = [v for v in row if v is not None]
    return median(vals) if vals else None  # even count => mean of the two middle values

def _mean_ignore_none(arr: List[Optional[float]]) -> Optional[float]:
    """Average of a 1D list ignoring None; return None if all are None."""
    vals = [v for v in arr if v is not None]
    return (sum(vals) / len(vals)) if vals else None


def _infer_agents_from_block(block: Any) -> Optional[int]:
    """Infer number of agents from a block that can be a dict or list."""
    if block is None:
        return None
    if isinstance(block, dict):
        # count keys that look like indices 0..N-1, fallback to len(dict)
        try:
            keys = [int(k) if not isinstance(k, int) else k for k in block.keys()]
            if not keys:
                return None
            # If keys are non-negative integers, infer as max+1
            if all(isinstance(k, int) and k >= 0 for k in keys):
                return max(keys) + 1
        except (ValueError, TypeError):
            pass
        return len(block)
    if isinstance(block, list):
        return len(block)
    return None


def _infer_file_stats(path: str) -> (int, int):
    """
    Infer (n_rows, n_agents) by scanning the JSONL file once.
    - n_rows: count of non-empty, JSON-decodable lines
    - n_agents: inferred from first decodable line's relevant blocks
    """
    n_rows = 0
    n_agents: Optional[int] = None
    try:
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except json.JSONDecodeError:
                    # skip malformed lines for counting/agents
                    continue
                n_rows += 1
                if n_agents is None:
                    evals: Dict[str, Any] = obj.get("evaluations", {})
                    pe = obj.get("predictions_evaluation", {})
                    candidates = [
                        evals.get("correctness"),
                        evals.get("judge_repetition"),
                        evals.get("judge_collaborative"),
                        evals.get("judge_golden_score"),
                        pe.get("math_accuracy"),
                    ]
                    inferred_values: List[int] = []
                    for blk in candidates:
                        v = _infer_agents_from_block(blk)
                        if isinstance(v, int) and v > 0:
                            inferred_values.append(v)
                    if inferred_values:
                        n_agents = max(inferred_values)
            # If file has additional lines after first valid, include them in n_rows
            # Note: the above only increments n_rows for valid JSON lines. That's intended.
    except FileNotFoundError:
        pass
    if n_rows == 0:
        # Fallback to 500 if nothing could be inferred
        n_rows = 500
    if n_agents is None:
        n_agents = 4
    return n_rows, n_agents


def main():
    parser = argparse.ArgumentParser(
        description="Extract matrices (correctness, judge_repetition, judge_collaborative) and per-row/global stats."
    )
    parser.add_argument("--input_file", "-i", required=True,
                        help="Path to input JSONL (use '-' for stdin)")
    parser.add_argument("--output_file", "-o", required=True,
                        help="Path to output JSON (use '-' for stdout)")
    parser.add_argument("--rows", type=int, default=500,
                        help="Number of rows to read (default: 500)")
    parser.add_argument("--agents", type=int, default=4,
                        help="Number of agents per row (default: 4)")
    args = parser.parse_args()

    # Auto-infer rows/agents if user passes non-positive values or leaves defaults
    N_ROWS = args.rows
    N_AGENTS = args.agents
    input_path = None if args.input_file == "-" else args.input_file
    if input_path:
        inferred_rows, inferred_agents = _infer_file_stats(input_path)
        if N_ROWS <= 0 or args.rows == parser.get_default("rows"):
            N_ROWS = inferred_rows
        if N_AGENTS <= 0 or args.agents == parser.get_default("agents"):
            N_AGENTS = inferred_agents

    # Derive dataset name from input filename: eval_outputs_{YYYYMMDD}_{HHMMSS}_{dataset}_evaluated.jsonl
    dataset_name = None
    if input_path:
        fname = Path(input_path).name
        m = re.match(r"^eval_outputs_\d{8}_\d{6}_(.+)_evaluated\.jsonl$", fname)
        if m:
            dataset_name = m.group(1)

    # open input/output (support - for stdio)
    fin = sys.stdin if args.input_file == "-" else open(args.input_file, "r", encoding="utf-8")
    fout = sys.stdout if args.output_file == "-" else open(args.output_file, "w", encoding="utf-8")

    try:
        # Matrices: rows x agents
        correctness:   List[List[Optional[float]]] = [[None]*N_AGENTS for _ in range(N_ROWS)]
        repetiton:     List[List[Optional[float]]] = [[None]*N_AGENTS for _ in range(N_ROWS)]  # keep your spelling
        collaborative: List[List[Optional[float]]] = [[None]*N_AGENTS for _ in range(N_ROWS)]
        gold_score: List[List[Optional[float]]] = [[None]*N_AGENTS for _ in range(N_ROWS)]
        math_acc: List[List[Optional[float]]] = [[None]*N_AGENTS for _ in range(N_ROWS)]

        # Per-row statistics (1D lists)
        correctness_avg = [None]*N_ROWS
        correctness_max = [None]*N_ROWS
        correctness_min = [None]*N_ROWS
        correctness_med = [None]*N_ROWS  # median

        repetition_avg = [None]*N_ROWS
        repetition_max = [None]*N_ROWS
        repetition_min = [None]*N_ROWS
        repetition_med = [None]*N_ROWS   # median

        collaborative_avg = [None]*N_ROWS
        collaborative_max = [None]*N_ROWS
        collaborative_min = [None]*N_ROWS
        collaborative_med = [None]*N_ROWS  # median

        judge_score_avg = [None]*N_ROWS
        judge_score_max = [None]*N_ROWS
        judge_score_min = [None]*N_ROWS
        judge_score_med = [None]*N_ROWS  # median
        
        math_acc_avg = [None]*N_ROWS
        math_acc_max = [None]*N_ROWS
        math_acc_min = [None]*N_ROWS
        math_acc_med = [None]*N_ROWS  # median

        rows_filled = 0
        for line in fin:
            if rows_filled >= N_ROWS:
                break
            s = line.strip()
            if not s:
                continue

            try:
                obj = json.loads(s)
            except json.JSONDecodeError:
                # skip malformed line; keep None row
                rows_filled += 1
                continue

            evals: Dict[str, Any] = obj.get("evaluations", {})
            pe = obj.get("predictions_evaluation", {})
            
            # Only these three sources are used:
            blk_correctness = evals.get("correctness")
            blk_judge_rep   = evals.get("judge_repetition")
            blk_judge_coll  = evals.get("judge_collaborative")
            blk_judge_golden_score = evals.get("judge_golden_score")
            blk_math_acc = pe.get("math_accuracy")

            # Extract 4-agent rows
            correctness[rows_filled]   = _get_agent_scores(blk_correctness, N_AGENTS)
            repetiton[rows_filled]     = _get_agent_scores(blk_judge_rep, N_AGENTS)
            collaborative[rows_filled] = _get_agent_scores(blk_judge_coll, N_AGENTS)
            gold_score[rows_filled] = _get_agent_scores(blk_judge_golden_score, N_AGENTS)
            math_acc[rows_filled] = _get_agent_scores(blk_math_acc, N_AGENTS)

            # Per-row stats
            c_row = correctness[rows_filled]
            r_row = repetiton[rows_filled]
            k_row = collaborative[rows_filled]
            gs_row = gold_score[rows_filled]
            ma_row = math_acc[rows_filled]

            correctness_avg[rows_filled] = _row_avg(c_row)
            correctness_max[rows_filled] = _row_max(c_row)
            correctness_min[rows_filled] = _row_min(c_row)
            correctness_med[rows_filled] = _row_median(c_row)

            repetition_avg[rows_filled] = _row_avg(r_row)
            repetition_max[rows_filled] = _row_max(r_row)
            repetition_min[rows_filled] = _row_min(r_row)
            repetition_med[rows_filled] = _row_median(r_row)

            collaborative_avg[rows_filled] = _row_avg(k_row)
            collaborative_max[rows_filled] = _row_max(k_row)
            collaborative_min[rows_filled] = _row_min(k_row)
            collaborative_med[rows_filled] = _row_median(k_row)

            judge_score_avg[rows_filled] = _row_avg(gs_row)
            judge_score_max[rows_filled] = _row_max(gs_row)
            judge_score_min[rows_filled] = _row_min(gs_row)
            judge_score_med[rows_filled] = _row_median(gs_row)

            math_acc_avg[rows_filled] = _row_avg(ma_row)
            math_acc_max[rows_filled] = _row_max(ma_row)
            math_acc_min[rows_filled] = _row_min(ma_row)
            math_acc_med[rows_filled] = _row_median(ma_row)

            rows_filled += 1

        # Global averages of the per-row sequences (ignore None)
        # collaborative
        collaborative_avg_avg = _mean_ignore_none(collaborative_avg)
        collaborative_max_avg = _mean_ignore_none(collaborative_max)
        collaborative_min_avg = _mean_ignore_none(collaborative_min)
        collaborative_med_avg = _mean_ignore_none(collaborative_med)
        # repetition
        repetition_avg_avg = _mean_ignore_none(repetition_avg)
        repetition_max_avg = _mean_ignore_none(repetition_max)
        repetition_min_avg = _mean_ignore_none(repetition_min)
        repetition_med_avg = _mean_ignore_none(repetition_med)
        # correctness
        correctness_avg_avg = _mean_ignore_none(correctness_avg)
        correctness_max_avg = _mean_ignore_none(correctness_max)
        correctness_min_avg = _mean_ignore_none(correctness_min)
        correctness_med_avg = _mean_ignore_none(correctness_med)
        # judge score
        judge_score_avg_avg = _mean_ignore_none(judge_score_avg)
        judge_score_max_avg = _mean_ignore_none(judge_score_max)
        judge_score_min_avg = _mean_ignore_none(judge_score_min)
        judge_score_med_avg = _mean_ignore_none(judge_score_med)
        # math acc
        math_acc_avg_avg = _mean_ignore_none(math_acc_avg)
        math_acc_max_avg = _mean_ignore_none(math_acc_max)
        math_acc_min_avg = _mean_ignore_none(math_acc_min)
        math_acc_med_avg = _mean_ignore_none(math_acc_med)

        out = {
            "dataset_name": dataset_name, 
            "num_thinkers": N_AGENTS,
            "num_samples": N_ROWS,
            # # Matrices
            # "correctness": correctness,
            # "repetiton": repetiton,            # keep the matrix name as you specified
            # "collaborative": collaborative,

            # # Per-row stats (median)
            # "correctness_avg": correctness_avg,
            # "correctness_max": correctness_max,
            # "correctness_min": correctness_min,
            # "correctness_med": correctness_med,

            # "repetition_avg": repetition_avg,
            # "repetition_max": repetition_max,
            # "repetition_min": repetition_min,
            # "repetition_med": repetition_med,

            # "collaborative_avg": collaborative_avg,
            # "collaborative_max": collaborative_max,
            # "collaborative_min": collaborative_min,
            # "collaborative_med": collaborative_med,

            # Global averages of the per-row stats
            "collaborative_avg_avg": collaborative_avg_avg,
            "collaborative_max_avg": collaborative_max_avg,
            "collaborative_min_avg": collaborative_min_avg,
            "collaborative_med_avg": collaborative_med_avg,

            "repetition_avg_avg": repetition_avg_avg,
            "repetition_max_avg": repetition_max_avg,
            "repetition_min_avg": repetition_min_avg,
            "repetition_med_avg": repetition_med_avg,

            "correctness_avg_avg": correctness_avg_avg,
            "correctness_max_avg": correctness_max_avg,
            "correctness_min_avg": correctness_min_avg,
            "correctness_med_avg": correctness_med_avg,
            
            "judge_score_avg_avg": judge_score_avg_avg,
            "judge_score_max_avg": judge_score_max_avg,
            "judge_score_min_avg": judge_score_min_avg,
            "judge_score_med_avg": judge_score_med_avg,
            
            "math_acc_avg_avg": math_acc_avg_avg,
            "math_acc_max_avg": math_acc_max_avg,
            "math_acc_min_avg": math_acc_min_avg,
            "math_acc_med_avg": math_acc_med_avg,
        }

        json.dump(out, fout, ensure_ascii=False, indent=2)
        print(f"\nDone. rows_filled={rows_filled}", file=sys.stderr)
    finally:
        if fin is not sys.stdin:
            fin.close()
        if fout is not sys.stdout:
            fout.close()


if __name__ == "__main__":
    main()

#bash -lc 'set -euo pipefail; export LC_ALL=C; find /Users/fengtingliao/external/group_think_work/group_think_data/experiments -type f -name "*evaluated.jsonl" -print0 | while IFS= read -r -d "" f; do dir=$(dirname "$f"); out="$dir/extracted_eval_scores.json"; echo "Processing: $f -> $out"; python3 /Users/fengtingliao/external/group_think_work/GroupThink-Training/evaluate/extract_eval_scores.py -i "$f" -o "$out" 2>"$out.stderr.log" || echo "Failed: $f"; done'


#bash -lc 'set -euo pipefail; export LC_ALL=C; find /Users/fengtingliao/external/group_think_work/group_think_data/experiments/run_eval_20250924_103345 -type f -name "*evaluated.jsonl" -print0 | while IFS= read -r -d "" f; do dir=$(dirname "$f"); out="$dir/extracted_eval_scores.json"; echo "Processing: $f -> $out"; python3 /Users/fengtingliao/external/group_think_work/GroupThink-Training/evaluate/extract_eval_scores.py -i "$f" -o "$out" 2>"$out.stderr.log" || echo "Failed: $f"; done'