import argparse
import json
import logging
import os
import sys
from glob import glob
from typing import Tuple

from contextlib import contextmanager
import numpy as np
import torch
import torch.nn.functional as F
import cv2

"""
Efficiency Metric Analyzer

This script is explicitly designed to isolate exactly how the DepthAnythingAC architecture 
performs operationally at run-time. It completely bypasses file I/O constraints by generating 
dummy matrices in-memory, systematically extracting pure network throughput speeds (FPS), 
GPU memory utilization (VRAM), and architectural operations (FLOPs).
"""

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__)))
sys.path.insert(0, PROJECT_ROOT)


def _get_device(apple_silicon: bool) -> torch.device:
    """Dynamically links physical hardware accelerators prioritizing native CUDA mapping naturally."""
    if apple_silicon:
        if torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def find_checkpoint(ckpt_folder: str, encoder: str = None) -> tuple:
    """Intelligently detects architecture scales logically linking model topologies simply from checkpoint strings."""
    pth_files = glob(os.path.join(ckpt_folder, "*.pth"))
    if not pth_files:
        return None, None

    model_path = pth_files[0]

    if encoder is None:
        low = model_path.lower()
        if "vits" in low:
            encoder = "vits"
        elif "vitb" in low:
            encoder = "vitb"
        elif "vitl" in low:
            encoder = "vitl"
        else:
            encoder = "vitl"
    return model_path, encoder


def preprocess_image_np(dummy_bgr: np.ndarray, target_size: int = 518) -> Tuple[torch.Tensor, Tuple[int, int], str]:
    """
    Simulates OpenCV ingest processing precisely mirroring standard inference environments. 
    Crucially forces image boundaries universally locking arrays to 14x14 patch structures 
    demanded explicitly by Vision Transformer frameworks.
    """
    raw_image = dummy_bgr
    if raw_image is None:
        raise ValueError("dummy image is None")

    image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    h, w = image.shape[:2]

    # Dynamically scales short edge constraints cleanly mapping internal arrays
    scale = target_size / min(h, w)
    new_h, new_w = int(h * scale), int(w * scale)

    # Padding constraints strictly blocking architectural dimension mismatch errors naturally
    new_h = ((new_h + 13) // 14) * 14
    new_w = ((new_w + 13) // 14) * 14

    image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)

    # Standardized latent projection mappings fixing bounds perfectly
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    image = (image - mean) / std

    tensor = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
    processing_resolution = f"{new_w}x{new_h}"
    return tensor, (h, w), processing_resolution


def normalize_depth(disparity_tensor: torch.Tensor) -> torch.Tensor:
    """Internal bounds projection preventing mathematical explosion organically."""
    eps = 1e-6
    disparity_min = disparity_tensor.min()
    disparity_max = disparity_tensor.max()
    return (disparity_tensor - disparity_min) / (disparity_max - disparity_min + eps)


def postprocess_depth(depth_tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor:
    """Mirrors upsampling loops ensuring memory constraints strictly account for boundary restorations automatically."""
    depth_tensor = depth_tensor.unsqueeze(1)
    h, w = original_size
    depth = F.interpolate(depth_tensor, size=(h, w), mode="bilinear", align_corners=True)
    return depth.squeeze(1)


def _register_extra_flop_handlers(flop_counter):
    """Establishes internal hook mappings strictly enforcing complex operation tracing natively."""

    def _shape_of(x):
        if torch.is_tensor(x):
            return list(x.shape)
        x_type = getattr(x, "type", None)
        if callable(x_type):
            value_type = x_type()
            sizes_fn = getattr(value_type, "sizes", None)
            if callable(sizes_fn):
                sizes = sizes_fn()
                if sizes is not None:
                    return list(sizes)
        return None

    def _numel_of(x):
        shape = _shape_of(x)
        if not shape:
            return 0
        n = 1
        for dim in shape:
            if dim is None:
                return 0
            n *= int(dim)
        return int(n)

    def _first_output_numel(outputs):
        if isinstance(outputs, (tuple, list)) and len(outputs) > 0:
            return _numel_of(outputs[0])
        return _numel_of(outputs)

    def _softmax(inputs, outputs):
        return _first_output_numel(outputs) * 5

    def _gelu(inputs, outputs):
        return _first_output_numel(outputs) * 8

    def _sigmoid(inputs, outputs):
        return _first_output_numel(outputs) * 4

    def _elem(inputs, outputs):
        return _first_output_numel(outputs)

    def _zero(inputs, outputs):
        return 0

    def _upsample_bicubic2d(inputs, outputs):
        return _first_output_numel(outputs) * 11

    def _sdpa(inputs, outputs):
        q_shape = _shape_of(inputs[0]) if len(inputs) > 0 else None
        k_shape = _shape_of(inputs[1]) if len(inputs) > 1 else None
        if q_shape and k_shape and len(q_shape) == 4 and len(k_shape) >= 3:
            B, H, Nq, D = q_shape
            Nk = k_shape[2]
            if None not in (B, H, Nq, D, Nk):
                return int(2 * B * H * Nq * D * Nk) + int(5 * B * H * Nq * Nk)
        return _first_output_numel(outputs) * 5

    for op, fn in {
        "aten::softmax": _softmax,
        "aten::gelu": _gelu,
        "aten::sigmoid": _sigmoid,
        "aten::add": _elem,
        "aten::add_": _elem,
        "aten::mul": _elem,
        "aten::mul_": _elem,
        "aten::pow": _elem,
        "aten::exp": _elem,
        "aten::unflatten": _zero,
        "aten::upsample_bicubic2d": _upsample_bicubic2d,
        "aten::scaled_dot_product_attention": _sdpa,
    }.items():
        flop_counter.set_op_handle(op, fn)


@contextmanager
def _patch_attention_for_jit_tracing(model):
    """
    Overwrites internal memory-efficient abstractions enforcing explicit arithmetic visibility 
    structurally preventing hidden metric misses computationally. 
    """
    patched = []
    seen_classes = set()
    for module in model.modules():
        cls = type(module)
        if cls in seen_classes:
            continue
        if not (hasattr(module, "qkv") and hasattr(module, "proj")
                and hasattr(module, "num_heads")):
            continue
        seen_classes.add(cls)
        orig = cls.forward

        def _make_traceable(_cls):
            def _forward(self, x, *args, **kwargs):
                B, N, C = x.shape
                qkv = self.qkv(x).reshape(
                    B, N, 3, self.num_heads, C // self.num_heads
                ).permute(2, 0, 3, 1, 4)
                q, k, v = qkv[0], qkv[1], qkv[2]
                scale = (C // self.num_heads) ** -0.5
                attn = (q @ k.transpose(-2, -1)) * scale
                attn = attn.softmax(dim=-1)
                attn_drop = getattr(self, "attn_drop", None)
                if callable(attn_drop):
                    attn = attn_drop(attn)
                elif isinstance(attn_drop, (float, int)) and float(attn_drop) > 0:
                    attn = torch.nn.functional.dropout(attn, p=float(attn_drop), training=self.training)
                x = (attn @ v).transpose(1, 2).reshape(B, N, C)
                x = self.proj(x)
                proj_drop = getattr(self, "proj_drop", None)
                if callable(proj_drop):
                    x = proj_drop(x)
                elif isinstance(proj_drop, (float, int)) and float(proj_drop) > 0:
                    x = torch.nn.functional.dropout(x, p=float(proj_drop), training=self.training)
                return x
            return _forward

        cls.forward = _make_traceable(cls)
        patched.append((cls, orig))

    try:
        yield
    finally:
        for cls, orig_fwd in patched:
            cls.forward = orig_fwd


def measure_flops(model: torch.nn.Module, example_input: torch.Tensor) -> float:
    """
    Determines absolute algorithmic complexity natively independently of hardware constraints via fvcore traces.
    Useful explicitly for separating theoretical compute requirements from simple hardware inference logic.
    """
    try:
        from fvcore.nn import FlopCountAnalysis
    except ImportError:
        logging.warning("Package 'fvcore' not found. FLOPs will be N/A. Install with: pip install fvcore")
        return 0.0

    class Wrapper(torch.nn.Module):
        def __init__(self, m):
            super().__init__()
            self.m = m

        def forward(self, x):
            pred = self.m(x)
            return pred["out"] if isinstance(pred, dict) and "out" in pred else pred

    wrapper = Wrapper(model).to(example_input.device)

    try:
        with _patch_attention_for_jit_tracing(wrapper):
            flop_counter = FlopCountAnalysis(wrapper, example_input)
            _register_extra_flop_handlers(flop_counter)
            return float(flop_counter.total())
    except Exception as e:
        logging.warning(f"FLOPs calculation failed; FLOPs will be N/A. Reason: {e}")
        return 0.0


def main():
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description="Depth Anything AC : Efficiency Benchmark Only")
    parser.add_argument("--output", type=str, required=True, help="Output directory to save efficiency_metrics.json")
    parser.add_argument("--model", type=str, default=None, help="Model .pth path")
    parser.add_argument("--checkpoint-dir", type=str, default=None, help="Directory containing checkpoints")
    parser.add_argument("--encoder", type=str, default=None, choices=["vits", "vitb", "vitl"])
    parser.add_argument("--target-size", type=int, default=518, help="Target processing size bounds")
    parser.add_argument("--minibatch_size", type=int, default=1, help="Kept strictly for schema parity.")
    parser.add_argument("--mask-edges", action="store_true", help="Kept strictly for schema parity.")
    
    # Enables mixed precision bounds natively doubling speeds logically via fp16/bf16 matrix optimizations naturally
    parser.add_argument("--amp", action="store_true", default=True, help="Enable autocast AMP (default: on)")
    parser.add_argument("--no-amp", action="store_true", help="Disable autocast AMP")
    parser.add_argument("--amp-dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="AMP dtype")
    parser.add_argument("--apple_silicon", action="store_true", help="Use Apple Silicon (MPS)")
    args = parser.parse_args()

    device = _get_device(args.apple_silicon)
    logging.info(f"Device: {device}")

    if args.checkpoint_dir is None:
        args.checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")

    model_path = args.model
    encoder = args.encoder

    if model_path is None:
        model_path, detected_encoder = find_checkpoint(args.checkpoint_dir, encoder)
        if model_path is None:
            raise ValueError(f"No checkpoint found in {args.checkpoint_dir}")
        if encoder is None:
            encoder = detected_encoder
    else:
        if encoder is None:
            low = model_path.lower()
            if "vits" in low: encoder = "vits"
            elif "vitb" in low: encoder = "vitb"
            else: encoder = "vitl"

    if not os.path.isfile(model_path):
        raise ValueError(f"Checkpoint not found: {model_path}")

    try:
        from depth_anything.dpt import DepthAnything_AC
    except ImportError:
        sys.path.append(os.path.join(PROJECT_ROOT, "DepthAnythingAC"))
        from depth_anything.dpt import DepthAnything_AC

    model_configs = {
        "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024], "version": "v2"},
        "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768], "version": "v2"},
        "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384], "version": "v2"},
    }
    
    if encoder not in model_configs:
        logging.warning(f"Unknown encoder '{encoder}', defaulting to 'vitl'")
        encoder = "vitl"

    logging.info(f"Loading model: {model_path} (encoder={encoder})")
    model = DepthAnything_AC(model_configs[encoder])
    ckpt = torch.load(model_path, map_location="cpu")
    model.load_state_dict(ckpt, strict=False)
    
    # Drops redundant dropout layers enforcing static compute maps globally
    model = model.to(device).eval()

    use_amp = args.amp and (not args.no_amp) and (device.type == "cuda")
    amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16

    # Establishes dummy arrays explicitly removing random I/O disk fetch timing variances natively 
    H_bench, W_bench = 720, 1280
    dummy_bgr = np.random.randint(0, 256, (H_bench, W_bench, 3), dtype=np.uint8)

    x, original_size, processing_resolution = preprocess_image_np(dummy_bgr, args.target_size)
    x = x.to(device, non_blocking=True)

    flops_val = 0.0
    if device.type == "cuda":
        flops_val = measure_flops(model, x)

    logging.info("Warming up (10 iterations)...")
    # Execute dummy bounds forcing PyTorch explicitly map CUDNN matrices internally avoiding "cold start" latency spikes
    for _ in range(10):
        with torch.no_grad():
            if use_amp:
                with torch.autocast(device_type="cuda", dtype=amp_dtype):
                    pred = model(x)
            else:
                pred = model(x)
            disp = pred["out"] if isinstance(pred, dict) and "out" in pred else pred
            depth = normalize_depth(disp)
            _ = postprocess_depth(depth, original_size)

    if device.type == "cuda":
        torch.cuda.synchronize()
        # Reset memory counts explicitly logging ONLY the true inference requirements post-warmup 
        torch.cuda.reset_peak_memory_stats(device)

    n_frames = 50
    logging.info(f"Measuring latency over {n_frames} frames... (amp={use_amp}, dtype={args.amp_dtype})")

    if device.type == "cuda":
        # Employ pure CUDA event timers scaling perfectly independently from host Python process pauses naturally
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
    else:
        import time
        start_time = time.time()

    for _ in range(n_frames):
        with torch.no_grad():
            if use_amp:
                with torch.autocast(device_type="cuda", dtype=amp_dtype):
                    pred = model(x)
            else:
                pred = model(x)
            disp = pred["out"] if isinstance(pred, dict) and "out" in pred else pred
            depth = normalize_depth(disp)
            _ = postprocess_depth(depth, original_size)

    if device.type == "cuda":
        end_event.record()
        torch.cuda.synchronize()
        total_time_s = start_event.elapsed_time(end_event) / 1000.0
        max_vram_gb = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
    else:
        total_time_s = time.time() - start_time
        max_vram_gb = 0.0

    avg_latency_s = total_time_s / n_frames
    fps = 1.0 / avg_latency_s

    metrics = {
        "model_name": f"DepthAnythingAC-{encoder}",
        "input_resolution": f"{W_bench}x{H_bench}",
        "processing_resolution": processing_resolution,
        "minibatch_size": args.minibatch_size,
        "mask_edges": bool(args.mask_edges),
        "latency_sec": round(avg_latency_s, 5),
        "fps": round(fps, 2),
        "vram_gb": round(max_vram_gb, 3),
        "flops_tflops": round(flops_val / 1e12, 4) if flops_val > 0 else "N/A",
        "flops_gflops": round(flops_val / 1e9, 2) if flops_val > 0 else "N/A",
    }

    print("\n" + "=" * 40)
    print(" Efficiency Results")
    print("=" * 40)
    print(json.dumps(metrics, indent=4))
    print("=" * 40)

    os.makedirs(args.output, exist_ok=True)
    json_path = os.path.join(args.output, "efficiency_metrics.json")
    with open(json_path, "w") as f:
        json.dump(metrics, f, indent=4)

    logging.info(f"Metrics saved to: {json_path}")


if __name__ == "__main__":
    main()