#!/usr/bin/env python3

import logging
from pathlib import Path
from typing import List

from inspect_ai import Task, eval
from inspect_ai.dataset import Sample
from inspect_ai.util import sandbox

from inspect_ai.scorer import Score, Target, scorer, stderr
from inspect_ai.solver import TaskState

from inspect_swe import claude_code, codex_cli

import config

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")


def _find_property_dirs(root: Path) -> List[Path]:
    if not root.exists():
        raise FileNotFoundError(f"Dataset root not found: {root}")
    return sorted([p for p in root.iterdir() if p.is_dir()])


def _collect_imgs(img_dir: Path) -> List[Path]:
    if not img_dir.exists():
        return []
    exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
    return sorted(
        [p for p in img_dir.iterdir() if p.suffix.lower() in exts and p.is_file()]
    )


def build_dataset():
    root_path = Path("dataset")
    samples: List[Sample] = []

    property_dirs = _find_property_dirs(root_path)

    # Filter to specific apartment if configured
    if config.APARTMENT_NAME:
        property_dirs = [d for d in property_dirs if d.name == config.APARTMENT_NAME]
        if not property_dirs:
            raise RuntimeError(
                f"Apartment '{config.APARTMENT_NAME}' not found in {root_path}"
            )

    for prop_dir in property_dirs:
        img_dir = prop_dir / "imgs"
        imgs = _collect_imgs(img_dir)
        if not imgs:
            continue

        files = {}
        for img in imgs:
            sandbox_path = f"/root/imgs/{img.name}"
            host_path = str(img.resolve())
            files[sandbox_path] = host_path

        samples.append(
            Sample(
                input=[
                    {
                        "role": "user",
                        "content": config.AGENT_INSTR + config.RULES,
                    },
                ],
                files=files,
                id=prop_dir.name,
                metadata={
                    "apartment_name": prop_dir.name,
                },
            )
        )

    if not samples:
        raise RuntimeError(f"No samples found under {root_path}/**/imgs/*")
    return samples


@scorer(metrics=[stderr()])
def read_output(apartment_name: str, agent: str, epoch: int):
    async def score(state: TaskState, target: Target):
        png_bytes: bytes = await sandbox().read_file("/root/output.png", text=False)
        file_path = f"./predictions_agent/{apartment_name}/{agent}_{epoch}.png"
        Path(file_path).parent.mkdir(parents=True, exist_ok=True)
        with open(file_path, "wb") as f:
            f.write(png_bytes)
        return Score(value=0.0)

    return score


def check_output_exists(apartment_name: str, model: str, epoch: int) -> bool:
    """Check if agent prediction output already exists."""
    parts = model.split("/", 1)
    provider = parts[0]
    model_name = parts[1] if len(parts) > 1 else "default"
    safe_model_name = model_name.replace("/", "_")

    output_dir = Path("./predictions_agent") / apartment_name / provider
    output_file = output_dir / f"{safe_model_name}_{epoch}.png"

    return output_file.exists()


def get_agent(model_name: str):
    provider = model_name.split("/")[0]
    if provider == "anthropic":
        return claude_code(model=model_name, disallowed_tools=["web_search"])
    elif provider == "openai":
        return codex_cli(model=model_name, disallowed_tools=["web_search"])
    else:
        raise ValueError(f"Unknown agent provider: {provider}")


def main():
    tasks = []
    total_combinations = 0
    skipped = 0

    for sample in build_dataset():
        for model in config.AGENT_MODELS:
            for epoch in range(1, config.EPOCHS + 1):
                total_combinations += 1
                apartment_name = sample.metadata["apartment_name"]

                if check_output_exists(apartment_name, model, epoch):
                    logging.info(f"Skip: {apartment_name}+{model}+{epoch} (exists)")
                    skipped += 1
                    continue

                tasks.append(
                    Task(
                        dataset=[sample],
                        solver=get_agent(model),
                        sandbox="docker",
                        model=model,
                        scorer=read_output(apartment_name, model, epoch),
                    )
                )

    # Print summary
    print(f"Found {total_combinations} total combinations")
    print(f"Skipped {skipped} existing predictions")
    print(f"Processing {len(tasks)} new predictions")

    if not tasks:
        print("No new predictions to process!")
        return

    # Process tasks in batches to avoid "too many open files" error
    batch_size = config.NUM_WORKERS  # Adjust based on your system
    for i in range(0, len(tasks), batch_size):
        print("Batch:", i // batch_size + 1)
        batch = tasks[i:i + batch_size]
        eval(batch, max_tasks=batch_size)


if __name__ == "__main__":
    main()
