from __future__ import annotations

import argparse
import logging
from pathlib import Path
from typing import Any, Dict, Optional

import torch

from pipeline_navi import NaviEditPipeline
from utils import (
    DEFAULT_DATA_ROOT,
    first_param_point,
    load_local_dataset,
    load_yaml_config,
)


LOGGER = logging.getLogger("navi_app")

DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent / "config" / "navi.yaml"
DEFAULT_MODEL_PATH = Path("") # Path to SD3/SD3.5 weights directory
DEFAULT_SEED = 42
DEFAULT_PRECISION = "fp16"
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent / "output"
DEFAULT_IMAGE_SIZE = 1024

# A single-point fallback edit config used when no YAML is provided
DEFAULT_EDIT_CONFIG: Dict[str, Any] = {
    "noise_samples": 1,
    "n_steps": 28,  # scheduler steps
    "t_edit": 28,  # edit iterations
    "t_ref": 20,
    "src_guidance_scale": 3.0,
    "tar_guidance_scale": 15.0,
    "edit_dt": 1.0,
    "use_equiv_gain": True,
    "mask_mode": "auto_repr",
    "mask_ema": 0.80,
    "mask_pow": 1.0,
    "clamp_strength": 0.0,
    "mask_quantile": 0.90,
    "mask_min_area": 0.02,
    "mask_max_area": 0.65,
    "mask_grow": True,
    "mask_grow_r_th": 0.35,
    "mask_grow_patience": 2,
    "mask_grow_quantile": 0.97,
    "mask_dilate_max_k": 7,
    "mask_blur_k": 5,
    "use_cfl": True,
    "cfl_tau": 8.0,
}


# --------------------------------------------------------------------------- #
# Argument parsing
# --------------------------------------------------------------------------- #


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Standalone NaviEdit demo for SD3/SD3.5.")
    parser.add_argument(
        "--config",
        type=str,
        default=str(DEFAULT_CONFIG_PATH),
        help="Optional YAML config path (defaults to ./config/navi.yaml if present, else built-in defaults).",
    )
    parser.add_argument("--dataset", type=str, default=None, help="Dataset root. Defaults to ./images.")
    parser.add_argument("--output-dir", type=str, default=None, help="Directory to save edited results.")
    parser.add_argument("--device", type=str, default=None, help="Torch device, e.g. cuda:0 or cpu.")
    parser.add_argument("--precision", choices=["fp32", "fp16", "bf16"], default=None, help="Computation dtype.")
    parser.add_argument("--model-path", type=str, default=None, help="Path to SD3/SD3.5 weights directory.")
    parser.add_argument("--negative-prompt", type=str, default="", help="Optional negative prompt.")
    parser.add_argument("--seed", type=int, default=None, help="Random seed overriding config.")

    # Core NaviEdit knobs
    parser.add_argument("--noise-samples", type=int, default=None, help="Number of MC noise samples (n_avg).")
    parser.add_argument("--n-steps", type=int, default=None, help="Scheduler timesteps (num_inference_steps).")
    parser.add_argument("--t-edit", type=int, default=None, help="Edit iterations on the Navi scale.")
    parser.add_argument("--t-ref", type=int, default=None, help="Reference starting index on the Navi scale.")
    parser.add_argument("--src-guidance", type=float, default=None, help="Guidance scale for source prompt.")
    parser.add_argument("--tar-guidance", type=float, default=None, help="Guidance scale for target prompt.")
    parser.add_argument("--edit-dt", type=float, default=None, help="Edit-time step size (Δs).")
    parser.add_argument(
        "--mask-mode",
        type=str,
        default=None,
        choices=["none", "user", "auto_repr", "auto_dv"],
        help="Masking backend.",
    )
    parser.add_argument("--disable-equivalent-gain", action="store_true", help="Disable gain calibration.")
    parser.add_argument("--image-size", type=int, default=None, help="Resolution for preprocessing / VAE.")

    parser.add_argument(
        "--center-crop",
        dest="center_crop",
        action="store_true",
        default=True,
        help="Apply center crop before resizing (default).",
    )
    parser.add_argument(
        "--no-center-crop",
        dest="center_crop",
        action="store_false",
        help="Disable center crop before resizing.",
    )

    parser.add_argument("--max-samples", type=int, default=None, help="Limit number of dataset samples.")
    return parser.parse_args()


# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #


def _dtype_from_precision(value: Optional[str], device: torch.device) -> torch.dtype:
    requested = (value or DEFAULT_PRECISION).lower()
    mapping = {
        "fp32": torch.float32,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }
    if requested not in mapping:
        raise ValueError(f"Unsupported precision '{value}'. Choose from {list(mapping)}.")
    dtype = mapping[requested]
    if device.type == "cpu" and dtype in {torch.float16, torch.bfloat16}:
        LOGGER.warning("Falling back to fp32 on CPU.")
        return torch.float32
    return dtype


def _load_config_from_file(path: Optional[str]) -> tuple[str, Dict[str, Any], int, str, int]:
    if path is None:
        LOGGER.warning("No config supplied; using built-in Navi defaults.")
        return (
            str(DEFAULT_MODEL_PATH),
            dict(DEFAULT_EDIT_CONFIG),
            DEFAULT_SEED,
            DEFAULT_PRECISION,
            DEFAULT_IMAGE_SIZE,
        )

    config_path = Path(path).expanduser()
    if not config_path.is_file():
        LOGGER.warning("Config file %s not found; using built-in defaults.", config_path)
        return (
            str(DEFAULT_MODEL_PATH),
            dict(DEFAULT_EDIT_CONFIG),
            DEFAULT_SEED,
            DEFAULT_PRECISION,
            DEFAULT_IMAGE_SIZE,
        )

    cfg = load_yaml_config(config_path)
    editor_cfg = cfg.get("editor", {})
    trained_cfg = cfg.get("trained", {})

    model_path = trained_cfg.get("path") or str(DEFAULT_MODEL_PATH)
    params_grid = editor_cfg.get("params_grid", {})
    edit_config = first_param_point(params_grid) if params_grid else dict(DEFAULT_EDIT_CONFIG)

    seed_value = editor_cfg.get("seed")
    if seed_value is None:
        seed_list = editor_cfg.get("seed_list", [])
        if seed_list:
            seed_value = seed_list[0]
    seed_value = int(seed_value) if seed_value is not None else DEFAULT_SEED

    precision = editor_cfg.get("precision", DEFAULT_PRECISION)
    image_size = int(editor_cfg.get("image_size", DEFAULT_IMAGE_SIZE))

    return model_path, edit_config, seed_value, precision, image_size


def _apply_cli_overrides(
    args: argparse.Namespace, edit_config: Dict[str, Any], seed: int, image_size: int
) -> tuple[Dict[str, Any], int, int]:
    overrides = {
        "noise_samples": args.noise_samples,
        "n_steps": args.n_steps,
        "t_edit": args.t_edit,
        "t_ref": args.t_ref,
        "src_guidance_scale": args.src_guidance,
        "tar_guidance_scale": args.tar_guidance,
        "edit_dt": args.edit_dt,
    }
    for key, value in overrides.items():
        if value is not None:
            edit_config[key] = value

    if args.mask_mode:
        edit_config["mask_mode"] = args.mask_mode

    if args.disable_equivalent_gain:
        edit_config["use_equiv_gain"] = False

    if args.seed is not None:
        seed = int(args.seed)

    if args.image_size is not None:
        image_size = int(args.image_size)

    return edit_config, seed, image_size


# --------------------------------------------------------------------------- #
# Main
# --------------------------------------------------------------------------- #


def main() -> None:
    args = parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    )

    model_path, edit_config, seed, precision, image_size = _load_config_from_file(args.config)
    edit_config, seed, image_size = _apply_cli_overrides(args, edit_config, seed, image_size)

    device = torch.device(args.device if args.device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
    torch_dtype = _dtype_from_precision(args.precision or precision, device)

    dataset_root = Path(args.dataset).expanduser().resolve() if args.dataset else DEFAULT_DATA_ROOT
    ds = load_local_dataset(
        dataset_root,
        image_size=image_size,
        center_crop=args.center_crop,
    )

    output_dir = Path(args.output_dir) if args.output_dir else DEFAULT_OUTPUT_DIR
    output_dir = output_dir.expanduser().resolve()
    output_dir.mkdir(parents=True, exist_ok=True)

    LOGGER.info("Loaded %d samples from %s", len(ds), dataset_root)
    LOGGER.info("Seed: %s | Edit config: %s", seed, edit_config)

    pipeline = NaviEditPipeline.from_pretrained_sd3(
        model_path=model_path,
        default_edit_config=edit_config,
        device=device,
        torch_dtype=torch_dtype,
        image_size=image_size,
        use_center_crop=args.center_crop,
        negative_prompt=args.negative_prompt,
    )

    processed = 0
    image_counter = 1
    for idx, sample in enumerate(ds):
        if args.max_samples is not None and idx >= args.max_samples:
            break

        result = pipeline(
            image=sample["original_image"],
            source_prompt=sample["original_prompt"],
            target_prompt=sample["edited_prompt"],
            seed=seed,
        )

        images = result.images
        if not isinstance(images, list):
            raise TypeError("Pipeline returned tensors; rerun with output_type='pil'.")

        if not images:
            LOGGER.warning("Pipeline returned no images for dataset index %d", idx)
            continue

        image_obj = images[0]
        current_index = image_counter
        filename = f"{current_index:03d}.png"
        save_path = output_dir / filename
        image_obj.save(save_path)
        LOGGER.info(
            "Saved %s as output #%03d (seed=%s)",
            save_path.name,
            current_index,
            seed,
        )
        image_counter += 1

        processed += 1
        LOGGER.info("Finished output #%03d: %d image(s)", current_index, len(images))

    LOGGER.info("Finished processing %s samples. Results in %s", processed, output_dir)


if __name__ == "__main__":
    main()
