"""
Calibration-Aware DeiT
======================

A minimal Data-Efficient Image Transformer (DeiT) enhanced with an
instance-wise temperature-scaling head that is trained jointly with the
main task.  The extra head predicts a positive scale  s > 0  from the CLS
token; logits are divided by s before the soft-max, reducing
over/under-confidence and therefore Expected Calibration Error (ECE).

Reference (original DeiT):  
Touvron et al., *ICML 2021* – “Training Data-Efficient Image Transformers …”
https://github.com/facebookresearch/deit
"""

import torch, math
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

# ---------------------------------------------------------------------
# 1. Patch Embedding (unchanged)
# ---------------------------------------------------------------------
class PatchEmbed(nn.Module):
    def __init__(self,
                 img_size: int   = 224,
                 patch_size: int = 16,
                 in_chans: int   = 3,
                 embed_dim: int  = 768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)                      # (B, D, H', W')
        x = x.flatten(2).transpose(1, 2)      # (B, N, D)
        return x


# ---------------------------------------------------------------------
# 2. Attention, MLP, Transformer Block  (unchanged w.r.t. original mini-DeiT)
# ---------------------------------------------------------------------
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8,
                 attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (self.qkv(x)
               .reshape(B, N, 3, self.num_heads, C // self.num_heads)
               .permute(2, 0, 3, 1, 4))          # (3, B, heads, N, dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.attn_drop(attn.softmax(dim=-1))
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj_drop(self.proj(x))


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None,
                 out_features=None, drop=0.0):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features    = out_features    or in_features
        self.fc1  = nn.Linear(in_features, hidden_features)
        self.act  = nn.GELU()
        self.fc2  = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))


class Block(nn.Module):
    def __init__(self, dim, num_heads,
                 mlp_ratio=4.0,
                 attn_drop=0.0, proj_drop=0.0,
                 drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = Attention(dim, num_heads, attn_drop, proj_drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = MLP(dim, int(dim * mlp_ratio), drop=proj_drop)
        self.drop_path_rate = drop_path

    # stochastic depth
    @staticmethod
    def _drop_path(x, drop_prob, training):
        if drop_prob == 0. or not training:
            return x
        keep = 1 - drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand  = keep + torch.rand(shape, dtype=x.dtype, device=x.device)
        rand.floor_()
        return x / keep * rand

    def forward(self, x):
        x = x + self._drop_path(self.attn(self.norm1(x)),
                                self.drop_path_rate, self.training)
        x = x + self._drop_path(self.mlp(self.norm2(x)),
                                self.drop_path_rate, self.training)
        return x


# ---------------------------------------------------------------------
# 3. NEW – Calibration head
# ---------------------------------------------------------------------
class CalibrationHead(nn.Module):
    """
    Tiny MLP that maps the CLS embedding to a positive scale s ≥ 0.
    softplus guarantees positivity; initial weights give s ≈ 1.
    """
    def __init__(self, dim: int, hidden: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, 1)
        nn.init.constant_(self.fc2.weight, 0.)
        nn.init.constant_(self.fc2.bias, 0.)

    def forward(self, cls_vec: torch.Tensor) -> torch.Tensor:
        return F.softplus(self.fc2(self.act(self.fc1(cls_vec)))) + 1e-4


# ---------------------------------------------------------------------
# 4. Calibration-Aware DeiT
# ---------------------------------------------------------------------
class DeiT(nn.Module):
    def __init__(self,
                 img_size: int   = 224,
                 patch_size: int = 16,
                 in_chans: int   = 3,
                 num_classes: int = 1000,
                 embed_dim: int = 384,
                 depth: int = 12,
                 num_heads: int = 6,
                 mlp_ratio: float = 4.0,
                 attn_drop: float = 0.0,
                 proj_drop: float = 0.0,
                 drop_path: float = 0.0,
                 calib_hidden: int = 128,
                 temp=1.0):
        super().__init__()

        # ---- Embedding ------------------------------------------------
        self.patch_embed = PatchEmbed(img_size, patch_size,
                                      in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1,
                                                  embed_dim))
        self.pos_drop = nn.Dropout(p=proj_drop)
        self.temp=temp

        # ---- Transformer encoder --------------------------------------
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio,
                  attn_drop, proj_drop, drop_path)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # ---- Heads ----------------------------------------------------
        self.head       = nn.Linear(embed_dim, num_classes)
        self.calib_head = CalibrationHead(embed_dim, calib_hidden)

        self._init_weights()

    # -----------------------------------------------------------------
    def _init_weights(self):
        nn.init.normal_(self.cls_token,  std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=.02)
        nn.init.zeros_(self.head.bias)

    # -----------------------------------------------------------------
    def forward(self, x: torch.Tensor,
                return_scale: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args
        ----
        x : (B, 3, H, W)
        return_scale : if True, also returns per-sample temperature s.

        Returns
        -------
        logits  (B, num_classes)  and optionally  s  (B, 1)
        """
        B = x.size(0)
        x = self.patch_embed(x)                           # (B, N, D)
        cls = self.cls_token.expand(B, -1, -1)            # (B, 1, D)
        x = torch.cat((cls, x), dim=1)                    # prepend CLS
        x = self.pos_drop(x + self.pos_embed[:, :x.size(1), :])

        for blk in self.blocks:
            x = blk(x)

        cls_vec = self.norm(x[:, 0])                      # (B, D)
        scale   = self.calib_head(cls_vec)                # (B, 1)
        logits  = self.head(cls_vec) / scale              # scaled logits
        return (logits, scale) if return_scale else logits


# ---------------------------------------------------------------------
# 5. Factory helpers (DeiT-Tiny / -Small / -Base)
# ---------------------------------------------------------------------
def deitcalib_tiny(**kw):   # 192-dim, 3 heads
    return DeiT(embed_dim=192, num_heads=3, **kw)

def deitcalib_small(**kw):  # 384-dim, 6 heads
    return DeiT(embed_dim=384, num_heads=6, **kw)

def deitcalib_base(**kw):   # 768-dim, 12 heads
    return DeiT(embed_dim=768, num_heads=12, **kw)


# # ---------------------------------------------------------------------
# # 6. Example training snippet (cross-entropy + Brier)
# # ---------------------------------------------------------------------
# if __name__ == "__main__":
#     model = deit_small(num_classes=10).cuda()
#     opt   = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
#     criterion_ce    = nn.CrossEntropyLoss()
#     criterion_brier = nn.MSELoss()
#     λ = 0.1  # calibration weight

#     dummy_x = torch.randn(4, 3, 224, 224).cuda()
#     dummy_y = torch.randint(0, 10, (4,)).cuda()

#     logits, s = model(dummy_x, return_scale=True)
#     p = logits.softmax(1)
#     loss = criterion_ce(logits, dummy_y) + \
#            λ * criterion_brier(p,
#                F.one_hot(dummy_y, 10).float())
#     loss.backward()
#     opt.step()
#     print("loss:", loss.item(), "mean s:", s.mean().item())
