

"""
Vision Transformer with a learnable, per-sample calibration head
================================================================

Adds an instance-wise temperature-scaling module that is trained jointly
with the main task.  The extra head predicts a *positive* scale s ≥ 0
from the CLS embedding; logits are divided by `s` before soft-max.

Suitable for ViT/DeiT/any Transformer that exposes a pooled token.
"""

from __future__ import annotations     # keeps forward-refs cheap on 3.7+

import math
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------------------------------
# 1) Patch embedding (unchanged)
# ---------------------------------------------------------------------
class PatchEmbed(nn.Module):
    def __init__(self,
                 img_size: int  = 224,
                 patch_size: int = 16,
                 in_chans: int  = 3,
                 embed_dim: int = 768):
        super().__init__()
        self.grid_size  = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (B,3,H,W) → (B,D,H/ps,W/ps) → (B,N,D)
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


# ---------------------------------------------------------------------
# 2) Attention, MLP, Block (standard ViT parts - unchanged)
# ---------------------------------------------------------------------
class Attention(nn.Module):
    def __init__(self, dim: int,
                 num_heads: int = 8,
                 attn_drop: float = 0.0,
                 proj_drop: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj      = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3,
                                  self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)          # (3,B,h,N,hd)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.attn_drop(attn.softmax(dim=-1))

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj_drop(self.proj(x))


class MLP(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 drop: float = 0.0):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features    = out_features  or in_features
        self.fc1  = nn.Linear(in_features,  hidden_features)
        self.act  = nn.GELU()
        self.fc2  = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.drop(self.act(self.fc1(x)))
        x = self.drop(self.fc2(x))
        return x


class Block(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 mlp_ratio: float = 4.0,
                 attn_drop: float = 0.0,
                 proj_drop: float = 0.0,
                 drop_path: float = 0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = Attention(dim, num_heads, attn_drop, proj_drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = MLP(dim, int(dim * mlp_ratio), drop=proj_drop)
        self.drop_path_rate = drop_path

    # stochastic depth
    @staticmethod
    def _drop_path(x: torch.Tensor,
                   drop_prob: float,
                   training: bool) -> torch.Tensor:
        if drop_prob == 0.0 or not training:
            return x
        keep = 1 - drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rnd = keep + torch.rand(shape, device=x.device, dtype=x.dtype)
        rnd.floor_()
        return x / keep * rnd

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self._drop_path(self.attn(self.norm1(x)),
                                self.drop_path_rate, self.training)
        x = x + self._drop_path(self.mlp(self.norm2(x)),
                                self.drop_path_rate, self.training)
        return x


# ---------------------------------------------------------------------
# 3)  Calibration head
# ---------------------------------------------------------------------
class CalibrationHead(nn.Module):
    """
    Learns an instance-wise temperature   s(x) > 0   from the CLS embedding.
    Softplus guarantees positivity; parameters are initialised so that
    s ≈ 1 at the start of training.
    """
    def __init__(self, dim: int, hidden: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, 1)

        nn.init.constant_(self.fc2.weight, 0.)
        nn.init.constant_(self.fc2.bias,
                          math.log(math.e**1 - 1))   # Softplus ≈ 1

    def forward(self, cls_vec: torch.Tensor) -> torch.Tensor:
        # cls_vec : (B, D)
        s = F.softplus(self.fc2(self.act(self.fc1(cls_vec)))) + 1.0e-6
        return s                       # (B, 1)


# ---------------------------------------------------------------------
# 4)  Vision Transformer w/ joint calibration + logging
# ---------------------------------------------------------------------
class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size: int = 224,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 num_classes: int = 1000,
                 embed_dim: int = 768,
                 depth: int = 12,
                 num_heads: int = 12,
                 mlp_ratio: float = 4.0,
                 attn_drop: float = 0.0,
                 proj_drop: float = 0.0,
                 drop_path: float = 0.0,
                 calib_hidden: int = 128,
                 temp=1.0):
        super().__init__()

        # --- patch embedding & positional encodings ------------------------
        self.patch_embed = PatchEmbed(img_size, patch_size,
                                      in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=proj_drop)

        # --- transformer encoder -----------------------------------------
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio,
                  attn_drop, proj_drop, drop_path)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.temp=temp

        # --- heads --------------------------------------------------------
        self.head       = nn.Linear(embed_dim, num_classes)
        self.calib_head = CalibrationHead(embed_dim, calib_hidden)

        self._init_weights()

        # --- logging buffers (CLS-norm & scale) ---------------------------
        self._log_enabled: bool = False
        self._z_trace: List[torch.Tensor] = []
        self._s_trace: List[torch.Tensor] = []

    # -----------------------------------------------------------------
    def _init_weights(self) -> None:
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    # -----------------------------------------------------------------
    # convenient toggles for logging
    # -----------------------------------------------------------------
    def enable_logging(self, flag: bool = True) -> None:
        """Turn recording of ‖z_cls‖ and s(z) on/off."""
        self._log_enabled = flag

    def get_logs(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Returns:
            z_list : list of 1-D tensors  (‖z_cls‖ per sample, per batch)
            s_list : list of 1-D tensors  (scale s   per sample, per batch)
        """
        return self._z_trace, self._s_trace

    def reset_logs(self) -> None:
        """Empty internal buffers (call between epochs)."""
        self._z_trace.clear()
        self._s_trace.clear()

    # -----------------------------------------------------------------
    def forward(self,
                x: torch.Tensor,
                return_scale: bool = False):
        """
        Args
        ----
            x            : input images (B,3,H,W)
            return_scale : if True, also return s(x)  (B,1)

        Returns
        -------
            logits               (B,C)
            (optionally) scale s (B,1)
        """
        B = x.size(0)
        x = self.patch_embed(x)                       # (B,N,D)

        cls_tok = self.cls_token.expand(B, -1, -1)    # (B,1,D)
        x = torch.cat((cls_tok, x), dim=1)
        x = self.pos_drop(x + self.pos_embed[:, :x.size(1)])

        for blk in self.blocks:
            x = blk(x)

        cls_vec = self.norm(x[:, 0])                  # (B,D)
        scale   = self.calib_head(cls_vec)            # (B,1)
        logits  = self.head(cls_vec) / scale          # calibrated logits

        # ----- optional recording ----------------------------------------
        if self._log_enabled:
            self._z_trace.append(cls_vec.norm(dim=1).detach().cpu())
            self._s_trace.append(scale.squeeze(1).detach().cpu())
        # -----------------------------------------------------------------

        if return_scale:
            return logits, scale
        return logits


# ---------------------------------------------------------------------
# 5) Factory helpers (tiny / small / base / large)
# ---------------------------------------------------------------------
def vitcalib_tiny_patch16_224(**kw) -> VisionTransformer:
    return VisionTransformer(img_size=224, patch_size=16,
                             embed_dim=192, depth=12, num_heads=3, **kw)


def vitcalib_small_patch16_224(**kw) -> VisionTransformer:
    return VisionTransformer(img_size=224, patch_size=16,
                             embed_dim=384, depth=12, num_heads=6, **kw)


def vitcalib_base_patch16_224(**kw) -> VisionTransformer:
    return VisionTransformer(img_size=224, patch_size=16,
                             embed_dim=768, depth=12, num_heads=12, **kw)


def vitcalib_large_patch16_224(**kw) -> VisionTransformer:
    return VisionTransformer(img_size=224, patch_size=16,
                             embed_dim=1024, depth=24, num_heads=16, **kw)
