import torch, numpy as np
from torch import nn, einsum
from einops import rearrange

# ---------- utility blocks ----------------------------------
class CyclicShift(nn.Module):
    def __init__(self, d): super().__init__(); self.d = d
    def forward(self, x):  return torch.roll(x, (self.d, self.d), (1, 2))

class Residual(nn.Module):
    def __init__(self, f): super().__init__(); self.f = f
    def forward(self, x, **k): return self.f(x, **k) + x

class PreNorm(nn.Module):
    def __init__(self, d, f): super().__init__(); self.n = nn.LayerNorm(d); self.f = f
    def forward(self, x, **k):  return self.f(self.n(x), **k)

class FeedForward(nn.Module):
    def __init__(self, d, hd): super().__init__(); self.net = nn.Sequential(
        nn.Linear(d, hd), nn.GELU(), nn.Linear(hd, d))
    def forward(self, x): return self.net(x)

# ---------- window helpers ----------------------------------
def rel_dist(ws):
    idx = torch.tensor([[i, j] for i in range(ws) for j in range(ws)])
    return idx[None] - idx[:, None]

def full_mask(ws):  # for non-shifted attention (no inf)
    return torch.zeros(ws**2, ws**2)

def shift_masks(ws, disp):
    """upper/lower and left/right masks for shifted windows"""
    m_ul = torch.zeros(ws**2, ws**2)
    m_ul[-disp*ws:, :-disp*ws] = m_ul[:-disp*ws, -disp*ws:] = float('-inf')
    m_lr = rearrange(m_ul, '(h1 w1) (h2 w2) -> h1 w1 h2 w2',
                     h1=ws, h2=ws)
    m_lr[:, -disp:, :, :-disp] = m_lr[:, :-disp, :, -disp:] = float('-inf')
    m_lr = rearrange(m_lr, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
    return m_ul, m_lr

# ---------- WindowAttention with auto-window ----------------
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim,
                 base_win, shifted, rel_pos=True):
        super().__init__()
        self.heads  = heads
        self.shift  = shifted
        self.base_w = base_win
        self.rel_pos = rel_pos
        inner = heads * head_dim
        self.scale = head_dim ** -0.5
        self.to_qkv = nn.Linear(dim, inner * 3, bias=False)
        self.to_out = nn.Linear(inner, dim)

        # Parameters created lazily (because window can shrink)
        self.register_buffer('_mask_ul', None, persistent=False)
        self.register_buffer('_mask_lr', None, persistent=False)
        self.register_buffer('_rel_idx', None, persistent=False)
        self.pos_emb = nn.Parameter(torch.empty(
            (2*base_win-1, 2*base_win-1)
            if rel_pos else (base_win**2, base_win**2)))
        nn.init.trunc_normal_(self.pos_emb, std=.02)

    # --- util ------------------------------------------------
    def _ensure_buffers(self, ws, device):
        if self._rel_idx is None or self._rel_idx.size(0) != ws**2:
            self._rel_idx = rel_dist(ws).to(device) + ws - 1
        if self._mask_ul is None or self._mask_ul.size(0) != ws**2:
            disp = ws // 2
            self._mask_ul, self._mask_lr = \
                map(lambda t: t.to(device), shift_masks(ws, disp))

    # ---------------------------------------------------------
    def forward(self, x):
        # possibly cyclic shift
        if self.shift:
            x = CyclicShift(-(self.base_w//2))(x)

        B, H, W, C = x.shape
        # choose largest window that divides both H and W
        ws = self.base_w
        while H % ws or W % ws:
            ws -= 1
        assert ws >= 1, 'window collapsed to <1'

        h = self.heads
        nwh, nww = H // ws, W // ws
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        def win(t):
            return rearrange(
                t, 'b (nwh wh) (nww ww) (h d) -> b h (nwh nww) (wh ww) d',
                nwh=nwh, nww=nww, h=h, wh=ws, ww=ws)
        q, k, v = map(win, qkv)
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

        # positional
        if self.rel_pos:
            self._ensure_buffers(ws, x.device)
            idx = self._rel_idx[:ws**2, :ws**2]
            dots += self.pos_emb[idx[..., 0], idx[..., 1]]
        else:
            dots += self.pos_emb[:ws**2, :ws**2]

        # shifted masks
        if self.shift:
            self._ensure_buffers(ws, x.device)
            dots[:, :, -nww:]      += self._mask_ul[:ws**2, :ws**2]
            dots[:, :, nww-1::nww] += self._mask_lr[:ws**2, :ws**2]

        attn = dots.softmax(dim=-1)
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(
            out, 'b h (nwh nww) (wh ww) d -> b (nwh wh) (nww ww) (h d)',
            nwh=nwh, nww=nww, h=h, wh=ws, ww=ws)
        out = self.to_out(out)

        if self.shift:
            out = CyclicShift(self.base_w//2)(out)
        return out

# ---------- Swin blocks & stages -----------------------------
class SwinBlock(nn.Module):
    def __init__(self, d, h, hd, mlp, shift, win, rel):
        super().__init__()
        self.attn = Residual(PreNorm(d, WindowAttention(d, h, hd, win, shift, rel)))
        self.ff   = Residual(PreNorm(d, FeedForward(d, mlp)))
    def forward(self, x): return self.ff(self.attn(x))

class PatchMerging(nn.Module):
    def __init__(self, cin, cout, f):
        super().__init__()
        self.f = f
        self.unfold = nn.Unfold(f, stride=f)
        self.lin = nn.Linear(cin * f * f, cout)
    def forward(self, x):
        B, C, H, W = x.shape
        h2, w2 = H // self.f, W // self.f
        x = self.unfold(x).view(B, -1, h2, w2).permute(0, 2, 3, 1)
        return self.lin(x)              # (B,h2,w2,cout)

class Stage(nn.Module):
    def __init__(self, cin, cout, nlayer, f, heads, hd, win, rel):
        super().__init__()
        assert nlayer % 2 == 0
        self.merge = PatchMerging(cin, cout, f)
        blks=[]
        for _ in range(nlayer//2):
            blks += [SwinBlock(cout, heads, hd, cout*4, False, win, rel),
                     SwinBlock(cout, heads, hd, cout*4, True,  win, rel)]
        self.blks = nn.ModuleList(blks)
    def forward(self, x):
        x = self.merge(x)
        for b in self.blks: x = b(x)
        return x.permute(0,3,1,2)      # NCHW

# ---------- Swin Transformer --------------------------------
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim=96, layers=(2,2,6,2), heads=(3,6,12,24),
                 channels=3, num_classes=1000, head_dim=32, window_size=7,
                 down=(4,2,2,2), rel_pos=True, temp=1.0):
        super().__init__()
        self.temp = temp
        self.s1 = Stage(channels,        hidden_dim,      layers[0], down[0],
                        heads[0], head_dim, window_size, rel_pos)
        self.s2 = Stage(hidden_dim,      hidden_dim*2,    layers[1], down[1],
                        heads[1], head_dim, window_size, rel_pos)
        self.s3 = Stage(hidden_dim*2,    hidden_dim*4,    layers[2], down[2],
                        heads[2], head_dim, window_size, rel_pos)
        self.s4 = Stage(hidden_dim*4,    hidden_dim*8,    layers[3], down[3],
                        heads[3], head_dim, window_size, rel_pos)
        self.head = nn.Sequential(
            nn.LayerNorm(hidden_dim*8),
            nn.Linear(hidden_dim*8, num_classes))

    def forward(self, x):
        x = self.s1(x); x = self.s2(x); x = self.s3(x); x = self.s4(x)
        x = x.mean((2,3))                   # GAP
        return self.head(x) / self.temp

# ---------- factory helpers (match ViT / DeiT naming) ----------
def swin_tiny(temp: float = 1.0, **kw):
    return SwinTransformer(hidden_dim=96,
                           layers=(2, 2, 6, 2),
                           heads=(3, 6, 12, 24),
                           temp=temp, **kw)

def swin_small(temp: float = 1.0, **kw):
    return SwinTransformer(hidden_dim=96,
                           layers=(2, 2, 18, 2),
                           heads=(3, 6, 12, 24),
                           temp=temp, **kw)

def swin_big(temp: float = 1.0, **kw):
    return SwinTransformer(hidden_dim=128,
                           layers=(2, 2, 18, 2),
                           heads=(4, 8, 16, 32),
                           temp=temp, **kw)

def swin_large(temp: float = 1.0, **kw):
    return SwinTransformer(hidden_dim=192,
                           layers=(2, 2, 18, 2),
                           heads=(6, 12, 24, 48),
                           temp=temp, **kw)
