
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json, re, sys
from pathlib import Path
from string import ascii_uppercase, ascii_lowercase

SPECIALTIES = [
    "Surgery", "Allergy / Immunology", "Sleep Medicine", "Pediatrics - Neonatal",
    "SOAP / Chart / Progress Notes", "Bariatrics", "Pain Management",
    "Lab Medicine - Pathology", "Dermatology", "Orthopedic", "Dentistry",
    "Psychiatry / Psychology", "General Medicine", "Office Notes", "Letters",
    "Neurosurgery", "Radiology", "Cosmetic / Plastic Surgery", "Nephrology",
    "Diets and Nutritions", "Chiropractic", "Gastroenterology",
    "Cardiovascular / Pulmonary", "Speech - Language",
    "Hospice - Palliative Care", "Autopsy", "Endocrinology",
    "Emergency Room Reports", "Discharge Summary", "ENT - Otolaryngology",
    "Urology", "Physical Medicine - Rehab", "Neurology", "Podiatry",
    "Ophthalmology", "Rheumatology", "IME-QME-Work Comp etc.",
    "Hematology - Oncology", "Consult - History and Phy.",
    "Obstetrics / Gynecology",
]

PROMPT_HEADER = (
    "\nTASK: Determine which medical specialty/domain best fits the "
    "transcription below. The 40 possible specialties are: "
    f"{SPECIALTIES}"
)

JSON_PROMPT = """Return **valid JSON** that matches this schema:
{{"reason": <Explanation>   // one sentence, "label": <Specialty>}}
###
INPUT: """

PROMPT_FOOTER = "\nOUTPUT:\n"
def extract_transcription(old_prompt: str) -> str:
    m = re.search(r"INPUT:(.*)OUTPUT:", old_prompt, re.S)
    if not m:
        sys.exit("✗ Could not find INPUT … OUTPUT block.")
    return m.group(1).strip()


def convert_line(obj: dict) -> dict:
    transcription = extract_transcription(obj["prompt"])
    new_prompt = f"{PROMPT_HEADER}{JSON_PROMPT}{transcription}{PROMPT_FOOTER}"
    obj["prompt"] = new_prompt
    obj["gold"] = obj["gold"] 
    return obj


def main():
    in_path  = Path("proxy_tuning/datasets/mtsample/valid.jsonl")
    out_path = Path("proxy_tuning/datasets/mtsample/val_reason_first.jsonl")
    out_path.parent.mkdir(parents=True, exist_ok=True)

    with in_path.open() as fin, out_path.open("w") as fout:
        for line in fin:
            fout.write(json.dumps(convert_line(json.loads(line)), ensure_ascii=False) + "\n")

    print(f"✓ wrote {out_path} ({out_path.stat().st_size/1e6:.2f} MB)")


if __name__ == "__main__":
    main()
