import argparse, glob, os, re, json
from pathlib import Path

def extract_last_choice_letter(text: str):
    """Extract the last A/B/C/D letter from generated text (ignoring the fixed instruction preamble)."""
    if not text:
        return None

    preamble_re = re.compile(
        r"""^\s*You\s+are\s+answering\s+a\s+4-option\s+multiple-choice\s+question\.\s*
             Options\s+are\s+labeled\s+A,\s*B,\s*C,\s*and\s*D\.\s*
             Think\s+step-by-step\s+and\s+show\s+your\s+reasoning\.\s*
             At\s+the\s+very\s+end,\s*output\s+ONE\s+line\s+exactly\s+in\s+this\s+format:\s*
             Final\s+Answer:\s*\\boxed\{\s*[A-D]\s*\}\s*
             where\s+the\s+letter\s+is\s+A,\s*B,\s*C,\s*or\s*D\.\s*
             \n*""",
        re.IGNORECASE | re.VERBOSE | re.DOTALL
    )
    text = preamble_re.sub("", text, count=1)

    hits = []
    for m in re.finditer(r"\\boxed\s*\{\s*([A-D])\s*\}", text):
        hits.append(m.group(1))
    for m in re.finditer(r"(?im)^\s*(?:final\s*answer|answer)\s*[:=]?\s*\(?\s*([A-D])\s*\)?\s*$", text):
        hits.append(m.group(1))
    for m in re.finditer(r"(?m)^\s*\(?\s*([A-D])\s*\)?\s*$", text):
        hits.append(m.group(1))
    return hits[-1] if hits else None


def extract_gold_choice_letter(text: str):
    """Extract the correct A/B/C/D letter from gold text."""
    if not text:
        return None
    m = re.search(r"####\s*([A-D])\b", text)
    if m:
        return m.group(1)
    m = re.match(r"^\s*\(?\s*([A-D])\s*\)?\s*$", text.strip())
    if m:
        return m.group(1)
    m = re.search(r"\b([A-D])\b", text)
    return m.group(1) if m else None


def compare_mcq(generated_text: str, gold_text: str) -> bool:
    """Return True if both predicted and gold are the same A/B/C/D letter."""
    gold = extract_gold_choice_letter(gold_text)
    if not gold:
        return False
    pred = extract_last_choice_letter(generated_text)
    return pred == gold if pred else False


def load_and_merge_json_files(pattern: str, num_ranks: int = 8):
    """Load and merge multiple rank JSON files into one."""
    all_rows = []
    
    for rank in range(num_ranks):
        # Replace rank* with the specific rank number
        file_path = pattern.replace("rank*", f"rank{rank}")
        
        if not os.path.exists(file_path):
            print(f"[WARNING] File not found: {file_path}")
            continue
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                rows = data.get("rows", [])
                all_rows.extend(rows)
                print(f"  Loaded rank{rank}: {len(rows)} rows")
        except Exception as e:
            print(f"[ERROR] Failed to load {file_path}: {e}")
            continue
    
    return all_rows


def patch_json_category(dir_path: str, pattern: str, output_name: str, num_ranks: int = 8):
    """Process one category of JSON files (e.g., baseline, soft_argmax, etc.)"""
    print(f"\n{'='*60}")
    print(f"Processing: {pattern}")
    print(f"{'='*60}")
    
    # Load and merge all rank files
    full_pattern = os.path.join(dir_path, pattern)
    all_rows = load_and_merge_json_files(full_pattern, num_ranks)
    
    if not all_rows:
        print("[WARNING] No rows found")
        return
    
    # Auto-detect field names from first valid row
    gold_key, pred_key, ok_key = None, None, None
    for row in all_rows:
        if isinstance(row, dict):
            # Try to detect gold answer field
            if "gold" in row:
                gold_key = "gold"
            elif "ground_truth_answer" in row:
                gold_key = "ground_truth_answer"
            
            # Try to detect prediction field
            if "pred" in row:
                pred_key = "pred"
            elif "generated_text" in row:
                pred_key = "generated_text"
            
            # Try to detect ok field
            if "ok" in row:
                ok_key = "ok"
            elif "is_correct" in row:
                ok_key = "is_correct"
            
            if gold_key and pred_key and ok_key:
                break
    
    if not (gold_key and pred_key and ok_key):
        print(f"[ERROR] Could not detect field names. Found: gold={gold_key}, pred={pred_key}, ok={ok_key}")
        return
    
    print(f"[INFO] Detected fields: gold='{gold_key}', pred='{pred_key}', ok='{ok_key}'")
    
    total, correct, changed = 0, 0, 0
    
    for row in all_rows:
        if not isinstance(row, dict):
            continue
        
        # Skip rows with errors (same as PT version)
        if "error" in row:
            continue
        
        total += 1
        gold = row.get(gold_key, "")
        pred = row.get(pred_key, "")
        
        old_ok = row.get(ok_key, False)
        # Convert to bool (same logic as PT version)
        if isinstance(old_ok, bool):
            pass
        elif isinstance(old_ok, (int, float)):
            old_ok = (old_ok != 0)
        elif isinstance(old_ok, str):
            old_ok = old_ok.strip().lower() in {"true", "1", "yes", "y", "t"}
        else:
            old_ok = False
        
        # Use EXACT same compare_mcq function as PT version
        new_ok = compare_mcq(pred, gold)
        
        if old_ok != new_ok:
            changed += 1
        
        row[ok_key] = new_ok
        correct += int(new_ok)
    
    accuracy = correct / max(1, total)
    
    # Create output JSON
    output_data = {
        "rows": all_rows,
        "accuracy": accuracy,
        "total": total,
        "correct": correct
    }
    
    # Save merged and fixed JSON
    output_path = os.path.join(dir_path, output_name)
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        print(f"✓ Saved to: {output_path}")
    except Exception as e:
        print(f"[ERROR] Failed to save {output_path}: {e}")
        return
    
    print(f"\n{'─'*60}")
    print("RESULTS:")
    print(f"{'─'*60}")
    print(f"  Total rows:        {total}")
    print(f"  Correct answers:   {correct}/{total}")
    print(f"  Changed judgments: {changed}")
    print(f"  New accuracy:      {accuracy:.4f}")
    print(f"{'─'*60}")


def main():
    ap = argparse.ArgumentParser(
        description="Fix multiple-choice (A/B/C/D) judgment errors in JSON result files and merge rank files.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    ap.add_argument("--dir", required=True, help="Directory containing the JSON files")
    ap.add_argument("--stem", required=True, help="File stem (e.g., a0.1__linear__L27 or 'baseline')")
    ap.add_argument("--num-ranks", type=int, default=8, help="Number of rank files to merge (default: 8)")
    ap.add_argument("--skip-baseline", action="store_true", help="Skip processing baseline files (when already processed)")
    args = ap.parse_args()
    
    dir_path = args.dir
    stem = args.stem
    num_ranks = args.num_ranks
    skip_baseline = args.skip_baseline
    
    if not os.path.isdir(dir_path):
        print(f"[ERROR] Directory not found: {dir_path}")
        return
    
    print(f"\n{'='*60}")
    print("MCQ FIX: Correcting A/B/C/D evaluation errors in JSON files")
    print(f"{'='*60}")
    print(f"Directory: {dir_path}")
    print(f"Stem: {stem}")
    print(f"Number of ranks: {num_ranks}")
    if skip_baseline:
        print(f"Skip baseline: YES (already processed)")
    print(f"{'='*60}")
    
    # Determine which categories to process based on stem
    if stem.lower() == "baseline":
        # For baseline, only process baseline files
        categories = [
            ("baseline.rank*.json", "baseline.fixeval.json"),
        ]
    else:
        # For non-baseline stems, process steering categories
        # and optionally baseline (unless skip_baseline is set)
        categories = []
        
        if not skip_baseline:
            categories.append(("baseline.rank*.json", "baseline.fixeval.json"))
        
        categories.extend([
            (f"{stem}.vec_base.rank*.json", f"{stem}.vec_base.fixeval.json"),
            (f"{stem}.soft_prob.rank*.json", f"{stem}.soft_prob.fixeval.json"),
            (f"{stem}.soft_argmax.rank*.json", f"{stem}.soft_argmax.fixeval.json"),
        ])
    
    for pattern, output_name in categories:
        patch_json_category(dir_path, pattern, output_name, num_ranks)
    
    print(f"\n{'='*60}")
    print(f"COMPLETED: Processed all categories for stem '{stem}'")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    main()