from __future__ import annotations

import argparse
import json
import logging
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch
from PIL import Image

from pipeline_navi import NaviEditPipeline
from utils import first_param_point, load_yaml_config


LOGGER = logging.getLogger("pie_bench")

DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent / "config" / "navi.yaml"
DEFAULT_MODEL_PATH = Path("") # Path to SD3/SD3.5 weights directory
DEFAULT_SEED = 42
DEFAULT_PRECISION = "fp16"
DEFAULT_IMAGE_SIZE = 1024

DEFAULT_PIE_ROOT = Path(__file__).resolve().parent / "pie_bench"
DEFAULT_MAPPING_FILE = "mapping_file.json"
DEFAULT_IMAGE_SUBDIR = "annotation_images"
DEFAULT_METHOD_NAME = "Navi"

# Fallback edit config used when no YAML is provided
DEFAULT_EDIT_CONFIG: Dict[str, Any] = {
    "noise_samples": 1,
    "n_steps": 28,
    "t_edit": 28,
    "t_ref": 20,
    "src_guidance_scale": 3.0,
    "tar_guidance_scale": 15.0,
    "edit_dt": 1.0,
    "use_equiv_gain": True,
    "mask_mode": "auto_repr",
    "mask_ema": 0.80,
    "mask_pow": 1.0,
    "clamp_strength": 0.0,
    "mask_quantile": 0.90,
    "mask_min_area": 0.02,
    "mask_max_area": 0.65,
    "mask_grow": True,
    "mask_grow_r_th": 0.35,
    "mask_grow_patience": 2,
    "mask_grow_quantile": 0.97,
    "mask_dilate_max_k": 7,
    "mask_blur_k": 5,
    "use_cfl": True,
    "cfl_tau": 8.0,
}


@dataclass(frozen=True)
class PieRecord:
    sample_id: str
    image_path: Path
    relative_path: Path
    original_prompt: str
    edited_prompt: str
    edit_instruction: str


# --------------------------------------------------------------------------- #
# Argument parsing
# --------------------------------------------------------------------------- #


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run NaviEdit on PIE-Bench data and export PIE-format results.",
    )
    parser.add_argument(
        "--config",
        type=str,
        default=str(DEFAULT_CONFIG_PATH),
        help="YAML config describing weights + params (defaults to ./config/navi.yaml if present, else built-in defaults).",
    )
    parser.add_argument("--device", type=str, default=None, help="Torch device override, e.g. cuda:0 or cpu.")
    parser.add_argument("--precision", choices=["fp32", "fp16", "bf16"], default=None, help="Computation precision.")
    parser.add_argument("--model-path", type=str, default=None, help="Path to SD3/SD3.5 weights directory.")
    parser.add_argument("--negative-prompt", type=str, default="", help="Optional negative prompt.")

    parser.add_argument("--seed", type=int, default=None, help="Random seed overriding the config file.")
    parser.add_argument("--noise-samples", type=int, default=None, help="Number of MC noise samples (n_avg).")
    parser.add_argument("--n-steps", type=int, default=None, help="Scheduler timesteps.")
    parser.add_argument("--t-edit", type=int, default=None, help="Edit iterations on the Navi scale.")
    parser.add_argument("--t-ref", type=int, default=None, help="Reference start index on the Navi scale.")
    parser.add_argument("--src-guidance", type=float, default=None, help="Guidance scale for source prompt.")
    parser.add_argument("--tar-guidance", type=float, default=None, help="Guidance scale for target prompt.")
    parser.add_argument("--edit-dt", type=float, default=None, help="Edit-time step size (Δs).")
    parser.add_argument(
        "--mask-mode",
        type=str,
        default=None,
        choices=["none", "user", "auto_repr", "auto_dv"],
        help="Masking backend.",
    )
    parser.add_argument("--disable-equivalent-gain", action="store_true", help="Disable gain calibration.")
    parser.add_argument("--image-size", type=int, default=None, help="Resolution used when feeding the VAE.")
    parser.add_argument("--max-samples", type=int, default=None, help="Only process the first N records.")

    parser.add_argument("--pie-root", type=str, default=None, help="Root directory of PIE-Bench data.")
    parser.add_argument(
        "--mapping-file",
        type=str,
        default=DEFAULT_MAPPING_FILE,
        help="Mapping file path relative to --pie-root.",
    )
    parser.add_argument(
        "--image-subdir",
        type=str,
        default=DEFAULT_IMAGE_SUBDIR,
        help="Subdirectory (inside --pie-root) containing the original PIE annotation images.",
    )
    parser.add_argument(
        "--export-root",
        type=str,
        default=None,
        help="Directory that follows PIE-Bench layout (data/... + output/...). Defaults to --pie-root.",
    )
    parser.add_argument(
        "--method-name",
        type=str,
        default=DEFAULT_METHOD_NAME,
        help="Name used under export_root/output/<method_name>/annotation_images.",
    )
    parser.add_argument(
        "--output-subdir",
        type=str,
        default=DEFAULT_IMAGE_SUBDIR,
        help="Subdirectory inside output/<method_name>/ for generated images.",
    )
    parser.add_argument(
        "--source-subdir",
        type=str,
        default=DEFAULT_IMAGE_SUBDIR,
        help="Subdirectory inside data/ where original images are copied when --copy-source is set.",
    )
    parser.add_argument("--copy-source", action="store_true", help="Copy the source PIE images into export_root/data.")
    parser.add_argument(
        "--mapping-dest",
        type=str,
        default="data/mapping_file.json",
        help="Relative path (from export_root) to write the mapping file.",
    )
    parser.add_argument(
        "--no-sync-mapping",
        action="store_true",
        help="Skip copying the PIE mapping file into export_root.",
    )
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing predictions when present.")
    parser.add_argument(
        "--log-every",
        type=int,
        default=25,
        help="Progress logging interval in number of saved samples (0 disables incremental logs).",
    )

    parser.add_argument(
        "--center-crop",
        dest="center_crop",
        action="store_true",
        default=True,
        help="Center-crop before resize for VAE preprocessing (default).",
    )
    parser.add_argument(
        "--no-center-crop",
        dest="center_crop",
        action="store_false",
        help="Disable center crop before resizing.",
    )

    return parser.parse_args()


# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #


def dtype_from_precision(value: Optional[str], device: torch.device) -> torch.dtype:
    requested = (value or DEFAULT_PRECISION).lower()
    mapping = {
        "fp32": torch.float32,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }
    if requested not in mapping:
        raise ValueError(f"Unsupported precision '{value}'. Choose from {list(mapping)}.")
    dtype = mapping[requested]
    if device.type == "cpu" and dtype in {torch.float16, torch.bfloat16}:
        LOGGER.warning("Falling back to fp32 on CPU.")
        return torch.float32
    return dtype


def load_pipeline_config(path: Optional[str]) -> tuple[str, Dict[str, Any], int, str, int]:
    if path is None:
        LOGGER.warning("No config supplied; using built-in Navi defaults.")
        return (
            str(DEFAULT_MODEL_PATH),
            dict(DEFAULT_EDIT_CONFIG),
            DEFAULT_SEED,
            DEFAULT_PRECISION,
            DEFAULT_IMAGE_SIZE,
        )

    config_path = Path(path).expanduser()
    if not config_path.is_file():
        LOGGER.warning("Config file %s not found; using built-in defaults.", config_path)
        return (
            str(DEFAULT_MODEL_PATH),
            dict(DEFAULT_EDIT_CONFIG),
            DEFAULT_SEED,
            DEFAULT_PRECISION,
            DEFAULT_IMAGE_SIZE,
        )

    cfg = load_yaml_config(config_path)
    editor_cfg = cfg.get("editor", {})
    trained_cfg = cfg.get("trained", {})

    model_path = trained_cfg.get("path") or str(DEFAULT_MODEL_PATH)
    params_grid = editor_cfg.get("params_grid", {})
    edit_config = first_param_point(params_grid) if params_grid else dict(DEFAULT_EDIT_CONFIG)

    seed_value = editor_cfg.get("seed")
    if seed_value is None:
        seed_list = editor_cfg.get("seed_list", [])
        if seed_list:
            seed_value = seed_list[0]
    seed_value = int(seed_value) if seed_value is not None else DEFAULT_SEED

    precision = editor_cfg.get("precision", DEFAULT_PRECISION)
    image_size = int(editor_cfg.get("image_size", DEFAULT_IMAGE_SIZE))

    return model_path, edit_config, seed_value, precision, image_size


def apply_cli_overrides(
    args: argparse.Namespace, edit_config: Dict[str, Any], seed: int, image_size: int
) -> tuple[Dict[str, Any], int, int]:
    overrides = {
        "noise_samples": args.noise_samples,
        "n_steps": args.n_steps,
        "t_edit": args.t_edit,
        "t_ref": args.t_ref,
        "src_guidance_scale": args.src_guidance,
        "tar_guidance_scale": args.tar_guidance,
        "edit_dt": args.edit_dt,
    }
    for key, value in overrides.items():
        if value is not None:
            edit_config[key] = value

    if args.mask_mode:
        edit_config["mask_mode"] = args.mask_mode

    if args.disable_equivalent_gain:
        edit_config["use_equiv_gain"] = False

    if args.seed is not None:
        seed = int(args.seed)

    if args.image_size is not None:
        image_size = int(args.image_size)

    return edit_config, seed, image_size


def resolve_path(base: Path, maybe_relative: str | Path) -> Path:
    candidate = Path(maybe_relative)
    if candidate.is_absolute():
        return candidate.expanduser().resolve()
    return (base / candidate).expanduser().resolve()


def load_pie_records(root: Path, mapping_path: Path, image_subdir: str) -> List[PieRecord]:
    if not mapping_path.exists():
        raise FileNotFoundError(f"PIE mapping file not found: {mapping_path}")

    with mapping_path.open("r", encoding="utf-8") as handle:
        mapping = json.load(handle)

    if not isinstance(mapping, dict):
        raise ValueError(f"Expected mapping JSON to be a dict, got {type(mapping).__name__}")

    img_root = (root / image_subdir).expanduser().resolve()
    if not img_root.exists():
        raise FileNotFoundError(f"PIE image directory does not exist: {img_root}")

    records: List[PieRecord] = []
    for sample_id in sorted(mapping.keys()):
        meta = mapping[sample_id]
        rel_value = meta.get("image_path")
        if rel_value is None:
            LOGGER.warning("Sample %s is missing 'image_path'; skipping.", sample_id)
            continue
        rel_path = Path(rel_value)
        abs_path = (img_root / rel_path).expanduser().resolve()
        if not abs_path.exists():
            LOGGER.warning("Sample %s image not found at %s; skipping.", sample_id, abs_path)
            continue

        original_prompt = meta.get("original_prompt") or meta.get("source_prompt") or ""
        edited_prompt = meta.get("editing_prompt") or meta.get("edited_prompt") or meta.get("target_prompt") or ""
        edit_instruction = meta.get("editing_instruction") or meta.get("edit_prompt") or edited_prompt

        records.append(
            PieRecord(
                sample_id=sample_id,
                image_path=abs_path,
                relative_path=rel_path,
                original_prompt=original_prompt,
                edited_prompt=edited_prompt,
                edit_instruction=edit_instruction,
            )
        )

    if not records:
        raise FileNotFoundError(f"No valid PIE records found in {mapping_path}.")
    return records


def ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def copy_file(src: Path, dst: Path, *, overwrite: bool) -> None:
    ensure_dir(dst.parent)
    if overwrite or not dst.exists():
        shutil.copy2(src, dst)


def sync_mapping_file(mapping_path: Path, export_root: Path, dest_relative: str, *, overwrite: bool) -> None:
    dest_path = resolve_path(export_root, dest_relative)
    copy_file(mapping_path, dest_path, overwrite=overwrite)


def save_prediction(image: Image.Image, destination: Path, *, overwrite: bool) -> None:
    ensure_dir(destination.parent)
    if overwrite or not destination.exists():
        image.save(destination)


# --------------------------------------------------------------------------- #
# Main
# --------------------------------------------------------------------------- #


def main() -> None:
    args = parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    )

    model_path_cfg, edit_config, seed, precision, image_size = load_pipeline_config(args.config)
    edit_config, seed, image_size = apply_cli_overrides(args, edit_config, seed, image_size)

    device = torch.device(args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
    torch_dtype = dtype_from_precision(args.precision or precision, device)

    pie_root = Path(args.pie_root).expanduser().resolve() if args.pie_root else DEFAULT_PIE_ROOT
    export_root = Path(args.export_root).expanduser().resolve() if args.export_root else pie_root
    mapping_path = resolve_path(pie_root, args.mapping_file)

    records = load_pie_records(pie_root, mapping_path, args.image_subdir)
    if args.max_samples is not None:
        records = records[: args.max_samples]

    if not records:
        LOGGER.error("No PIE records to process. Check dataset paths.")
        return

    LOGGER.info(
        "Loaded %d PIE samples from %s (mapping=%s)",
        len(records),
        pie_root,
        mapping_path,
    )
    LOGGER.info("Seed %s | Edit config %s", seed, edit_config)

    pipeline = NaviEditPipeline.from_pretrained_sd3(
        model_path=args.model_path or model_path_cfg,
        default_edit_config=edit_config,
        device=device,
        torch_dtype=torch_dtype,
        image_size=image_size,
        use_center_crop=args.center_crop,
        negative_prompt=args.negative_prompt,
    )

    output_dir = export_root / "output" / args.method_name / args.output_subdir
    source_dir = export_root / "data" / args.source_subdir
    ensure_dir(output_dir)
    if args.copy_source:
        ensure_dir(source_dir)

    if not args.no_sync_mapping:
        sync_mapping_file(mapping_path, export_root, args.mapping_dest, overwrite=args.overwrite)
        LOGGER.info("Synchronized mapping file to %s", resolve_path(export_root, args.mapping_dest))

    processed = 0
    skipped = 0

    for idx, record in enumerate(records, start=1):
        rel_output_path = output_dir / record.relative_path
        if rel_output_path.exists() and not args.overwrite:
            skipped += 1
            continue

        try:
            with Image.open(record.image_path) as img:
                source_image = img.convert("RGB")
        except Exception as exc:  # pragma: no cover - defensive
            LOGGER.error("Failed to read %s: %s", record.image_path, exc)
            skipped += 1
            continue

        try:
            result = pipeline(
                image=source_image,
                source_prompt=record.original_prompt,
                target_prompt=record.edited_prompt,
                seed=seed,
                output_type="pil",
            )
        except Exception as exc:  # pragma: no cover - runtime safety
            LOGGER.error("Pipeline failed on %s: %s", record.sample_id, exc)
            skipped += 1
            continue

        images = result.images
        if isinstance(images, list) and images:
            generated = images[0]
        elif torch.is_tensor(images):
            generated = pipeline._tensor_to_pil(images)[0]  # type: ignore[attr-defined]
        else:
            LOGGER.warning("No images returned for sample %s; skipping.", record.sample_id)
            skipped += 1
            continue

        save_prediction(generated, rel_output_path, overwrite=args.overwrite)

        if args.copy_source:
            target_source_path = source_dir / record.relative_path
            copy_file(record.image_path, target_source_path, overwrite=args.overwrite)

        processed += 1
        if args.log_every and processed % args.log_every == 0:
            LOGGER.info("Saved %d/%d samples (skipped=%d)", processed, len(records), skipped)

    LOGGER.info(
        "Finished PIE export. Saved %d sample(s), skipped %d (existing/errors). Results: %s",
        processed,
        skipped,
        output_dir,
    )


if __name__ == "__main__":
    main()
