from __future__ import annotations

import argparse
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

from architecture.dav2_metric_lusnar import DAV2MetricLargeOutdoor
from dataset_scripts.lusnar_dataset import LusnarDepthDataset

"""
Extrinsic Fine-Tuning Orchestrator

This module centrally manages distributed multi-node algorithmic calibrations enforcing gradient 
alignments structurally mapping latent DepthAnythingV2 topological features to purely novel outdoor metric 
domains. Evaluates runtime mathematical loss functions organically and iteratively updates weight schemas via 
cosine decay logic robustly.
"""

EVAL_SCRIPTS_ROOT = Path("path/to/eval_scripts")
if str(EVAL_SCRIPTS_ROOT) not in sys.path:
    sys.path.append(str(EVAL_SCRIPTS_ROOT))

from core.alignment import least_squares_align
from core.evaluator import apply_max_depth_mask
from core.metrics import compute_metrics


LOSS_MASK_MAX_DEPTH = 50.0
EVAL_METRIC_KEYS = [
    "delta_1",
    "delta_2",
    "delta_3",
    "mae",
    "abs_rel",
    "rmse",
    "silog",
    "irmse",
    "sq_rel",
    "edge_acc",
    "edge_comp",
]


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(description="Fine-tune Depth Anything V2 metric-large outdoor on LuSNAR")
    ap.add_argument("--base-dir", type=str, default="path/to/dav2_finetuning")
    ap.add_argument("--data-root", type=str, default="path/to/lusnar-dataset")
    ap.add_argument(
        "--pretrained",
        type=str,
        default="path/to/dav2_finetuning/Depth-Anything-V2/metric_depth/checkpoints/depth_anything_v2_metric_vkitti_vitl.pth",
    )
    ap.add_argument("--results-dir", type=str, default="path/to/dav2_finetuning/results")
    ap.add_argument("--run-name", type=str, default="dav2_lusnar_finetune")
    ap.add_argument("--input-size", type=int, default=518)
    ap.add_argument("--max-depth", type=float, default=50.0)
    ap.add_argument("--finetune-mode", type=str, default="full", choices=["full", "lora"])
    ap.add_argument("--lora-rank", type=int, default=8)
    ap.add_argument("--lora-alpha", type=float, default=16.0)
    ap.add_argument("--lora-dropout", type=float, default=0.0)
    ap.add_argument("--batch-size", type=int, default=0, help="Per-process batch size. Set <=0 to enable auto search.")
    ap.add_argument("--batch-size-max", type=int, default=64, help="Auto batch-size search start value")
    ap.add_argument("--batch-size-min", type=int, default=2, help="Auto batch-size search minimum value")
    ap.add_argument("--val-batch-size", type=int, default=1)
    ap.add_argument("--workers", type=int, default=6)
    ap.add_argument("--epochs", type=int, default=80)
    ap.add_argument("--lr", type=float, default=5e-5)
    ap.add_argument("--weight-decay", type=float, default=0.01)
    ap.add_argument("--warmup-epochs", type=int, default=2, help="Linear warmup epochs before cosine decay")
    ap.add_argument("--min-lr", type=float, default=1e-6, help="Minimum learning rate for cosine tail")
    ap.add_argument("--silog-lambda", type=float, default=0.5)
    ap.add_argument("--silog-weight", type=float, default=1.0)
    ap.add_argument("--grad-weight", type=float, default=1.0)
    ap.add_argument("--early-stop-patience", type=int, default=10)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--print-freq", type=int, default=20)
    ap.add_argument("--save-every", type=int, default=1)
    return ap.parse_args()


def setup_seed(seed: int) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)


def setup_distributed() -> tuple[int, int, int, bool]:
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    rank = int(os.environ.get("RANK", "0"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    distributed = world_size > 1

    if distributed:
        if torch.cuda.is_available():
            try:
                dist.init_process_group(backend="nccl", init_method="env://", device_id=torch.device(f"cuda:{local_rank}"))
            except TypeError:
                dist.init_process_group(backend="nccl", init_method="env://")
        else:
            dist.init_process_group(backend="gloo", init_method="env://")
        torch.cuda.set_device(local_rank)
    return rank, local_rank, world_size, distributed


def cleanup_distributed(distributed: bool) -> None:
    if distributed and dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()


def setup_logger(run_dir: Path, rank: int) -> logging.Logger:
    logger = logging.getLogger(f"train_rank_{rank}")
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")

    file_handler = logging.FileHandler(run_dir / f"train_rank{rank}.log", mode="a", encoding="utf-8")
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    if rank == 0:
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

    return logger


def build_dynamic_depth_mask(depth: torch.Tensor, base_valid_mask: torch.Tensor) -> torch.Tensor:
    """Enforces absolute geometrical boundaries preventing mathematical explosion on undefined points."""
    return base_valid_mask.bool() & (depth > 0.0) & (depth <= LOSS_MASK_MAX_DEPTH)


def masked_silog_loss(
    pred: torch.Tensor,
    target: torch.Tensor,
    valid_mask: torch.Tensor,
    lambd: float = 0.5,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Calculates scale-invariant logarithmic loss explicitly factoring structural divergences in geometry natively, 
    accounting for variance dependencies scaling directly.
    """
    valid = valid_mask.bool()
    if valid.sum().item() == 0:
        return pred.sum() * 0.0

    pred_v = torch.clamp(pred[valid], min=eps)
    target_v = torch.clamp(target[valid], min=eps)
    diff_log = torch.log(target_v) - torch.log(pred_v)
    silog = torch.sqrt(torch.clamp((diff_log**2).mean() - lambd * (diff_log.mean() ** 2), min=0.0) + eps)
    return silog


def masked_gradient_matching_loss(
    pred: torch.Tensor,
    target: torch.Tensor,
    valid_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Extracts dynamic structural gradient differences enforcing edge-preservation capabilities naturally 
    matching topological high-frequency variations robustly across axes identically.
    """
    valid = valid_mask.bool()

    pred_dx = pred[:, :, 1:] - pred[:, :, :-1]
    target_dx = target[:, :, 1:] - target[:, :, :-1]
    valid_dx = valid[:, :, 1:] & valid[:, :, :-1]

    pred_dy = pred[:, 1:, :] - pred[:, :-1, :]
    target_dy = target[:, 1:, :] - target[:, :-1, :]
    valid_dy = valid[:, 1:, :] & valid[:, :-1, :]

    loss_x = pred.sum() * 0.0
    loss_y = pred.sum() * 0.0

    if valid_dx.any():
        loss_x = torch.abs(pred_dx[valid_dx] - target_dx[valid_dx]).mean()
    if valid_dy.any():
        loss_y = torch.abs(pred_dy[valid_dy] - target_dy[valid_dy]).mean()

    return 0.5 * (loss_x + loss_y)


@torch.no_grad()
def validate(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    distributed: bool,
    capture_visual: bool = False,
) -> tuple[dict[str, float], dict[str, torch.Tensor] | None]:
    """
    Safely integrates full epoch validation loops strictly avoiding gradient memory accumulations cleanly tracking 
    intrinsic distance errors geometrically iteratively.
    """
    model.eval()

    total_l1 = torch.zeros(1, device=device)
    total_abs = torch.zeros(1, device=device)
    total_sq = torch.zeros(1, device=device)
    total_count = torch.zeros(1, device=device)
    vis_payload: dict[str, torch.Tensor] | None = None

    for batch in loader:
        image = batch["image"].to(device, non_blocking=True)
        depth = batch["depth"].to(device, non_blocking=True)
        valid = batch["valid_mask"].to(device, non_blocking=True)

        pred = model(image)
        if pred.shape[-2:] != depth.shape[-2:]:
            pred = F.interpolate(pred.unsqueeze(1), size=depth.shape[-2:], mode="bilinear", align_corners=True).squeeze(1)

        valid = build_dynamic_depth_mask(depth, valid)
        if valid.sum().item() == 0:
            continue

        if capture_visual and vis_payload is None:
            vis_payload = {
                "image": image[0].detach().cpu(),
                "pred": pred[0].detach().cpu(),
                "gt": depth[0].detach().cpu(),
            }

        diff = torch.abs(pred[valid] - depth[valid])
        sq_diff = (pred[valid] - depth[valid]) ** 2

        total_l1 += diff.sum()
        total_abs += diff.sum()
        total_sq += sq_diff.sum()
        total_count += torch.tensor([float(diff.numel())], device=device)

    if distributed:
        dist.all_reduce(total_l1, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_abs, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_sq, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_count, op=dist.ReduceOp.SUM)

    denom = max(total_count.item(), 1.0)
    mean_l1 = (total_l1 / denom).item()
    rmse = torch.sqrt(total_sq / denom).item()

    metrics = {
        "val_l1": mean_l1,
        "val_rmse": rmse,
        "valid_pixels": float(total_count.item()),
    }
    return metrics, vis_payload


def assert_pretrained_not_overwritten(pretrained_path: str, save_dir: Path) -> None:
    pretrained_resolved = Path(pretrained_path).expanduser().resolve()
    save_dir_resolved = save_dir.expanduser().resolve()

    if pretrained_resolved.parent == save_dir_resolved:
        raise RuntimeError(
            "Safety check failed: checkpoint save directory equals pretrained checkpoint directory. "
            "Choose a different --results-dir so pretrained weights are never overwritten."
        )


def get_shared_run_stamp(rank: int, distributed: bool) -> str:
    def _make_stamp() -> str:
        base = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        job_id = os.environ.get("SLURM_JOB_ID", "")
        return f"{base}_job{job_id}" if job_id else base

    if not distributed:
        return _make_stamp()

    stamp_holder = [_make_stamp() if rank == 0 else ""]
    dist.broadcast_object_list(stamp_holder, src=0)
    return str(stamp_holder[0])


def _is_oom_error(exc: RuntimeError) -> bool:
    message = str(exc).lower()
    return "out of memory" in message or "cuda error: out of memory" in message


def _batch_candidates(max_batch: int, min_batch: int) -> list[int]:
    max_batch = max(1, int(max_batch))
    min_batch = max(1, int(min_batch))
    if max_batch < min_batch:
        max_batch, min_batch = min_batch, max_batch

    values: list[int] = []
    current = max_batch
    while current >= min_batch:
        values.append(current)
        current //= 2
    if min_batch not in values:
        values.append(min_batch)
    return values


def autotune_batch_size(
    model: torch.nn.Module,
    train_set: Dataset,
    args: argparse.Namespace,
    device: torch.device,
    rank: int,
    distributed: bool,
    logger: logging.Logger,
) -> int:
    """
    Systematically executes iterative local memory threshold checks mathematically identifying absolute 
    maximum tensor boundaries before out-of-memory logic crashes structurally.
    """
    chosen = int(args.batch_size)
    if chosen > 0:
        if distributed:
            chosen_tensor = torch.tensor([chosen], device=device)
            dist.broadcast(chosen_tensor, src=0)
            chosen = int(chosen_tensor.item())
        return chosen

    if rank == 0:
        logger.info("Auto batch-size search enabled (max=%d, min=%d)", args.batch_size_max, args.batch_size_min)
        probe_optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        model.train()
        chosen = -1

        for candidate in _batch_candidates(args.batch_size_max, args.batch_size_min):
            if len(train_set) < candidate:
                continue

            loader = DataLoader(
                train_set,
                batch_size=candidate,
                shuffle=True,
                num_workers=0,
                pin_memory=True,
                drop_last=True,
            )
            if len(loader) == 0:
                continue

            batch = next(iter(loader))
            image = batch["image"].to(device, non_blocking=True)
            depth = batch["depth"].to(device, non_blocking=True)
            valid = batch["valid_mask"].to(device, non_blocking=True)

            try:
                probe_optimizer.zero_grad(set_to_none=True)
                pred = model(image)
                if pred.shape[-2:] != depth.shape[-2:]:
                    pred = F.interpolate(pred.unsqueeze(1), size=depth.shape[-2:], mode="bilinear", align_corners=True).squeeze(1)
                dynamic_valid = build_dynamic_depth_mask(depth, valid)
                silog_loss = masked_silog_loss(pred, depth, dynamic_valid, lambd=args.silog_lambda)
                grad_loss = masked_gradient_matching_loss(pred, depth, dynamic_valid)
                loss = args.silog_weight * silog_loss + args.grad_weight * grad_loss
                loss.backward()
                probe_optimizer.zero_grad(set_to_none=True)
                chosen = candidate
                logger.info("Auto batch-size selected: %d", chosen)
                break
            except RuntimeError as exc:
                probe_optimizer.zero_grad(set_to_none=True)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                if _is_oom_error(exc):
                    logger.info("Auto batch-size OOM at %d, trying smaller batch.", candidate)
                    continue
                raise

        if chosen < 0:
            chosen = max(1, int(args.batch_size_min))
            logger.warning("Auto batch-size fallback to minimum=%d", chosen)

    if distributed:
        chosen_tensor = torch.tensor([int(chosen)], device=device)
        dist.broadcast(chosen_tensor, src=0)
        chosen = int(chosen_tensor.item())

    return chosen


def _denormalize_input(image_chw: torch.Tensor) -> torch.Tensor:
    mean = torch.tensor([0.485, 0.456, 0.406], dtype=image_chw.dtype, device=image_chw.device).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], dtype=image_chw.dtype, device=image_chw.device).view(3, 1, 1)
    return (image_chw * std + mean).clamp(0.0, 1.0)


def _depth_to_vis(depth_hw: torch.Tensor, max_depth: float) -> torch.Tensor:
    vis = (depth_hw.clamp(0.0, max_depth) / max(max_depth, 1e-6)).unsqueeze(0)
    return vis.repeat(3, 1, 1)


class WarmupCosineScheduler:
    def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        warmup_steps: int,
        min_lr: float,
    ) -> None:
        self.optimizer = optimizer
        self.total_steps = max(1, int(total_steps))
        self.warmup_steps = max(0, int(warmup_steps))
        self.min_lr = float(min_lr)
        self.base_lrs = [group["lr"] for group in optimizer.param_groups]
        self.last_step = 0

    def _scale_at(self, step: int) -> float:
        step = max(1, min(step, self.total_steps))

        if self.warmup_steps > 0 and step <= self.warmup_steps:
            return step / float(self.warmup_steps)

        cosine_steps = max(1, self.total_steps - self.warmup_steps)
        progress = (step - self.warmup_steps) / float(cosine_steps)
        progress = max(0.0, min(1.0, progress))
        return 0.5 * (1.0 + np.cos(np.pi * progress))

    def step(self) -> float:
        self.last_step += 1
        scale = self._scale_at(self.last_step)

        for idx, group in enumerate(self.optimizer.param_groups):
            base_lr = self.base_lrs[idx]
            group["lr"] = self.min_lr + (base_lr - self.min_lr) * scale

        return self.optimizer.param_groups[0]["lr"]

    def get_last_lr(self) -> list[float]:
        return [group["lr"] for group in self.optimizer.param_groups]


def _valid_gt_mask_np(gt: np.ndarray) -> np.ndarray:
    return np.isfinite(gt) & (gt > 0.0) & (gt < 65500.0)


def _compute_eval_metrics_sample(pred: np.ndarray, gt: np.ndarray) -> dict:
    pred_depth = pred.astype(np.float64)
    gt = gt.astype(np.float64)
    pred_valid = np.isfinite(pred_depth)
    base_mask = _valid_gt_mask_np(gt) & pred_valid
    base_mask = apply_max_depth_mask(gt, base_mask, LOSS_MASK_MAX_DEPTH)
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, base_mask)
    return compute_metrics(aligned_pred, gt, base_mask)


def _new_metric_accumulator() -> dict[str, list[float]]:
    return {k: [0.0, 0.0] for k in EVAL_METRIC_KEYS}


def _accumulate_metric_dict(acc: dict[str, list[float]], values: dict) -> None:
    for key in EVAL_METRIC_KEYS:
        val = values.get(key, float("nan"))
        if val is None or not np.isfinite(val):
            continue
        acc[key][0] += float(val)
        acc[key][1] += 1.0


def _reduce_metric_accumulator(acc: dict[str, list[float]], device: torch.device, distributed: bool) -> dict[str, list[float]]:
    if not distributed:
        return acc
    for key in EVAL_METRIC_KEYS:
        tensor = torch.tensor(acc[key], device=device, dtype=torch.float64)
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        acc[key] = [float(tensor[0].item()), float(tensor[1].item())]
    return acc


def _finalize_metric_accumulator(acc: dict[str, list[float]]) -> dict[str, float]:
    out: dict[str, float] = {}
    for key in EVAL_METRIC_KEYS:
        total, count = acc[key]
        out[key] = float(total / count) if count > 0 else float("nan")
    return out


def main() -> None:
    args = parse_args()

    rank, local_rank, world_size, distributed = setup_distributed()
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    writer = None

    try:
        setup_seed(args.seed + rank)

        run_stamp = get_shared_run_stamp(rank, distributed)
        run_dir = Path(args.results_dir) / f"{args.run_name}_{run_stamp}"
        ckpt_dir = run_dir / "checkpoints"
        tb_dir = run_dir / "tensorboard"
        assert_pretrained_not_overwritten(args.pretrained, ckpt_dir)
        if rank == 0:
            ckpt_dir.mkdir(parents=True, exist_ok=True)
            tb_dir.mkdir(parents=True, exist_ok=True)
        if distributed:
            dist.barrier()

        logger = setup_logger(run_dir, rank)
        writer = SummaryWriter(log_dir=str(tb_dir)) if rank == 0 else None
        logger.info("Starting LuSNAR fine-tuning")
        logger.info("Rank %d/%d initialized", rank, world_size)
        logger.info(
            "LR schedule: linear warmup + cosine decay | base_lr=%.6e min_lr=%.6e warmup_epochs=%d",
            args.lr,
            args.min_lr,
            args.warmup_epochs,
        )

        train_set = LusnarDepthDataset(root=args.data_root, split="train", input_size=args.input_size, max_depth=args.max_depth)
        val_set = LusnarDepthDataset(root=args.data_root, split="val", input_size=args.input_size, max_depth=args.max_depth)

        model = DAV2MetricLargeOutdoor(base_dir=args.base_dir, max_depth=args.max_depth)
        missing, unexpected = model.load_pretrained(args.pretrained, strict=True)
        if missing or unexpected:
            logger.info("Checkpoint load status | missing=%d unexpected=%d", len(missing), len(unexpected))

        mode_info = model.configure_finetuning(
            mode=args.finetune_mode,
            lora_rank=args.lora_rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
        )
        logger.info(
            "Finetune mode=%s | trainable_params=%d | lora_replaced_linear=%d",
            args.finetune_mode,
            model.trainable_param_count(),
            mode_info.get("lora_replaced_linear", 0),
        )

        model.to(device)
        effective_batch_size = autotune_batch_size(model, train_set, args, device, rank, distributed, logger)
        logger.info("Using per-process train batch size: %d", effective_batch_size)
        if distributed:
            dist.barrier()

        train_sampler = DistributedSampler(train_set, shuffle=True) if distributed else None
        val_sampler = DistributedSampler(val_set, shuffle=False) if distributed else None

        train_loader = DataLoader(
            train_set,
            batch_size=effective_batch_size,
            sampler=train_sampler,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True,
        )
        val_loader = DataLoader(
            val_set,
            batch_size=args.val_batch_size,
            sampler=val_sampler,
            shuffle=False,
            num_workers=max(1, args.workers // 2),
            pin_memory=True,
            drop_last=False,
        )

        if distributed:
            model = DDP(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                broadcast_buffers=False,
                find_unused_parameters=True,
            )

        trainable_params = [p for p in model.parameters() if p.requires_grad]
        if not trainable_params:
            raise RuntimeError("No trainable parameters found for the selected finetune mode")
        optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
        total_steps = args.epochs * max(1, len(train_loader))
        warmup_steps = args.warmup_epochs * max(1, len(train_loader))
        scheduler = WarmupCosineScheduler(
            optimizer=optimizer,
            total_steps=total_steps,
            warmup_steps=warmup_steps,
            min_lr=args.min_lr,
        )

        if rank == 0:
            config = vars(args).copy()
            config["world_size"] = world_size
            config["effective_batch_size"] = effective_batch_size
            with (run_dir / "config.json").open("w", encoding="utf-8") as f:
                json.dump(config, f, indent=2)

        best_val_abs_rel = float("inf")
        epochs_without_improve = 0
        global_step = 0
        metrics_history: list[dict[str, float | int]] = []

        for epoch in range(args.epochs):
            if distributed and train_sampler is not None:
                train_sampler.set_epoch(epoch)

            model.train()
            running_loss = 0.0
            running_steps = 0
            train_metric_acc = _new_metric_accumulator()

            for step, batch in enumerate(train_loader):
                image = batch["image"].to(device, non_blocking=True)
                depth = batch["depth"].to(device, non_blocking=True)
                valid = batch["valid_mask"].to(device, non_blocking=True)

                pred = model(image)
                if pred.shape[-2:] != depth.shape[-2:]:
                    pred = F.interpolate(pred.unsqueeze(1), size=depth.shape[-2:], mode="bilinear", align_corners=True).squeeze(1)

                dynamic_valid = build_dynamic_depth_mask(depth, valid)
                silog_loss = masked_silog_loss(pred, depth, dynamic_valid, lambd=args.silog_lambda)
                grad_loss = masked_gradient_matching_loss(pred, depth, dynamic_valid)
                loss = args.silog_weight * silog_loss + args.grad_weight * grad_loss

                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                current_lr = scheduler.step()

                pred_np = pred.detach().cpu().numpy()
                gt_np = depth.detach().cpu().numpy()
                for idx in range(pred_np.shape[0]):
                    sample_metrics = _compute_eval_metrics_sample(pred_np[idx], gt_np[idx])
                    _accumulate_metric_dict(train_metric_acc, sample_metrics)

                running_loss += loss.item()
                running_steps += 1
                global_step += 1

                if rank == 0 and writer is not None:
                    writer.add_scalar("Train/loss_total_step", loss.item(), global_step)
                    writer.add_scalar("Train/loss_silog_step", silog_loss.item(), global_step)
                    writer.add_scalar("Train/loss_grad_step", grad_loss.item(), global_step)
                    writer.add_scalar("Train/lr", current_lr, global_step)

                if rank == 0 and ((step + 1) % args.print_freq == 0 or (step + 1) == len(train_loader)):
                    logger.info(
                        "Epoch [%d/%d] Step [%d/%d] | train_total=%.6f silog=%.6f grad=%.6f",
                        epoch + 1,
                        args.epochs,
                        step + 1,
                        len(train_loader),
                        running_loss / max(running_steps, 1),
                        silog_loss.item(),
                        grad_loss.item(),
                    )

            train_metric_acc = _reduce_metric_accumulator(train_metric_acc, device, distributed)
            train_epoch_metrics = _finalize_metric_accumulator(train_metric_acc)

            train_loss_tensor = torch.tensor([running_loss, running_steps], device=device, dtype=torch.float64)
            if distributed:
                dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM)
            epoch_train_loss = float(train_loss_tensor[0].item() / max(train_loss_tensor[1].item(), 1.0))

            model.eval()
            val_loss_sum = 0.0
            val_steps = 0
            val_metric_acc = _new_metric_accumulator()
            val_vis = None

            with torch.no_grad():
                for batch in val_loader:
                    image = batch["image"].to(device, non_blocking=True)
                    depth = batch["depth"].to(device, non_blocking=True)
                    valid = batch["valid_mask"].to(device, non_blocking=True)

                    pred = model(image)
                    if pred.shape[-2:] != depth.shape[-2:]:
                        pred = F.interpolate(pred.unsqueeze(1), size=depth.shape[-2:], mode="bilinear", align_corners=True).squeeze(1)

                    dynamic_valid = build_dynamic_depth_mask(depth, valid)
                    silog_loss = masked_silog_loss(pred, depth, dynamic_valid, lambd=args.silog_lambda)
                    grad_loss = masked_gradient_matching_loss(pred, depth, dynamic_valid)
                    loss = args.silog_weight * silog_loss + args.grad_weight * grad_loss

                    val_loss_sum += float(loss.item())
                    val_steps += 1

                    pred_np = pred.detach().cpu().numpy()
                    gt_np = depth.detach().cpu().numpy()
                    for idx in range(pred_np.shape[0]):
                        sample_metrics = _compute_eval_metrics_sample(pred_np[idx], gt_np[idx])
                        _accumulate_metric_dict(val_metric_acc, sample_metrics)

                    if rank == 0 and val_vis is None:
                        val_vis = {
                            "image": image[0].detach().cpu(),
                            "pred": pred[0].detach().cpu(),
                            "gt": depth[0].detach().cpu(),
                        }

            val_metric_acc = _reduce_metric_accumulator(val_metric_acc, device, distributed)
            val_epoch_metrics = _finalize_metric_accumulator(val_metric_acc)

            val_loss_tensor = torch.tensor([val_loss_sum, val_steps], device=device, dtype=torch.float64)
            if distributed:
                dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
            epoch_val_loss = float(val_loss_tensor[0].item() / max(val_loss_tensor[1].item(), 1.0))

            current_val_abs_rel = float(val_epoch_metrics.get("abs_rel", float("nan")))
            improved = np.isfinite(current_val_abs_rel) and (current_val_abs_rel < best_val_abs_rel)

            if rank == 0:
                if writer is not None:
                    writer.add_scalar("Train/loss_total_epoch", epoch_train_loss, epoch + 1)
                    writer.add_scalar("Val/loss_total_epoch", epoch_val_loss, epoch + 1)
                    writer.add_scalar("Train/learning_rate", optimizer.param_groups[0]["lr"], epoch + 1)
                    for key in EVAL_METRIC_KEYS:
                        writer.add_scalar(f"Train/{key}", train_epoch_metrics[key], epoch + 1)
                        writer.add_scalar(f"Val/{key}", val_epoch_metrics[key], epoch + 1)

                metrics_history.append(
                    {
                        "epoch": epoch + 1,
                        "train_total_loss": float(epoch_train_loss),
                        "val_total_loss": float(epoch_val_loss),
                        "train_abs_rel": float(train_epoch_metrics.get("abs_rel", float("nan"))),
                        "val_abs_rel": float(val_epoch_metrics.get("abs_rel", float("nan"))),
                        "train_rmse": float(train_epoch_metrics.get("rmse", float("nan"))),
                        "val_rmse": float(val_epoch_metrics.get("rmse", float("nan"))),
                        "lr": float(optimizer.param_groups[0]["lr"]),
                    }
                )
                with (run_dir / "metrics_history.json").open("w", encoding="utf-8") as f:
                    json.dump(metrics_history, f, indent=2)

                logger.info(
                    "Epoch [%d/%d] | train_loss=%.6f val_loss=%.6f | train_abs_rel=%.6f val_abs_rel=%.6f | train_rmse=%.6f val_rmse=%.6f",
                    epoch + 1,
                    args.epochs,
                    epoch_train_loss,
                    epoch_val_loss,
                    train_epoch_metrics.get("abs_rel", float("nan")),
                    val_epoch_metrics.get("abs_rel", float("nan")),
                    train_epoch_metrics.get("rmse", float("nan")),
                    val_epoch_metrics.get("rmse", float("nan")),
                )

                if improved:
                    best_val_abs_rel = current_val_abs_rel
                    epochs_without_improve = 0

                    model_to_save = model.module if isinstance(model, DDP) else model
                    checkpoint_payload = model_to_save.export_checkpoint_payload()
                    torch.save(checkpoint_payload, ckpt_dir / "best_model.pth")

                    best_meta = {
                        "epoch": epoch + 1,
                        "best_val_abs_rel": best_val_abs_rel,
                        "val_metrics": val_epoch_metrics,
                        "train_metrics": train_epoch_metrics,
                        "args": vars(args),
                    }
                    with (ckpt_dir / "best_model_meta.json").open("w", encoding="utf-8") as f:
                        json.dump(best_meta, f, indent=2)

                    if writer is not None and val_vis is not None:
                        in_vis = _denormalize_input(val_vis["image"]).cpu()
                        pred_vis = _depth_to_vis(val_vis["pred"], args.max_depth).cpu()
                        gt_vis = _depth_to_vis(val_vis["gt"], args.max_depth).cpu()
                        writer.add_image("BestEpoch/input", in_vis, epoch + 1)
                        writer.add_image("BestEpoch/pred", pred_vis, epoch + 1)
                        writer.add_image("BestEpoch/gt", gt_vis, epoch + 1)
                else:
                    epochs_without_improve += 1

                logger.info(
                    "Epoch [%d/%d] | improved=%s | epochs_without_improve=%d/%d",
                    epoch + 1,
                    args.epochs,
                    str(improved),
                    epochs_without_improve,
                    args.early_stop_patience,
                )
                logger.info(
                    "MONITOR epoch=%d val_abs_rel=%.6f best_val_abs_rel=%.6f",
                    epoch + 1,
                    current_val_abs_rel,
                    best_val_abs_rel,
                )

            stop_now = False
            if rank == 0:
                stop_now = epochs_without_improve >= args.early_stop_patience

            stop_tensor = torch.tensor([1 if stop_now else 0], device=device)
            if distributed:
                dist.broadcast(stop_tensor, src=0)

            if int(stop_tensor.item()) == 1:
                if rank == 0:
                    logger.info("Early stopping triggered at epoch %d", epoch + 1)
                break

        if rank == 0:
            logger.info("Training complete. Best val_abs_rel=%.6f", best_val_abs_rel)
            if writer is not None:
                writer.flush()
                writer.close()
    finally:
        if writer is not None:
            try:
                writer.flush()
                writer.close()
            except Exception:
                pass
        cleanup_distributed(distributed)


if __name__ == "__main__":
    main()