import json
import os

model = os.environ.get("MODEL", "Qwen/Qwen3-8B")
mode = os.environ.get("MODE", "vllm")

if mode == "api":
    from .gen_by_api import get_response_list
elif mode == "vllm":
    from .gen_by_vllm import get_response_list, init_vllm
    init_vllm(model)
else:
    raise ValueError(f"Invalid mode: {mode}")

few_shot_radio_note = "<few_shot_radio_note>"

few_shot_parse = "X-Ray"

valid_modalities = ["X-ray", "CT", "MRI", "Ultrasound", "Fluoroscopy", "Others"]

def get_parse_message(radio_note, enable_few_shot):
    message = [
        {
            "role": "system",
            "content": f"You need to determine what type of imaging modality the radiology note below is from: {valid_modalities}. You should only provide one of the above options as your answer. Note that Computed Tomography is also known as CT, and Magnetic Resonance Imaging is also known as MRI. If you are not sure or the modality is not listed, please choose Others. Your answer should be one of the following: {valid_modalities}. Please do not include any other content in your answer."
        }
    ]
    if enable_few_shot:
        message.append(
            {
                "role": "user",
                "content": few_shot_radio_note
            }
        )
        message.append(
            {
                "role": "assistant",
                "content": few_shot_parse
            }
        )
    message.append(
        {
            "role": "user",
            "content": radio_note
        }
    )
    
    return message

def modality_valid(modality):
    for valid_modality in valid_modalities:
        if valid_modality.lower() in modality.lower():
            if valid_modality == "Others":
                return "Others"
            return "Valid"
    return "Invalid"

def modality_std(modality):
    for valid_modality in valid_modalities:
        if valid_modality.lower() in modality.lower():
            return valid_modality
    return "Others"

def parse_radio(note_list, enable_few_shot):
    messages = [get_parse_message(note, enable_few_shot) for note in note_list]
    response_list = get_response_list(messages)
    return response_list