"""model_builders.py
===================
Factory that creates the **student** (MedSAM / SAM-Med2D) and the **teacher**
(original Meta-SAM) models used throughout *Slimsam Fused Train*.

The constructor signature now matches the CLI -- we forward every relevant
argument so that the underlying `build_sam_*` functions receive the full
`args` namespace they expect (image_size, sam_checkpoint, encoder_adapter …).
"""

from __future__ import annotations

import copy
from argparse import Namespace
from types import SimpleNamespace
from typing import Any, Dict, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

from .sam_org import sam_model_registry as smr  # Original Meta SAM
from .sammed2d import sam_model_registry as smdmr  # MedSAM / SAM-Med2D

__all__ = ["build_teacher_student"]


# ---------------------------------------------------------------------------
# Helper – pack CLI args into the minimal Namespace MedSAM expects
# ---------------------------------------------------------------------------


def _to_namespace(cli_args: Namespace) -> Namespace:
    """Extract attrs that `build_sam_*` expects and package into Namespace."""

    wanted = [
        # attributes explicitly used inside build_sam_* implementation
        "sammed2d_image_size",
        "adapter",
        "sam_checkpoint",
        "sammed2d_checkpoint",
    ]
    d: Dict[str, Any] = {k: getattr(cli_args, k, None) for k in wanted}
    return SimpleNamespace(**d)


class SAMPartSwitcher:
    """Manage two *families* of SAM parameters and switch them at runtime."""

    def __init__(
        self,
        orig_sam: nn.Module,
        interp_sam: nn.Module,
        *,
        global_attn_idx: Sequence[int] = (2, 5, 7, 8, 11, 13, 15, 23, 31),
        free_orig: bool = True,
    ) -> None:
        self.interp_sam = interp_sam
        self._global_idx = tuple(global_attn_idx)

        # --- snapshot both variants ---------------------------------------
        self._prompt_orig, self._posrel_orig = self._extract_parts(orig_sam)
        self._prompt_interp, self._posrel_interp = self._extract_parts(interp_sam)

        if free_orig:
            self._detach_prompt_posrel(orig_sam)

        # current state bookkeeping
        self._state_prompt = "interp"
        self._state_posrel = "interp"

    # ------------------------------------------------------------------
    #   Public switches (prompt / decoder)
    # ------------------------------------------------------------------
    def to_orig_prompt(self):
        if self._state_prompt == "orig":
            return
        self._apply_prompt(self._prompt_orig)
        self._state_prompt = "orig"

    def to_interp_prompt(self):
        if self._state_prompt == "interp":
            return
        self._apply_prompt(self._prompt_interp)
        self._state_prompt = "interp"

    # ------------------------------------------------------------------
    #   Public switches (pos_embed & rel_pos*)
    # ------------------------------------------------------------------
    def to_orig_posrel(self):
        if self._state_posrel == "orig":
            return
        self._apply_posrel(self._posrel_orig)
        self._state_posrel = "orig"

    def to_interp_posrel(self):
        if self._state_posrel == "interp":
            return
        self._apply_posrel(self._posrel_interp)
        self._state_posrel = "interp"

    # convenience helpers ----------------------------------------------------
    def to_orig_all(self):
        self.to_orig_prompt()
        self.to_orig_posrel()

    def to_interp_all(self):
        self.to_interp_prompt()
        self.to_interp_posrel()

    # ------------------------------------------------------------------
    #   Interpolate current pos‑rel to match interp_sam img_size
    # ------------------------------------------------------------------
    def interpolate(self):
        """Resize *current* positional encodings in‑place to fit interp size."""
        sam = self.interp_sam
        token = sam.image_encoder.img_size // sam.image_encoder.patch_size

        # pos_embed ------------------------------------------------------
        pe = sam.image_encoder.pos_embed
        if pe.shape[1] != token:
            with torch.no_grad():
                tmp = pe.permute(0, 3, 1, 2)
                tmp = F.interpolate(
                    tmp, (token, token), mode="bilinear", align_corners=False
                )
                pe.data.copy_(tmp.permute(0, 2, 3, 1))

        # helper for 1‑D vectors
        def _resize_1d(t: torch.Tensor, tgt: int):
            if t.shape[0] == tgt:
                return t
            t4 = t.T[None, None]
            t4 = F.interpolate(t4, size=tgt, mode="linear", align_corners=False)
            return t4[0, 0].T

        # rel_pos / rel_pos_h / rel_pos_w -------------------------------
        for idx in self._global_idx:
            if idx >= len(sam.image_encoder.blocks):
                continue
            attn = sam.image_encoder.blocks[idx].attn
            if hasattr(attn, "rel_pos"):
                rp = attn.rel_pos
                if rp.shape != (token, token):
                    with torch.no_grad():
                        tmp = rp[None, None]
                        tmp = F.interpolate(
                            tmp,
                            size=(token, token),
                            mode="bilinear",
                            align_corners=False,
                        )
                        rp.data.copy_(tmp[0, 0])
            if hasattr(attn, "rel_pos_h"):
                tgt_len = 2 * token - 1
                with torch.no_grad():
                    attn.rel_pos_h.data.copy_(_resize_1d(attn.rel_pos_h.data, tgt_len))
                    attn.rel_pos_w.data.copy_(_resize_1d(attn.rel_pos_w.data, tgt_len))

    # ------------------------------------------------------------------
    #   Internal helpers
    # ------------------------------------------------------------------
    def _extract_parts(self, sam: nn.Module):
        prompt = {
            "prompt_encoder": sam.prompt_encoder,
            "mask_decoder": sam.mask_decoder,
        }
        posrel = {
            "pos_embed": sam.image_encoder.pos_embed,
            "rel_pos": {},
            "rel_pos_h": {},
            "rel_pos_w": {},
        }
        for idx in self._global_idx:
            if idx >= len(sam.image_encoder.blocks):
                continue
            attn = sam.image_encoder.blocks[idx].attn
            for key in ("rel_pos", "rel_pos_h", "rel_pos_w"):
                if hasattr(attn, key):
                    posrel[key][idx] = getattr(attn, key)
        return prompt, posrel

    def _apply_prompt(self, d: Dict[str, nn.Module]):
        sam = self.interp_sam
        sam.prompt_encoder = d["prompt_encoder"]
        sam.mask_decoder = d["mask_decoder"]

    def _apply_posrel(self, d: Dict):
        sam = self.interp_sam
        sam.image_encoder.pos_embed = d["pos_embed"]
        for idx, obj in d["rel_pos"].items():
            sam.image_encoder.blocks[idx].attn.rel_pos = obj
        for idx, obj in d["rel_pos_h"].items():
            sam.image_encoder.blocks[idx].attn.rel_pos_h = obj
        for idx, obj in d["rel_pos_w"].items():
            sam.image_encoder.blocks[idx].attn.rel_pos_w = obj

    @staticmethod
    def _detach_prompt_posrel(sam: nn.Module):
        del sam.prompt_encoder, sam.mask_decoder, sam.image_encoder.pos_embed
        for blk in sam.image_encoder.blocks:
            attn = blk.attn
            for attr in ("rel_pos", "rel_pos_h", "rel_pos_w"):
                if hasattr(attn, attr):
                    delattr(attn, attr)


# ---------------------------------------------------------------------------
# Public factory
# ---------------------------------------------------------------------------


def build_teacher_student(cli_args: Namespace, device: torch.device):
    variant = cli_args.variant

    # build a minimal args Namespace for the underlying registry
    ns = _to_namespace(cli_args)

    # ---- student (MedSAM) --------------------------------------------------
    student = smdmr[variant](args=ns).to(device)
    student_org = smr[variant](args=ns).to(device)
    switcher = SAMPartSwitcher(student_org, student)
    switcher.to_orig_all()
    del student_org

    # ---- teacher (original SAM) -------------------------------------------
    teacher = copy.deepcopy(student)
    teacher.to(device)
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad = False

    return student, teacher, switcher
