# vit_parallel_repa_local.py
"""
Parallel ViT with multi-branch MHSA where branch 0 is global attention and
other branches are local (2D patch-grid aware) attention.

Author: ChatGPT (generated)
Date: 2025-08-26
"""

import math
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------
# Patch embedding (unchanged)
# -------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=384):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


# -------------------------
# Custom single-head MHSA core (unchanged)
# -------------------------
class CustomMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=6, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward_qkv(self, x):
        """
        return Q, K, V shaped [B, heads, N, head_dim]
        """
        B, N, C = x.shape
        q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        return q, k, v

    def forward_from_qkv(self, q, k, v, attn_probs: Optional[torch.Tensor] = None):
        """
        given q,k,v with shape [B,heads,N,hd], optionally given attn_probs [B,heads,N,N],
        compute final output [B,N,C]. This helper allows reuse when we have precomputed probs.
        """
        if attn_probs is None:
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_probs = self.dropout(attn_probs)
        out = torch.matmul(attn_probs, v)  # [B,heads,N,hd]
        out = out.transpose(1, 2).contiguous().view(q.shape[0], q.shape[2], self.embed_dim)  # [B,N,C]
        out = self.out_proj(out)
        return out, attn_probs


# -------------------------
# Local attention (2D patch-grid aware)
# -------------------------
class LocalMultiHeadAttention(CustomMultiHeadAttention):
    """
    Local attention that only allows tokens to attend to keys within a 2D window
    around the query token on the patch grid. Index 0 is assumed to be CLS token
    and will be allowed to attend globally (and be attended by everyone).
    """

    def __init__(self, embed_dim, num_heads=6, dropout=0.0, win_radius: int = 3, grid_size: Optional[int] = None):
        """
        win_radius: radius in patches (e.g., 3 => window size = 7x7)
        grid_size: sqrt(num_patches); required to build 2D mapping
        """
        super().__init__(embed_dim, num_heads, dropout)
        assert grid_size is not None, "grid_size (sqrt(num_patches)) must be provided for 2D local attention"
        self.win_radius = win_radius
        self.grid_size = grid_size

        # cache masks by (N, device, dtype) could be implemented; keep simple and rebuild each forward
        # but we will provide a small method to build mask given N

    def _build_2d_local_mask(self, N, device, dtype):
        """
        Build mask of shape [N, N] with 0 for allowed positions and -inf for disallowed.
        We assume index 0 is CLS, and indices 1..N-1 correspond to patches in row-major order.
        """
        if N <= 1:
            return torch.zeros((N, N), device=device, dtype=dtype)

        # patch token count
        n_patches = N - 1
        g = self.grid_size
        assert g * g == n_patches, f"grid_size {g} doesn't match n_patches {n_patches}"

        # create mask filled with -inf
        mask = torch.full((N, N), float("-inf"), device=device, dtype=dtype)

        # allow cls (index 0) to attend to everyone and vice versa
        mask[0, :] = 0.0  # cls attends all
        mask[:, 0] = 0.0  # all attend cls

        # for each patch index, compute allowed neighborhood
        # patch indices map: token_idx -> patch_idx = token_idx - 1
        for token_i in range(1, N):
            pi = token_i - 1
            row_i = pi // g
            col_i = pi % g
            # allowed patch coords in window
            r0 = max(0, row_i - self.win_radius)
            r1 = min(g - 1, row_i + self.win_radius)
            c0 = max(0, col_i - self.win_radius)
            c1 = min(g - 1, col_i + self.win_radius)
            # iterate allowed patches and set zero
            for rr in range(r0, r1 + 1):
                for cc in range(c0, c1 + 1):
                    pj = rr * g + cc
                    token_j = pj + 1
                    mask[token_i, token_j] = 0.0
        return mask  # [N, N]

    def forward_from_qkv(self, q, k, v, attn_probs: Optional[torch.Tensor] = None):
        """
        q,k,v: [B,heads,N,hd]
        For local attention we build a mask and add to attn_scores before softmax.
        """
        B, heads, N, hd = q.shape
        device = q.device
        dtype = q.dtype

        if attn_probs is None:
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # [B,heads,N,N]

            # build 2D mask [N,N] and broadcast to [B,heads,N,N]
            mask_2d = self._build_2d_local_mask(N, device, dtype)  # [N, N]
            # Expand mask to all batches and heads by adding new axes
            mask = mask_2d.unsqueeze(0).unsqueeze(0)  # [1,1,N,N]
            attn_scores = attn_scores + mask  # broadcast

            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_probs = self.dropout(attn_probs)

        out = torch.matmul(attn_probs, v)  # [B,heads,N,hd]
        out = out.transpose(1, 2).contiguous().view(B, N, self.embed_dim)
        out = self.out_proj(out)
        return out, attn_probs


# -------------------------
# Multi-branch MHSA with lambda-weighting (modified to accept grid_size)
# -------------------------
class LambdaMultiBranchMHSA(nn.Module):
    """
    Multi-branch MHSA with symmetric lambda_off controlling cross-branch logits mixing.

    Behavior:
      - For each branch j compute L_j = Q_j @ K_j^T / sqrt(d)
      - For branch i compute L_i_weighted = L_i + lambda_off * (sum_{j != i} L_j)
      - Then probs_i = softmax(L_i_weighted / temperature), output_i = probs_i @ V_i_proj_out
      - Final output = average(outputs)
    """

    def __init__(self, embed_dim, num_heads=6, num_branches=2, dropout=0.0,
                 temperature_mode="adaptive", base_temperature=1.0, local_win_radius=3,
                 grid_size: Optional[int] = None):
        super().__init__()
        self.num_branches = num_branches
        self.temperature_mode = temperature_mode  # "fixed", "adaptive", "learnable"
        self.base_temperature = base_temperature

        assert grid_size is not None, "LambdaMultiBranchMHSA requires grid_size (sqrt(num_patches)) for local branches"
        self.grid_size = grid_size

        self.branches: nn.ModuleList = nn.ModuleList()
        for i in range(num_branches):
            if i == 0:
                # first branch global attention
                self.branches.append(CustomMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout))
            else:
                # other branches local
                self.branches.append(LocalMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout,
                                                            win_radius=local_win_radius, grid_size=grid_size))

        self.norm = nn.LayerNorm(embed_dim)
        # logits_norm used earlier: we will keep the same shape normalization trick
        self.logits_norm = nn.LayerNorm(self.branches[0].num_heads)

        # Learnable temperature parameter
        if temperature_mode == "learnable":
            self.temperature = nn.Parameter(torch.tensor(base_temperature))

    def _get_temperature(self, lambda_off: float):
        """Calculate temperature based on the chosen mode"""
        if self.temperature_mode == "fixed":
            return self.base_temperature
        elif self.temperature_mode == "adaptive":
            # Adaptive temperature: increase temperature as lambda_off increases
            # When lambda_off=0: temperature = base_temperature
            # When lambda_off=1: temperature = base_temperature * num_branches (compensate for summing)
            return self.base_temperature * (1.0 + lambda_off * (self.num_branches - 1.0))
        elif self.temperature_mode == "learnable":
            return torch.clamp(self.temperature, min=0.1, max=10.0)  # Clamp to reasonable range
        else:
            return self.base_temperature

    def forward(self, x, lambda_off: float):
        """
        x: [B,N,D]
        lambda_off: scalar in [0,1] controlling off-diagonal mixing. 0 => independent branches.
        """
        x_ = self.norm(x)
        # compute per-branch q,k,v
        q_list, k_list, v_list = [], [], []
        Lj_list = []
        for branch in self.branches:
            q, k, v = branch.forward_qkv(x_)  # each [B,heads,N,hd]
            q_list.append(q)
            k_list.append(k)
            v_list.append(v)
            Lj = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(branch.head_dim)  # [B,heads,N,N]
            Lj_list.append(Lj)

        # sum of all Lj
        sum_L = sum(Lj_list)  # [B,heads,N,N]

        # Get temperature for this forward pass
        temperature = self._get_temperature(lambda_off)

        outs = []
        attn_probs_list = []
        # for each branch i, build weighted logits: L_i + lambda_off*(sum_L - L_i)
        for i in range(self.num_branches):
            Li = Lj_list[i]
            if self.num_branches == 1:
                L_i_weighted = Li
            else:
                L_i_weighted = Li + lambda_off * (sum_L - Li)

            # Optional normalization across logits (preserve shapes)
            # Note: we keep your previous approach using a LayerNorm over the 'heads' dim by reshaping
            logits_reshaped = L_i_weighted.permute(0, 2, 3, 1)  # [B, N, N, heads]
            normalized = self.logits_norm(logits_reshaped)
            L_i_scaled = normalized.permute(0, 3, 1, 2)  # back to [B, heads, N, N]

            attn_probs = torch.softmax(L_i_scaled, dim=-1)
            # pass attn_probs to branch's forward_from_qkv so branch can skip recomputing logits
            out_i, _ = self.branches[i].forward_from_qkv(q_list[i], k_list[i], v_list[i], attn_probs=attn_probs)
            outs.append(out_i)
            attn_probs_list.append(attn_probs)

        # average outputs across branches (you may choose sum if you prefer)
        out = sum(outs) / float(self.num_branches)
        return out, Lj_list, attn_probs_list  # return Lj_list for potential fusion utils


# -------------------------
# Multi-branch MLP with lambda mixing (unchanged)
# -------------------------
class LambdaMultiBranchMLP(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, num_branches=2, dropout=0.0):
        super().__init__()
        self.num_branches = num_branches
        hidden = int(embed_dim * mlp_ratio)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, embed_dim),
                nn.Dropout(dropout),
            ) for _ in range(num_branches)
        ])

    def forward(self, x, lambda_off: float):
        ys = []
        hiddens = []

        for b in self.branches:
            norm = b[0]
            fc1 = b[1]
            h = fc1(norm(x))
            hiddens.append(h)

        sum_hidden = sum(hiddens)
        mixed_hiddens = []
        for i, h in enumerate(hiddens):
            if self.num_branches == 1:
                h_w = h
            else:
                h_w = h + lambda_off * (sum_hidden - h)
            mixed_hiddens.append(h_w)

        for i, b in enumerate(self.branches):
            act = b[2]  # GELU
            drop1 = b[3]  # Dropout
            fc2 = b[4]  # Linear(hidden -> embed_dim)
            drop2 = b[5]  # Dropout

            y = drop2(fc2(drop1(act(mixed_hiddens[i]))))
            ys.append(y)

        out = sum(ys) / float(self.num_branches)
        return out, ys


# -------------------------
# Transformer block with lambda-controlled branches
# -------------------------
class ParallelTransformerBlockWithLambda(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, mlp_ratio=4.0,
                 dropout=0.0, attn_branches=2, mlp_branches=2,
                 temperature_mode="adaptive", base_temperature=1.0,
                 local_win_radius=3, grid_size: Optional[int] = None):
        super().__init__()
        self.attn = LambdaMultiBranchMHSA(embed_dim, num_heads=num_heads,
                                          num_branches=attn_branches, dropout=dropout,
                                          temperature_mode=temperature_mode,
                                          base_temperature=base_temperature,
                                          local_win_radius=local_win_radius,
                                          grid_size=grid_size)
        self.mlp = LambdaMultiBranchMLP(embed_dim, mlp_ratio=mlp_ratio,
                                        num_branches=mlp_branches, dropout=dropout)

    def forward(self, x, lambda_off: float):
        x = x + self.attn(x, lambda_off)[0]
        x = x + self.mlp(x, lambda_off)[0]
        return x


# -------------------------
# Custom ModuleList for transformer blocks
# -------------------------
class LambdaTransformerBlocks(nn.Module):
    """
    Custom wrapper to handle multiple transformer blocks that need lambda_off parameter
    """

    def __init__(self, blocks):
        super().__init__()
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x, lambda_off: float):
        for block in self.blocks:
            x = block(x, lambda_off)
        return x


# -------------------------
# ParallelViT (modified to pass grid_size to blocks)
# -------------------------
class ParallelViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200,
                 embed_dim=192, depth=6, num_heads=12, mlp_ratio=4.0, dropout=0.1,
                 attn_branches=2, mlp_branches=2, temperature_mode="adaptive",
                 base_temperature=1.0, local_win_radius=3):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches
        grid_size = int(math.sqrt(num_patches))
        assert grid_size * grid_size == num_patches, "num_patches must be a perfect square for 2D local attention."

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Use custom wrapper instead of nn.Sequential
        blocks = [
            ParallelTransformerBlockWithLambda(embed_dim, num_heads, mlp_ratio, dropout,
                                               attn_branches=attn_branches, mlp_branches=mlp_branches,
                                               temperature_mode=temperature_mode,
                                               base_temperature=base_temperature,
                                               local_win_radius=local_win_radius,
                                               grid_size=grid_size)
            for _ in range(depth)
        ]
        self.blocks = LambdaTransformerBlocks(blocks)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, lambda_off: float):
        B = x.shape[0]
        x = self.patch_embed(x)             # [B, N_patches, D]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)  # [B, N_patches+1, D]
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x, lambda_off)
        x = self.norm(x)
        return self.head(x[:, 1:].mean(dim=1))


def get_parallel_vit(num_classes=200, attn_branches=2, mlp_branches=2,
                     temperature_mode="adaptive", base_temperature=1.0,
                     dropout=0.1,
                     **kwargs):
    return ParallelViT(num_classes=num_classes,
                       attn_branches=attn_branches,
                       mlp_branches=mlp_branches,
                       temperature_mode=temperature_mode,
                       base_temperature=base_temperature,
                       dropout=dropout
                       **kwargs)


# -------------------------
# Lambda scheduler
# -------------------------
class LambdaScheduler:
    """
    Ramp lambda_off from 0 -> 1 over warmup_steps (linear or cosine).
    Usage: scheduler.get_lambda(global_step) -> returns lambda_off scalar in [0,1]
    """
    def __init__(self, warmup_steps: int, mode: str = "linear"):
        assert warmup_steps >= 1
        assert mode in ("linear", "cosine")
        self.warmup_steps = warmup_steps
        self.mode = mode

    def get_lambda(self, step: float):
        if step <= 0:
            return 0.0
        if step >= self.warmup_steps:
            return 1.0
        t = step / float(self.warmup_steps)
        if self.mode == "linear":
            return float(t)
        elif self.mode == 'cosine':
            return float(0.5 * (1 - math.cos(math.pi * t)))
        elif self.mode == 'exponential':
            return 1 - math.exp(-5 * t)
        elif self.mode == "sqrt":
            return float(math.sqrt(t))
        elif self.mode == "sine":
            return float(math.sin(1 / 2 * math.pi * t))
        elif self.mode == "smoothstep":
            return float(3 * t ** 2 - 2 * t ** 3)
        else:
            print(f'NO such kind of lambda policy {self.mode}')
            return None


# -------------------------
# Fusion utilities (kept as you had them, minor adjustments)
# -------------------------
def fuse_exact_to_runtime_attention(mhsa: LambdaMultiBranchMHSA, x: torch.Tensor):
    """
    Exact runtime fusion:
      - compute L_j for each branch (using branch q,k)
      - S = sum_j L_j
      - probs = softmax(S)
      - V_combined = sum_j V_j_proj(x)
      - out_combined = out_proj_combined( probs @ V_combined )
    This returns same outputs as averaged branch outputs when lambda_off == 1.
    """
    # preprocess
    x_ = mhsa.norm(x)
    Lj_list = []
    v_merged_no_out = []
    B, N, D = x_.shape
    for branch in mhsa.branches:
        q, k, v = branch.forward_qkv(x_)
        Lj = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(branch.head_dim)
        Lj_list.append(Lj)
        v_merged_no_out.append(v.transpose(1, 2).contiguous().view(B, N, branch.embed_dim))

    S = sum(Lj_list)  # [B,heads,N,N]
    probs = torch.softmax(S, dim=-1)
    V_sum = sum(v_merged_no_out)  # [B,N,D]
    # split V_sum into heads
    head_count = mhsa.branches[0].num_heads
    head_dim = mhsa.branches[0].head_dim
    Vsum_heads = V_sum.view(B, N, head_count, head_dim).transpose(1, 2)  # [B,heads,N,hd]
    out_heads = torch.matmul(probs, Vsum_heads)  # [B,heads,N,hd]
    out = out_heads.transpose(1, 2).contiguous().view(B, N, mhsa.branches[0].embed_dim)

    # build combined out_proj by summing weights and biases across branches
    total_out_w = sum([b.out_proj.weight for b in mhsa.branches])
    total_out_b = sum([b.out_proj.bias for b in mhsa.branches]) if mhsa.branches[0].out_proj.bias is not None else None
    out_final = torch.nn.functional.linear(out, total_out_w, total_out_b)
    return out_final


def merge_params_approx(mhsa: LambdaMultiBranchMHSA):
    """
    Approximate param-level merge:
      - Sum q_proj.weight/bias, k_proj, v_proj, out_proj across branches.
      - Return a single CustomMultiHeadAttention with those summed params.
    NOTE: This is an approximation, not necessarily exact mathematically.
    """
    nb = mhsa.num_branches
    device = next(mhsa.parameters()).device
    b0 = mhsa.branches[0]
    embed_dim = b0.embed_dim; num_heads = b0.num_heads; dropout = b0.dropout.p if hasattr(b0, 'dropout') else 0.0
    merged = CustomMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout).to(device)

    # sum weights helper
    def sum_param(name):
        ps = [getattr(b, name) for b in mhsa.branches]
        ws = [p.weight.data for p in ps]
        wb = sum(ws)
        if ps[0].bias is not None:
            bs = [p.bias.data for p in ps]
            bb = sum(bs)
        else:
            bb = None
        return wb, bb

    # q
    wq, bq = sum_param('q_proj')
    merged.q_proj.weight.data.copy_(wq)
    if bq is not None:
        merged.q_proj.bias.data.copy_(bq)

    wk, bk = sum_param('k_proj')
    merged.k_proj.weight.data.copy_(wk)
    if bk is not None:
        merged.k_proj.bias.data.copy_(bk)

    wv, bv = sum_param('v_proj')
    merged.v_proj.weight.data.copy_(wv)
    if bv is not None:
        merged.v_proj.bias.data.copy_(bv)

    wo, bo = sum_param('out_proj')
    merged.out_proj.weight.data.copy_(wo)
    if bo is not None:
        merged.out_proj.bias.data.copy_(bo)

    return merged


# -------------------------
# Quick test (if run as script)
# -------------------------
if __name__ == "__main__":
    # small smoke test
    model = get_parallel_vit(img_size=32, patch_size=8, in_chans=3, num_classes=10,
                             embed_dim=64, depth=2, num_heads=4,
                             attn_branches=3, mlp_branches=2, local_win_radius=1)
    dummy = torch.randn(2, 3, 32, 32)
    lambda_off = 0.5
    out = model(dummy, lambda_off)
    print("Output shape:", out.shape)  # should be [B, num_classes]
