import os
import json
import asyncio
import re
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio

# ==== Configuration ====
dataset_root = "/root/autodl-tmp/dataset"
CATEGORIES_TO_PROCESS = [
    "Abdominal Imaging",
    "Bone and Joint Imaging",
    "Breast Imaging",
    "Cardiac Imaging",
    "Chest Imaging",
    "Cranial Imaging",
    "Dental Imaging",
    "Dermatological Imaging",
    "Endoscopy Imaging",
    "Fundus Imaging",
    "Gynecological Imaging",
    "Pathology Slide Imaging"
]

# Small sample test configuration:
# 0 = process all samples
# Other numbers = process the specified number of samples (e.g., 10 means processing only the first 10 data entries)
# Note: Only English reports are processed now; Chinese report processing has been discontinued
sample_size = 0  # Small sample test, 0 for full sample, other numbers for small sample size
CONCURRENCY = 10  # Concurrency count

# Position word filter list
VALID_POSITIONS = {
    "left", "right", "center", "central", "upper", "lower", "middle",
    "left upper", "left lower", "right upper", "right lower",
    "upper left", "upper right", "lower left", "lower right",
}

PROMPTS = {
    "region_en": """Based on the following medical imaging report, carefully extract the key information and provide exactly three components separated by " | ":

1. Shape/appearance adjective: Extract the main descriptive adjective about the lesion's shape, appearance, or characteristics from the report (e.g., irregular, round, oval, raised, flat, uniform, heterogeneous, etc.)

2. Medical lesion term: Extract the main medical term for the abnormality mentioned in the report (e.g., lesion, mass, nodule, opacity, density, etc.)

3. Position: Convert the anatomical location mentioned in the report to one of these standard positions: left, right, center, central, upper, lower, middle, left upper, left lower, right upper, right lower, upper left, upper right, lower left, lower right

If no lesions or abnormalities are described, answer "No Finding".

Format: [adjective] | [medical term] | [position]

Examples:
- For "lesion on the trunk with brownish discoloration": brownish | lesion | center
- For "irregular mass in the left lung": irregular | mass | left
- For "round nodule in the upper right quadrant": round | nodule | upper right

Report content: {report}""",

    "neg_en": """Generate 5 different shape/appearance adjectives that could replace "{original_adjective}" for medical imaging descriptions.
Focus on contrasting characteristics (if original is "irregular" suggest "regular", "round", "oval", etc.).
Provide only the adjectives, one per line, numbered 1-5.

Original adjective: {original_adjective}"""
}


async def llm_text_generation(client, prompt, max_tokens=128):
    """Pure text generation, no images involved"""
    sys_prompt = "You are an expert radiologist."
    resp = await client.chat.completions.create(
        model="QwenVL",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": prompt}
        ],
        temperature=0.6,
        max_tokens=max_tokens
    )
    return resp.choices[0].message.content.strip()


async def generate_region_caption(client, report):
    """Generate region-level feature report"""
    prompt = PROMPTS["region_en"].format(report=report)
    response = await llm_text_generation(client, prompt, max_tokens=64)

    # If "No Finding", return directly
    if "No Finding" in response:
        return "No Finding"

    # Parse three components: adjective | medical_term | position
    parts = [part.strip() for part in response.split("|")]
    if len(parts) == 3:
        adjective, medical_term, position = parts
        # Assemble into standard format
        region_caption = f"There is {adjective} {medical_term} located in the {position} of the image"
        return region_caption
    else:
        # If parsing fails, try to extract intelligently from original response
        print(f"Parsing failed, original response: {response}")
        # Try to build using keywords in response
        words = response.split()
        if len(words) >= 2:
            # Simple heuristic: assume first word is adjective, second is medical term
            adjective = words[0] if words[0] else "irregular"
            medical_term = words[1] if words[1] else "lesion"
            position = "center"  # Default position
            region_caption = f"There is {adjective} {medical_term} located in the {position} of the image"
            return region_caption
        else:
            return "No Finding"


async def generate_negative_captions(client, region_caption):
    """Generate negative sample descriptions, only replace the adjective part"""
    # If "No Finding", skip negative sample generation
    if "No Finding" in region_caption:
        return []

    # Extract original adjective from region_caption
    # Format: "There is [adjective] [medical_term] located in the [position] of the image"
    match = re.match(r"There is (\w+) (.+) located in the (.+) of the image", region_caption)
    if not match:
        return []

    original_adjective = match.group(1)
    medical_term = match.group(2)
    position = match.group(3)

    # Generate new adjectives
    prompt = PROMPTS["neg_en"].format(original_adjective=original_adjective)
    neg_response = await llm_text_generation(client, prompt, max_tokens=128)

    # Parse new adjective list
    lines = [line.strip().strip("0123456789.、:- ") for line in neg_response.split("\n") if line.strip()]

    # Generate negative samples
    negative_captions = []
    for new_adjective in lines[:5]:
        if new_adjective and new_adjective != original_adjective:
            negative_caption = f"There is {new_adjective} {medical_term} located in the {position} of the image"
            negative_captions.append(negative_caption)

    return negative_captions


async def process_one(client, row, semaphore):
    """Process a single report"""
    async with semaphore:
        report = row["report"]

        # Generate region-level feature report
        region_caption = await generate_region_caption(client, report)

        # Generate negative sample descriptions
        negative_captions = await generate_negative_captions(client, region_caption)

        # Package output
        item = {
            "image": row.get("image", ""),  # Retain original field structure
            "report": report,
            "region_caption": region_caption,
            "negative_captions": negative_captions
        }
        return item


def find_jsonl_file(cat_dir):
    """Find English jsonl file"""
    suffix = "_en.jsonl"
    files = [f for f in os.listdir(cat_dir) if f.endswith(suffix)]
    if not files:
        print(f"[WARN] {cat_dir} has no {suffix}, skipped.")
        return None
    return files[0]


async def process_category(category, client, semaphore):
    """Process a single category"""
    cat_dir = os.path.join(dataset_root, category)
    jsonl_file = find_jsonl_file(cat_dir)

    if jsonl_file is None:
        return

    jsonl_path = os.path.join(cat_dir, jsonl_file)

    # Read data
    with open(jsonl_path, "r", encoding="utf-8-sig") as f:
        meta = [json.loads(line) for line in f]

    # Small sample test: 0 for full sample, >0 for specified number of samples
    if sample_size > 0:
        meta = meta[:sample_size]
        print(f"  -> Using small sample test, processing {len(meta)} data entries")
    else:
        print(f"  -> Processing all samples, total {len(meta)} data entries")

    # Output file path
    out_path = os.path.join(cat_dir, f"{category}_region_en.jsonl")

    # Batch processing
    tasks = [process_one(client, row, semaphore) for row in meta]
    results = await tqdm_asyncio.gather(*tasks, desc=f"{category}-EN", unit="reports")

    # Write results
    with open(out_path, "w", encoding="utf-8") as fout:
        for item in results:
            if item is not None:
                fout.write(json.dumps(item, ensure_ascii=False) + "\n")


async def main():
    """Main function"""
    # Display current configuration information
    print("🔧 Current configuration: Only process English reports")
    print("🔧 Format requirement: There is + [adjective] + [medical term] + located in the + [position] + of the image")
    print(f"🔧 The position word list contains {len(VALID_POSITIONS)} valid position words")
    if sample_size > 0:
        print(f"🔧 Small sample test mode, processing {sample_size} data entries per category")
    else:
        print("🔧 Full sample processing mode")
    print(f"🔧 Concurrency count: {CONCURRENCY}")
    print("-" * 50)

    client = AsyncOpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")
    semaphore = asyncio.Semaphore(CONCURRENCY)

    # Get categories to process
    categories = [
        name for name in CATEGORIES_TO_PROCESS
        if os.path.isdir(os.path.join(dataset_root, name))
    ]

    # Process categories one by one (only English)
    for category in categories:
        print(f"Processing main category: {category} English reports")
        await process_category(category, client, semaphore)

    print("✅ All completed!")


if __name__ == "__main__":
    asyncio.run(main())