from __future__ import annotations

from pathlib import Path
import sys

import torch
import torch.nn as nn

"""
Low-Rank Adaptation (LoRA) Architecture Injection Mappings

This module systematically restructures generic Vision Transformer bottleneck constraints computationally 
without fundamentally altering original topological scale parameters explicitly. By mapping lightweight 
rank-decomposed matrix multiplications orthogonally onto pre-trained linear states internally, memory footprint 
optimizations remain intact throughout distributed scaling structurally.
"""

MODEL_CONFIGS = {
    "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
    "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]},
    "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]},
    "vitg": {"encoder": "vitg", "features": 384, "out_channels": [1536, 1536, 1536, 1536]},
}


def _ensure_metric_depth_import_path(base_dir: Path) -> None:
    metric_depth_root = base_dir / "Depth-Anything-V2" / "metric_depth"
    metric_depth_str = str(metric_depth_root.resolve())
    if metric_depth_str not in sys.path:
        sys.path.insert(0, metric_depth_str)


class LoRALinear(nn.Module):
    def __init__(self, base_linear: nn.Linear, rank: int, alpha: float, dropout: float) -> None:
        super().__init__()
        if rank <= 0:
            raise ValueError("LoRA rank must be positive")

        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        self.rank = int(rank)
        self.alpha = float(alpha)
        self.scaling = self.alpha / float(self.rank)

        self.base = base_linear
        for param in self.base.parameters():
            param.requires_grad = False

        self.lora_down = nn.Linear(self.in_features, self.rank, bias=False)
        self.lora_up = nn.Linear(self.rank, self.out_features, bias=False)
        self.lora_dropout = nn.Dropout(p=float(dropout)) if dropout > 0 else nn.Identity()

        nn.init.kaiming_uniform_(self.lora_down.weight, a=5**0.5)
        nn.init.zeros_(self.lora_up.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base_out = self.base(x)
        lora_out = self.lora_up(self.lora_down(self.lora_dropout(x))) * self.scaling
        return base_out + lora_out


def _replace_backbone_linear_with_lora(module: nn.Module, rank: int, alpha: float, dropout: float) -> int:
    """
    Recursively cascades through the deep neural topology organically identifying strict `nn.Linear` 
    bottlenecks scaling parameter dimensions with dynamic Low-Rank matrices.
    """
    replaced = 0
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Linear):
            setattr(module, name, LoRALinear(child, rank=rank, alpha=alpha, dropout=dropout))
            replaced += 1
        else:
            replaced += _replace_backbone_linear_with_lora(child, rank=rank, alpha=alpha, dropout=dropout)
    return replaced


def _set_module_requires_grad(module: nn.Module, requires_grad: bool) -> None:
    for param in module.parameters():
        param.requires_grad = requires_grad


def _state_dict_by_trainable(module: nn.Module) -> dict[str, torch.Tensor]:
    out: dict[str, torch.Tensor] = {}
    named_params = dict(module.named_parameters())
    for key, value in module.state_dict().items():
        param = named_params.get(key)
        if param is not None and param.requires_grad:
            out[key] = value.detach().cpu()
    return out


def _extract_lora_state(module: nn.Module) -> dict[str, torch.Tensor]:
    """
    Isolates specifically injected low-rank matrices systematically decoupling them from the 
    massive frozen backbone state optimizing storage requirements purely to the localized adaptations.
    """
    out: dict[str, torch.Tensor] = {}
    for name, submodule in module.named_modules():
        if isinstance(submodule, LoRALinear):
            out[f"{name}.lora_down.weight"] = submodule.lora_down.weight.detach().cpu()
            out[f"{name}.lora_up.weight"] = submodule.lora_up.weight.detach().cpu()
    return out


class DAV2MetricLargeOutdoor(nn.Module):
    def __init__(self, base_dir: str | Path, max_depth: float = 50.0) -> None:
        super().__init__()
        base_dir = Path(base_dir)
        _ensure_metric_depth_import_path(base_dir)

        from depth_anything_v2.dpt import DepthAnythingV2

        self.model = DepthAnythingV2(**{**MODEL_CONFIGS["vitl"], "max_depth": max_depth})
        self.finetune_mode = "full"

    def load_pretrained(self, checkpoint_path: str | Path, strict: bool = True) -> tuple[list[str], list[str]]:
        checkpoint_path = Path(checkpoint_path)
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        state = torch.load(str(checkpoint_path), map_location="cpu")
        if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
            state = state["model"]

        if any(key.startswith("module.") for key in state.keys()):
            state = {key.replace("module.", "", 1): value for key, value in state.items()}

        missing, unexpected = self.model.load_state_dict(state, strict=strict)
        return list(missing), list(unexpected)

    def configure_finetuning(
        self,
        mode: str,
        lora_rank: int = 8,
        lora_alpha: float = 16.0,
        lora_dropout: float = 0.0,
    ) -> dict[str, int]:
        mode = mode.lower().strip()
        if mode not in {"full", "lora"}:
            raise ValueError(f"Unsupported finetuning mode: {mode}")

        self.finetune_mode = mode

        if mode == "full":
            _set_module_requires_grad(self.model, True)
            return {"lora_replaced_linear": 0}

        _set_module_requires_grad(self.model.pretrained, False)
        replaced = _replace_backbone_linear_with_lora(
            self.model.pretrained,
            rank=lora_rank,
            alpha=lora_alpha,
            dropout=lora_dropout,
        )
        _set_module_requires_grad(self.model.depth_head, True)
        return {"lora_replaced_linear": replaced}

    def trainable_param_count(self) -> int:
        return sum(param.numel() for param in self.parameters() if param.requires_grad)

    def export_checkpoint_payload(self) -> dict:
        mode = getattr(self, "finetune_mode", "full")
        if mode == "lora":
            return {
                "finetune_mode": "lora",
                "lora_state_dict": _extract_lora_state(self.model.pretrained),
                "trainable_state_dict": _state_dict_by_trainable(self.model),
            }
        return {
            "finetune_mode": "full",
            "model_state_dict": self.model.state_dict(),
        }

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)