#!/usr/bin/env python3
"""
Multi-GPU Depth Anything AC Inference Script

This script manages high-throughput parallel inference across multi-GPU environments.
Instead of processing inference synchronously (which severely underutilizes multi-node 
hardware limits), it explicitly spawns isolated tensor computations onto distinct 
GPU threads, sharding image datasets sequentially via round-robin distribution.

It directly outputs three formats:
    ├── depth_colored/       # False-color visualization (for human structural review)
    ├── depth_bw/            # 16-bit grayscale matrix (for detailed high-resolution visual extraction)
    └── depth_npy/           # Raw np.float32 depth values (used explicitly in all quantitative evaluator pipelines)
"""

import argparse
import os
import sys
from glob import glob
from pathlib import Path
from typing import List, Tuple

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from matplotlib import cm

# 'spawn' enforces clean isolated memory space per GPU child process. 
# Using the default 'fork' limits PyTorch's ability to natively map CUDA tensors properly in parallel.
mp.set_start_method("spawn", force=True)

# Dynamically link the relative DA architecture logic handling network definitions directly
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "DepthAnythingAC"))
sys.path.insert(0, PROJECT_ROOT)

IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff", ".tif"]


def get_image_files(input_dir: str) -> List[str]:
    """Scans physical input roots abstracting away multiple extension standards automatically."""
    image_files = []
    input_path = Path(input_dir)

    for ext in IMAGE_EXTENSIONS:
        image_files.extend(glob(str(input_path / f"*{ext}")))
        image_files.extend(glob(str(input_path / f"*{ext.upper()}")))

    return sorted(image_files)


def split_images(image_list: List[str], num_gpus: int) -> List[List[str]]:
    """
    Distributes paths evenly without analyzing spatial loads explicitly.
    Round-robin cleanly balances dataset shards universally across standard GPU node limits.
    """
    splits = [[] for _ in range(num_gpus)]
    for idx, img in enumerate(image_list):
        splits[idx % num_gpus].append(img)
    return splits


def normalize_depth(disparity_tensor):
    """
    Convert absolute disparity inferences (the raw output layer representation) into structural valid depth probabilities [0.0 - 1.0].
    This internal normalization prevents the final output mapping from crashing on infinite gradients natively occurring in undefined spatial spans.
    """
    eps = 1e-6
    disparity_min = disparity_tensor.min()
    disparity_max = disparity_tensor.max()
    normalized_disparity = (disparity_tensor - disparity_min) / (disparity_max - disparity_min + eps)
    return normalized_disparity


def preprocess_image(image_path, target_size=518):
    """
    Encodes standard 8-bit image formats dynamically mapping resolution explicitly into DepthAnything processing expectations.
    """
    raw_image = cv2.imread(image_path)
    if raw_image is None:
        raise ValueError(f"Cannot read image: {image_path}")

    # Transforms standard BGR OpenCV ingest explicitly mapping it linearly into generic RBG tensors 
    image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    h, w = image.shape[:2]
    
    # Scale dynamically conserving initial aspect scales naturally fitting bounded parameters
    scale = target_size / min(h, w)
    new_h, new_w = int(h * scale), int(w * scale)

    # Force patch alignments universally natively fulfilling transformer architecture grid padding requirements (14 x 14 macroblocks)
    new_h = ((new_h + 13) // 14) * 14
    new_w = ((new_w + 13) // 14) * 14
    image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)

    # Apply hardcoded pretraining statistical normalizations required to replicate correct feature embedding projections
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = (image - mean) / std

    # Reshape arrays natively dropping channel structures directly into PyTorch batched conventions
    image = torch.from_numpy(image.transpose(2, 0, 1)).float()
    image = image.unsqueeze(0)

    return image, (h, w)


def postprocess_depth(depth_tensor, original_size):
    """
    Rescales output model matrices deterministically matching original environmental pixel bounds natively using bilinear upsampling.
    """
    depth_tensor = depth_tensor.unsqueeze(1)
    h, w = original_size

    try:
        depth = F.interpolate(depth_tensor, size=(h, w), mode='bilinear', align_corners=True)
        # Flush CUDA streams directly dropping calculations uniformly out into the primary host RAM array
        depth = depth.squeeze().cpu().numpy()
        return depth
    except Exception as e:
        print(f"Interpolation failed: {str(e)}")
        return None


def worker_fn(
    rank: int,
    world_size: int,
    jobs: List[Tuple[str, str]],  # List of (image_path, output_dir_base)
    model_path: str,
    encoder: str,
    colormap: str,
    target_size: int,
) -> None:
    """
    Isolated GPU Process execution.
    Loads isolated PyTorch bindings binding precisely to individual CUDA ordinals, completely independent of other parallel pipeline threads.
    """
    if not jobs:
        print(f"[GPU {rank}] No images assigned, skipping...")
        return
        
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    print(f"[GPU {rank}] Starting worker with {len(jobs)} images on {device}")
    print(f"[GPU {rank}] GPU: {torch.cuda.get_device_name(rank)}")

    try:
        from depth_anything.dpt import DepthAnything_AC
    except ImportError:
        # Resolves dynamic Python PATH bounds locally required to link cross-directory modules 
        sys.path.append(os.path.join(PROJECT_ROOT, "DepthAnythingAC"))
        from depth_anything.dpt import DepthAnything_AC

    # ViT explicit decoder structural parameters aligning output projection dimensions logically across parameters
    model_configs = {
        'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'version': 'v2'},
        'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768], 'version': 'v2'},
        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'version': 'v2'}
    }

    if encoder not in model_configs:
         print(f"[GPU {rank}] Error: Unknown encoder '{encoder}'. Defaulting to 'vitl'.")
         encoder = 'vitl'

    # Binds checkpoint layers logically to parallel GPUs skipping host map delays
    print(f"[GPU {rank}] Loading model from {model_path}...")
    model = DepthAnything_AC(model_configs[encoder])
    checkpoint = torch.load(model_path, map_location='cpu')
    model.load_state_dict(checkpoint, strict=False)
    
    # Eval explicitly nullifies active structural dropout limits globally
    model.eval()
    model = model.to(device)
    print(f"[GPU {rank}] Model loaded successfully!")

    # Establish structural bounds natively blocking file I/O crashes
    unique_output_dirs = set(job[1] for job in jobs)
    for out_dir in unique_output_dirs:
        os.makedirs(os.path.join(out_dir, "depth_colored"), exist_ok=True)
        os.makedirs(os.path.join(out_dir, "depth_bw"), exist_ok=True)
        os.makedirs(os.path.join(out_dir, "depth_npy"), exist_ok=True)

    total_images = len(jobs)
    successful = 0
    failed = 0

    # Execute entirely without gradient tapes bypassing backprop limits ensuring explicit speed 
    with torch.no_grad():
        for idx, (image_path, output_dir) in enumerate(jobs):
            image_name = Path(image_path).stem
            
            output_dir_colored = os.path.join(output_dir, "depth_colored")
            output_dir_bw = os.path.join(output_dir, "depth_bw")
            output_dir_npy = os.path.join(output_dir, "depth_npy")

            if idx % 10 == 0:
                print(f"[GPU {rank}] Processing {idx + 1}/{total_images}: {image_name}")

            try:
                # Preprocess input geometry bounds pushing completely into CUDA pipelines
                image_tensor, original_size = preprocess_image(image_path, target_size)
                image_tensor = image_tensor.to(device)

                # Execute network depth mapping
                prediction = model(image_tensor)
                disparity_tensor = prediction['out']

                # Normalize abstract outputs and extrapolate bounds logically explicitly back towards physical dimensions 
                depth_tensor = normalize_depth(disparity_tensor)
                depth = postprocess_depth(depth_tensor, original_size)

                # Architecture logic dynamically adjusts bounds explicitly managing tensor flattening distortions  
                if depth is None:
                    if depth_tensor.dim() == 1:
                        h, w = original_size
                        expected_size = h * w
                        if depth_tensor.shape[0] == expected_size:
                            depth_tensor = depth_tensor.view(1, 1, h, w)
                        else:
                            import math
                            side_length = int(math.sqrt(depth_tensor.shape[0]))
                            if side_length * side_length == depth_tensor.shape[0]:
                                depth_tensor = depth_tensor.view(1, 1, side_length, side_length)
                    depth = postprocess_depth(depth_tensor, original_size)

                if depth is None:
                    print(f"[GPU {rank}] ERROR: Could not process depth for {image_name}")
                    failed += 1
                    continue

                # Array formats save directly natively preventing scale artifacts generated strictly within image bounds 
                npy_path = os.path.join(output_dir_npy, f"{image_name}.npy")
                np.save(npy_path, depth)

                # High-fidelity monochrome arrays expand linear spans seamlessly enabling high-contrast visual detail mapping 
                depth_16bit = (depth * 65535).astype(np.uint16)
                cv2.imwrite(os.path.join(output_dir_bw, f"{image_name}.png"), depth_16bit)

                # Colored visualization arrays encode linear values via heat maps intrinsically aiding qualitative analysis
                if colormap == 'inferno':
                    depth_colored = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
                elif colormap == 'spectral':
                    # Matplotlib explicitly bridges inverted map geometries 
                    spectral_cmap = cm.get_cmap('Spectral_r')
                    depth_colored = (spectral_cmap(depth) * 255).astype(np.uint8)
                    depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_RGBA2BGR)
                else:
                    depth_colored = (depth * 255).astype(np.uint8)
                    depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_GRAY2BGR)

                cv2.imwrite(os.path.join(output_dir_colored, f"{image_name}.png"), depth_colored)
                successful += 1

            except Exception as e:
                failed += 1
                print(f"[GPU {rank}] ERROR processing {image_name}: {e}")
                
                if failed <= 3:
                     import traceback
                     traceback.print_exc()
                continue

    print(f"[GPU {rank}] Worker finished! Processed {total_images} images ({successful} successful, {failed} failed).")

    # Flush execution bounds cleanly out of physical multi-device systems inherently 
    del model
    torch.cuda.empty_cache()


def find_checkpoint(ckpt_folder: str, encoder: str = None) -> tuple:
    """Intelligently defaults architecture mappings based entirely off parameter strings without explicitly supplying flags."""
    pth_files = glob(os.path.join(ckpt_folder, "*.pth"))

    if not pth_files:
        return None, None

    model_path = pth_files[0]

    # Assign architectural network scales natively matching structural parameters directly to model targets
    if encoder is None:
        if "vits" in model_path.lower():
            encoder = "vits"
        elif "vitb" in model_path.lower():
            encoder = "vitb"
        elif "vitl" in model_path.lower():
            encoder = "vitl"
        else:
            encoder = "vitl"

    return model_path, encoder


def main():
    parser = argparse.ArgumentParser(
        description="Multi-GPU Depth Anything AC Inference explicitly distributing sequence arrays out structurally scaling pipeline limits logically.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("--pairs", type=str, action="append", required=True, help="Input|output directory pair. Repeat flag for multiple pairs.")
    parser.add_argument("--model", type=str, default=None, help="Model weights path")
    parser.add_argument("--encoder", type=str, default=None, choices=["vits", "vitb", "vitl"], help="Encoder type")
    parser.add_argument("--checkpoint-dir", type=str, default=None, help="Directory containing checkpoints")
    parser.add_argument("--num-gpus", type=int, default=None, help="Number of GPUs to use")
    parser.add_argument("--colormap", type=str, default="spectral", choices=["inferno", "spectral", "gray"])
    parser.add_argument("--target-size", type=int, default=518)

    args = parser.parse_args()

    if args.checkpoint_dir is None:
        args.checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")

    if args.model is None:
        args.model, detected_encoder = find_checkpoint(args.checkpoint_dir, args.encoder)
        if args.model is None:
            raise ValueError(f"No checkpoint found in {args.checkpoint_dir}")
        if args.encoder is None:
            args.encoder = detected_encoder
    else:
        if args.encoder is None:
            if "vits" in args.model.lower(): args.encoder = "vits"
            elif "vitb" in args.model.lower(): args.encoder = "vitb"
            else: args.encoder = "vitl"

    print("=" * 60)
    print("Multi-GPU Depth Anything AC Inference")
    print("=" * 60)
    print(f"Model:            {args.model}")
    print(f"Encoder:          {args.encoder}")
    print(f"Colormap:         {args.colormap}")
    print(f"Target size:      {args.target_size}")

    all_jobs: List[Tuple[str, str]] = []
    
    # Loop over user-supplied boundaries natively managing output mapping dynamically mapping arrays
    for pair in args.pairs:
        if "|" not in pair:
             raise ValueError(f"Pair must be 'input|output': {pair}")
        input_dir, output_dir = pair.split("|", 1)
        input_dir = input_dir.strip()
        output_dir = output_dir.strip()

        if not os.path.isdir(input_dir):
            print(f"WARNING: Input directory does not exist: {input_dir}. Skipping pair.")
            continue

        images = get_image_files(input_dir)
        if not images:
            print(f"WARNING: No images found in {input_dir}. Skipping pair.")
            continue
            
        print(f"  Found {len(images)} images in {input_dir}")
        print(f"  Output will be in {output_dir}")
        
        os.makedirs(output_dir, exist_ok=True)
        
        for img in images:
            all_jobs.append((img, output_dir))

    if not all_jobs:
        raise ValueError("No images to process across provided pairs")

    print(f"Total images:     {len(all_jobs)}")

    # Automatically map network sizes seamlessly leveraging device variables directly avoiding process crashes naturally
    available_gpus = torch.cuda.device_count()
    if available_gpus == 0:
        raise RuntimeError("No CUDA GPUs available!")

    num_gpus = args.num_gpus if args.num_gpus else available_gpus
    num_gpus = min(num_gpus, available_gpus, len(all_jobs))

    print(f"Available GPUs:   {available_gpus}")
    print(f"Using GPUs:       {num_gpus}")

    # Orchestrate splits logically enforcing fair structural workloads universally
    job_splits = [[] for _ in range(num_gpus)]
    for idx, job in enumerate(all_jobs):
        job_splits[idx % num_gpus].append(job)

    for i, split in enumerate(job_splits):
        print(f"  GPU {i} will process: {len(split)} images")

    print("\nStarting multi-GPU inference...")
    processes = []

    # Activate child pools isolating CUDA boundaries individually completely shielding against global network memory bleeding 
    for rank in range(num_gpus):
        p = mp.Process(
            target=worker_fn,
            args=(
                rank,
                num_gpus,
                job_splits[rank],
                args.model,
                args.encoder,
                args.colormap,
                args.target_size,
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    print("\nAll workers completed!")
    print("\n" + "=" * 60)
    print("Multi-GPU inference completed successfully!")
    print("=" * 60)

if __name__ == "__main__":
    main()
