import torch
import torch.nn as nn
import math
from typing import List, Optional, Dict
import torch.nn.functional as F
from ptflops import get_model_complexity_info


# -------------------------
# Cosine Similarity Monitor
# -------------------------
class CosineSimilarityMonitor:
    """监控各个 branch 之间的 cosine similarity"""

    def __init__(self):
        self.reset()

    def reset(self):
        """重置监控数据"""
        self.attn_similarities = []
        self.mlp_similarities = []
        self.block_similarities = []

    @staticmethod
    def compute_pairwise_cosine_similarity(tensors: List[torch.Tensor],
                                           normalize_dim: int = -1) -> Dict[str, torch.Tensor]:
        """
        计算张量列表中所有两两之间的 cosine similarity

        Args:
            tensors: 张量列表，每个张量形状相同 [B, N, D] 或其他
            normalize_dim: 计算余弦相似度时的归一化维度

        Returns:
            similarity_dict: 包含各种统计信息的字典
        """
        if len(tensors) < 2:
            return {
                'pairwise_similarities': torch.tensor([]),
                'mean_similarity': torch.tensor(0.0),
                'min_similarity': torch.tensor(0.0),
                'max_similarity': torch.tensor(0.0),
                'similarity_matrix': torch.tensor([])
            }

        n_branches = len(tensors)
        device = tensors[0].device

        # 创建相似度矩阵
        similarity_matrix = torch.zeros(n_branches, n_branches, device=device)
        pairwise_similarities = []

        for i in range(n_branches):
            for j in range(n_branches):
                if i == j:
                    similarity_matrix[i, j] = 1.0
                elif i < j:  # 只计算上三角，避免重复
                    # 计算余弦相似度，在指定维度上
                    sim = F.cosine_similarity(tensors[i], tensors[j], dim=normalize_dim)
                    # 取所有位置的平均值
                    mean_sim = sim.mean().item()
                    similarity_matrix[i, j] = mean_sim
                    similarity_matrix[j, i] = mean_sim  # 对称填充
                    pairwise_similarities.append(mean_sim)

        pairwise_similarities = torch.tensor(pairwise_similarities, device=device)

        return {
            'pairwise_similarities': pairwise_similarities,
            'mean_similarity': pairwise_similarities.mean(),
            'min_similarity': pairwise_similarities.min(),
            'max_similarity': pairwise_similarities.max(),
            'std_similarity': pairwise_similarities.std(),
            'similarity_matrix': similarity_matrix
        }

    def record_attn_similarities(self, branch_outputs: List[torch.Tensor],
                                 layer_idx: int = 0):
        """记录注意力分支的相似度"""
        sim_stats = self.compute_pairwise_cosine_similarity(branch_outputs)
        sim_stats['layer_idx'] = layer_idx
        sim_stats['type'] = 'attention'
        self.attn_similarities.append(sim_stats)

    def record_mlp_similarities(self, branch_outputs: List[torch.Tensor],
                                layer_idx: int = 0):
        """记录 MLP 分支的相似度"""
        sim_stats = self.compute_pairwise_cosine_similarity(branch_outputs)
        sim_stats['layer_idx'] = layer_idx
        sim_stats['type'] = 'mlp'
        self.mlp_similarities.append(sim_stats)

    def get_summary_stats(self) -> Dict[str, float]:
        """获取所有监控数据的汇总统计"""
        stats = {}

        if self.attn_similarities:
            attn_means = [s['mean_similarity'].item() for s in self.attn_similarities]
            stats['attn_mean_similarity'] = sum(attn_means) / len(attn_means)
            stats['attn_min_similarity'] = min([s['min_similarity'].item() for s in self.attn_similarities])
            stats['attn_max_similarity'] = max([s['max_similarity'].item() for s in self.attn_similarities])

        if self.mlp_similarities:
            mlp_means = [s['mean_similarity'].item() for s in self.mlp_similarities]
            stats['mlp_mean_similarity'] = sum(mlp_means) / len(mlp_means)
            stats['mlp_min_similarity'] = min([s['min_similarity'].item() for s in self.mlp_similarities])
            stats['mlp_max_similarity'] = max([s['max_similarity'].item() for s in self.mlp_similarities])

        return stats

    def print_summary(self):
        """打印汇总统计信息"""
        stats = self.get_summary_stats()
        print("\n=== Cosine Similarity Summary ===")
        for key, value in stats.items():
            print(f"{key}: {value:.4f}")


# -------------------------
# Patch embedding (unchanged)
# -------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=384):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


# -------------------------
# Custom single-head MHSA core
# -------------------------
class CustomMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=6, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward_qkv(self, x):
        """
        return Q, K, V shaped [B, heads, N, head_dim]
        """
        B, N, C = x.shape
        q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        return q, k, v

    def forward_from_qkv(self, q, k, v, attn_probs: Optional[torch.Tensor] = None):
        """
        given q,k,v with shape [B,heads,N,hd], optionally given attn_probs [B,heads,N,N],
        compute final output [B,N,C]. This helper allows reuse when we have precomputed probs.
        """
        if attn_probs is None:
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_probs = self.dropout(attn_probs)
        out = torch.matmul(attn_probs, v)  # [B,heads,N,hd]
        out = out.transpose(1, 2).contiguous().view(q.shape[0], q.shape[2], self.embed_dim)  # [B,N,C]
        out = self.out_proj(out)
        return out, attn_probs


# -------------------------
# Multi-branch MHSA with lambda-weighting and similarity monitoring
# -------------------------
class LambdaMultiBranchMHSA(nn.Module):
    """
    Multi-branch MHSA with symmetric lambda_off controlling cross-branch logits mixing.

    Behavior:
      - For each branch j compute L_j = Q_j @ K_j^T / sqrt(d)
      - For branch i compute L_i_weighted = L_i + lambda_off * sum_{j != i} L_j
      - Then probs_i = softmax(L_i_weighted / temperature), output_i = probs_i @ V_i_proj_out
      - Final output = average(outputs)  (or sum; here we average to be numerically stable)
    """

    def __init__(self, embed_dim, num_heads=6, num_branches=2, dropout=0.0,
                 temperature_mode="adaptive", base_temperature=1.0):
        super().__init__()
        self.num_branches = num_branches
        self.temperature_mode = temperature_mode  # "fixed", "adaptive", "learnable"
        self.base_temperature = base_temperature

        self.branches: nn.ModuleList = nn.ModuleList([
            CustomMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout)
            for _ in range(num_branches)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.logits_norm = nn.LayerNorm(self.branches[0].num_heads)

        # Learnable temperature parameter
        if temperature_mode == "learnable":
            self.temperature = nn.Parameter(torch.tensor(base_temperature))

    def _get_temperature(self, lambda_off: float):
        """Calculate temperature based on the chosen mode"""
        if self.temperature_mode == "fixed":
            return self.base_temperature
        elif self.temperature_mode == "adaptive":
            # Adaptive temperature: increase temperature as lambda_off increases
            # When lambda_off=0: temperature = base_temperature
            # When lambda_off=1: temperature = base_temperature * num_branches (compensate for summing)
            return self.base_temperature * (1.0 + lambda_off * (self.num_branches - 1.0))
        elif self.temperature_mode == "learnable":
            return torch.clamp(self.temperature, min=0.1, max=10.0)  # Clamp to reasonable range
        else:
            return self.base_temperature

    @staticmethod
    def _pairwise_diversity(feats_list: List[torch.Tensor]) -> torch.Tensor:
        """
        feats_list: list of tensors with shape [B, N, D] (or [B, D])
        returns: scalar tensor (mean over batch/tokens/pairs)
        """
        n = len(feats_list)
        if n < 2:
            return feats_list[0].new_tensor(0.0)
        div = 0.0
        cnt = 0
        for i in range(n):
            for j in range(i + 1, n):
                # 逐 token 的余弦相似度 -> [B, N] 或 [B]
                sim = F.cosine_similarity(feats_list[i], feats_list[j], dim=-1)
                div = div + sim.pow(2).mean()  # 负号鼓励差异
                cnt += 1
        return div / cnt

    def forward(self, x, lambda_off: float, get_diversity_loss=False,
                monitor: Optional[CosineSimilarityMonitor] = None, layer_idx: int = 0):
        """
        x: [B,N,D]
        lambda_off: scalar in [0,1] controlling off-diagonal mixing. 0 => independent branches.
        monitor: CosineSimilarityMonitor instance for recording similarities
        layer_idx: layer index for monitoring
        """
        x_ = self.norm(x)
        # compute per-branch q,k,v
        q_list, k_list, v_list = [], [], []
        # Also store per-branch logits L_j for later possible fusion
        Lj_list = []
        for branch in self.branches:
            q, k, v = branch.forward_qkv(x_)  # each [B,heads,N,hd]
            q_list.append(q)
            k_list.append(k)
            v_list.append(v)
            Lj = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(branch.head_dim)  # [B,heads,N,N]
            Lj_list.append(Lj)

        # sum of all Lj
        sum_L = sum(Lj_list)  # [B,heads,N,N]

        # Get temperature for this forward pass
        temperature = self._get_temperature(lambda_off)

        outs = []
        attn_probs_list = []
        # for each branch i, build weighted logits: L_i + lambda_off*(sum_L - L_i)
        for i in range(self.num_branches):
            Li = Lj_list[i]
            if self.num_branches == 1:
                L_i_weighted = Li
            else:
                L_i_weighted = Li + lambda_off * (sum_L - Li)

            # Apply temperature scaling before softmax
            L_i_scaled = L_i_weighted / (1+(lambda_off**2)*(self.num_branches-1))
            attn_probs = torch.softmax(L_i_scaled, dim=-1)

            # dropout can be applied to probs within CustomMultiHeadAttention.forward_from_qkv if desired
            out_i, _ = self.branches[i].forward_from_qkv(q_list[i], k_list[i], v_list[i], attn_probs=attn_probs)
            outs.append(out_i)
            attn_probs_list.append(attn_probs)

        # 监控 attention branch outputs 的相似度
        if monitor is not None:
            monitor.record_attn_similarities(outs, layer_idx)

        # average outputs across branches (you may choose sum if you prefer)
        out = sum(outs) / float(self.num_branches)
        if get_diversity_loss:
            div_attn = self._pairwise_diversity(outs)
            return out, Lj_list, attn_probs_list, div_attn  # return Lj_list so fusion util can reuse them
        else:
            return out, Lj_list, attn_probs_list


# -------------------------
# Multi-branch MLP with lambda mixing and similarity monitoring
# -------------------------
class LambdaMultiBranchMLP(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, num_branches=2, dropout=0.0):
        super().__init__()
        self.num_branches = num_branches
        hidden = int(embed_dim * mlp_ratio)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, embed_dim),
                nn.Dropout(dropout),
            ) for _ in range(num_branches)
        ])

    @staticmethod
    def _pairwise_diversity(feats_list: List[torch.Tensor]) -> torch.Tensor:
        n = len(feats_list)
        if n < 2:
            return feats_list[0].new_tensor(0.0)
        div = 0.0
        cnt = 0
        for i in range(n):
            for j in range(i + 1, n):
                sim = F.cosine_similarity(feats_list[i], feats_list[j], dim=-1)
                div = div + sim.pow(2).mean()
                cnt += 1
        return div / cnt

    def forward(self, x, lambda_off: float, get_diversity_loss=False,
                monitor: Optional[CosineSimilarityMonitor] = None, layer_idx: int = 0):
        ys = []
        hiddens = []

        for b in self.branches:
            norm = b[0]
            fc1 = b[1]
            h = fc1(norm(x))
            hiddens.append(h)

        sum_hidden = sum(hiddens)
        mixed_hiddens = []
        for i, h in enumerate(hiddens):
            if self.num_branches == 1:
                h_w = h
            else:
                h_w = h + lambda_off * (sum_hidden - h)
            mixed_hiddens.append(h_w)

        for i, b in enumerate(self.branches):
            act = b[2]  # GELU
            drop1 = b[3]  # Dropout
            fc2 = b[4]  # Linear(hidden -> embed_dim)
            drop2 = b[5]  # Dropout

            y = drop2(fc2(drop1(act(mixed_hiddens[i]))))
            ys.append(y)

        # 监控 MLP branch outputs 的相似度
        if monitor is not None:
            monitor.record_mlp_similarities(ys, layer_idx)

        out = sum(ys) / float(self.num_branches)
        if get_diversity_loss:
            # --- Branch diversity over MLP branch outputs ---
            div_mlp = self._pairwise_diversity(ys)
            return out, ys, div_mlp
        else:
            return out, ys


# -------------------------
# Transformer block with lambda-controlled branches and monitoring
# -------------------------
class ParallelTransformerBlockWithLambda(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, mlp_ratio=4.0,
                 dropout=0.0, attn_branches=2, mlp_branches=2,
                 temperature_mode="adaptive", base_temperature=1.0, drop_path=0.1,
                 attn_div_weight: float = 1.0, mlp_div_weight: float = 1.0):
        super().__init__()
        self.attn = LambdaMultiBranchMHSA(embed_dim, num_heads=num_heads,
                                          num_branches=attn_branches, dropout=dropout,
                                          temperature_mode=temperature_mode,
                                          base_temperature=base_temperature)
        self.mlp = LambdaMultiBranchMLP(embed_dim, mlp_ratio=mlp_ratio,
                                        num_branches=mlp_branches, dropout=dropout)
        # self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.drop_path = nn.Identity()
        self.attn_div_weight = attn_div_weight
        self.mlp_div_weight = mlp_div_weight

    def forward(self, x, lambda_off: float, get_diversity_loss=False,
                monitor: Optional[CosineSimilarityMonitor] = None, layer_idx: int = 0):
        if get_diversity_loss:
            attn_out, _, _, div_attn = self.attn(x, lambda_off, get_diversity_loss, monitor, layer_idx)
            x = x + attn_out
            mlp_out, _, div_mlp = self.mlp(x, lambda_off, get_diversity_loss, monitor, layer_idx)
            x = x + mlp_out
            block_div_loss = self.attn_div_weight * div_attn + self.mlp_div_weight * div_mlp / 2
            # print(f'div_attn:, {div_attn}, div_mlp, {div_mlp}')
            return x, block_div_loss
        else:
            x = x + self.attn(x, lambda_off, monitor=monitor, layer_idx=layer_idx)[0]
            x = x + self.mlp(x, lambda_off, monitor=monitor, layer_idx=layer_idx)[0]
            return x


# -------------------------
# Custom ModuleList for transformer blocks with monitoring
# -------------------------
class LambdaTransformerBlocks(nn.Module):
    """
    Custom wrapper to handle multiple transformer blocks that need lambda_off parameter
    """

    def __init__(self, blocks):
        super().__init__()
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x, lambda_off: float, get_diversity_loss=False,
                monitor: Optional[CosineSimilarityMonitor] = None):
        if get_diversity_loss:
            total_div = 0.0
            for i, block in enumerate(self.blocks):
                x, div_loss = block(x, lambda_off, get_diversity_loss, monitor, layer_idx=i)
                total_div += div_loss
            return x, total_div/len(self.blocks)
        else:
            for i, block in enumerate(self.blocks):
                x = block(x, lambda_off, monitor=monitor, layer_idx=i)
            return x


# -------------------------
# Main ViT model with monitoring support
# -------------------------
class ParallelViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200,
                 embed_dim=192, depth=6, num_heads=12, mlp_ratio=4.0, dropout=0.1,
                 attn_branches=2, mlp_branches=2, temperature_mode="adaptive",
                 base_temperature=1.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Use custom wrapper instead of nn.Sequential
        blocks = [
            ParallelTransformerBlockWithLambda(embed_dim, num_heads, mlp_ratio, dropout,
                                               attn_branches=attn_branches, mlp_branches=mlp_branches,
                                               temperature_mode=temperature_mode,
                                               base_temperature=base_temperature)
            for _ in range(depth)
        ]
        self.blocks = LambdaTransformerBlocks(blocks)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, lambda_off=0, get_diversity_loss=False,
                monitor: Optional[CosineSimilarityMonitor] = None):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        if get_diversity_loss:
            x, diversity_loss = self.blocks(x, lambda_off, get_diversity_loss, monitor)
            x = self.norm(x)
            return self.head(x[:, 1:].mean(dim=1)), diversity_loss
        else:
            x = self.blocks(x, lambda_off, monitor=monitor)
            x = self.norm(x)
            return self.head(x[:, 1:].mean(dim=1))


def get_parallel_vit(num_classes=200, attn_branches=2, mlp_branches=2, depth=6,
                     temperature_mode="adaptive", base_temperature=1.0, dropout=0.0):
    return ParallelViT(num_classes=num_classes,
                            attn_branches=attn_branches,
                            mlp_branches=mlp_branches,
                            depth=depth,
                            temperature_mode=temperature_mode,
                            base_temperature=base_temperature,
                            dropout=dropout)

# -------------------------
# Lambda scheduler
# -------------------------
class LambdaScheduler:
    """
    Ramp lambda_off from 0 -> 1 over warmup_steps (linear or cosine).
    Usage: scheduler.step(global_step) -> returns lambda_off scalar in [0,1]
    """
    def __init__(self, warmup_steps: int, mode: str = "linear", pw_a: float = 0.1, pw_r: float = 10.0, start_steps: int=0):
        assert warmup_steps >= 1
        assert mode in ("linear", "cosine", "exponential", "sqrt", "sine", "smoothstep", "sin2", "piecewise_linear")
        self.warmup_steps = warmup_steps
        self.start_steps = start_steps
        self.mode = mode
        self.pw_a = pw_a
        self.pw_r = pw_r

    def get_lambda(self, step: float):
        if step <= self.start_steps:
            return 0.0
        if step >= self.start_steps + self.warmup_steps:
            return 1.0
        t = (step-self.start_steps) / float(self.warmup_steps)
        if self.mode == "linear":
            return float(t)
        elif self.mode == 'cosine':
            return float(0.5 * (1 - math.cos(math.pi * t)))
        elif self.mode == 'exponential':
            return 1 - math.exp(-5 * t)
        elif self.mode == "sqrt":
            return float(math.sqrt(t))
        elif self.mode == "sine":
            return float(math.sin(1 / 2 * math.pi * t))
        elif self.mode == "smoothstep":
            return float(3 * t ** 2 - 2 * t ** 3)
        elif self.mode == "sin2":
            return math.sin(2 * math.pi * t) ** 2
        elif self.mode == "piecewise_linear":
            a, r = self.pw_a, self.pw_r
            denom = (1 - 2 * a + 2 * a * r)
            s_slow = 1.0 / denom
            s_fast = r * s_slow
            if t < a:
                v = s_fast * t + 0.2
            elif t < 1 - a:
                v = s_fast * a + s_slow * (t - a) + 0.2
            else:
                v = s_fast * a + s_slow * (1 - 2 * a) + s_fast * (t - (1 - a)) + 0.2
            if v < 0.0: v = 0.0
            if v > 1.0: v = 1.0
            return float(v)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")


# -------------------------
# Fusion utilities
# -------------------------
def fuse_exact_to_runtime_attention(mhsa: LambdaMultiBranchMHSA, x: torch.Tensor):
    """
    Exact runtime fusion:
      - compute L_j for each branch (using branch q,k)
      - S = sum_j L_j
      - probs = softmax(S)
      - V_combined = sum_j V_j_proj(x)
      - out_combined = out_proj_combined( probs @ V_combined )
    This returns same outputs as averaged branch outputs when lambda_off == 1.
    Note: still uses per-branch projections to compute L_j and V_j.
    """
    # preprocess
    x_ = mhsa.norm(x)
    Lj_list = []
    Vproj_list = []
    B, N, D = x_.shape
    for branch in mhsa.branches:
        q, k, v = branch.forward_qkv(x_)
        Lj = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(branch.head_dim)
        Lj_list.append(Lj)
        # compute V_proj (combined before out_proj)
        # we produce v -> [B,heads,N,hd], then merge to [B,N,D] and apply out_proj later
        v_merged = v.transpose(1,2).contiguous().view(B,N,branch.embed_dim)  # [B,N,D]
        Vproj_list.append(branch.out_proj(v_merged))  # actually apply out_proj to each branch's v-outputs

    S = sum(Lj_list)  # [B,heads,N,N]
    probs = torch.softmax(S, dim=-1)
    # compute sum_j (probs @ V_j_merged) ; but since probs same for all, probs @ sum(Vproj_preout) works if out_proj applied after
    # We used branch.out_proj already; to follow algebra: if out_proj is applied after merging, we should not apply it yet.
    # simpler exact way: compute v_merged_no_out = v.transpose... -> sum them -> then out = out_proj_combined(sum)
    v_merged_no_out = []
    for branch in mhsa.branches:
        _, _, v = branch.forward_qkv(x_)
        v_merged_no_out.append(v.transpose(1,2).contiguous().view(B,N,branch.embed_dim))
    V_sum = sum(v_merged_no_out)  # [B,N,D]
    # compute attn output (single pass)
    # convert probs [B,heads,N,N] to [B,heads,N,N], multiply with corresponding head-split V_sum
    # split V_sum into heads
    Vsum_heads = V_sum.view(B, N, mhsa.branches[0].num_heads, mhsa.branches[0].head_dim).transpose(1,2)
    out_heads = torch.matmul(probs, Vsum_heads)  # [B,heads,N,hd]
    out = out_heads.transpose(1,2).contiguous().view(B, N, mhsa.branches[0].embed_dim)
    # out_proj_combined = sum_i out_proj_i
    # build combined out_proj by summing weights and biases
    total_out_w = sum([b.out_proj.weight for b in mhsa.branches])
    total_out_b = sum([b.out_proj.bias for b in mhsa.branches]) if mhsa.branches[0].out_proj.bias is not None else None
    out_final = torch.nn.functional.linear(out, total_out_w, total_out_b)
    return out_final


def merge_params_approx(mhsa: LambdaMultiBranchMHSA):
    """
    Approximate param-level merge:
      - Sum q_proj.weight/bias, k_proj, v_proj, out_proj across branches.
      - Return a single CustomMultiHeadAttention with those summed params.
    NOTE: This is an approximation, not necessarily exact mathematically.
    """
    nb = mhsa.num_branches
    device = next(mhsa.parameters()).device
    b0 = mhsa.branches[0]
    embed_dim = b0.embed_dim; num_heads = b0.num_heads; dropout = b0.dropout.p if hasattr(b0, 'dropout') else 0.0
    merged = CustomMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout).to(device)

    # sum weights
    def sum_param(name):
        ps = [getattr(b, name) for b in mhsa.branches]
        ws = [p.weight.data for p in ps]
        wb = sum(ws)
        if ps[0].bias is not None:
            bs = [p.bias.data for p in ps]
            bb = sum(bs)
        else:
            bb = None
        return wb, bb

    # q
    wq, bq = sum_param('q_proj')
    merged.q_proj.weight.data.copy_(wq)
    if bq is not None:
        merged.q_proj.bias.data.copy_(bq)

    wk, bk = sum_param('k_proj')
    merged.k_proj.weight.data.copy_(wk)
    if bk is not None:
        merged.k_proj.bias.data.copy_(bk)

    wv, bv = sum_param('v_proj')
    merged.v_proj.weight.data.copy_(wv)
    if bv is not None:
        merged.v_proj.bias.data.copy_(bv)

    wo, bo = sum_param('out_proj')
    merged.out_proj.weight.data.copy_(wo)
    if bo is not None:
        merged.out_proj.bias.data.copy_(bo)

    return merged


def get_model_flops(model, img_size=224, device="cuda"):
    # model = model.to(device)
    # with torch.cuda.device(device):
    macs, params = get_model_complexity_info(
        model,
        (3, img_size, img_size),
        as_strings=True,
        print_per_layer_stat=False,
        verbose=False
    )
    print(f"[Model Stats] FLOPs (MACs): {macs}, Params: {params}")
    return macs, params


# -------------------------
# Demo and testing functions
# -------------------------
def demo_cosine_similarity_monitoring():
    """演示如何使用 cosine similarity 监控功能"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 创建模型
    model = ParallelViT(
        num_classes=1000,
        attn_branches=4,
        mlp_branches=4,
        depth=3,  # 使用较少的层以便演示
        embed_dim=192,
        num_heads=12
    ).to(device)

    # 创建监控器
    monitor = CosineSimilarityMonitor()

    # 创建输入数据
    B = 4
    img_size = 224
    x = torch.randn(B, 3, img_size, img_size, device=device)

    print("=== 开始 Cosine Similarity 监控演示 ===")

    model.eval()
    with torch.no_grad():
        # 前向传播并监控
        output = model(x, lambda_off=0.5, monitor=monitor)

        print(f"模型输出形状: {output.shape}")
        print(f"记录的注意力层相似度数量: {len(monitor.attn_similarities)}")
        print(f"记录的MLP层相似度数量: {len(monitor.mlp_similarities)}")

        # 打印详细的相似度信息
        print("\n=== 详细相似度信息 ===")
        for i, sim_info in enumerate(monitor.attn_similarities):
            print(f"Layer {sim_info['layer_idx']} Attention - "
                  f"Mean: {sim_info['mean_similarity']:.4f}, "
                  f"Min: {sim_info['min_similarity']:.4f}, "
                  f"Max: {sim_info['max_similarity']:.4f}, "
                  f"Std: {sim_info['std_similarity']:.4f}")
            print(f"  Similarity matrix:\n{sim_info['similarity_matrix']}")

        for i, sim_info in enumerate(monitor.mlp_similarities):
            print(f"Layer {sim_info['layer_idx']} MLP - "
                  f"Mean: {sim_info['mean_similarity']:.4f}, "
                  f"Min: {sim_info['min_similarity']:.4f}, "
                  f"Max: {sim_info['max_similarity']:.4f}, "
                  f"Std: {sim_info['std_similarity']:.4f}")
            print(f"  Similarity matrix:\n{sim_info['similarity_matrix']}")

        # 打印汇总统计
        monitor.print_summary()


if __name__ == "__main__":
    # 在这里测试模型 FLOPs
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ParallelViT(num_classes=1000,
                            attn_branches=4,
                            mlp_branches=4,
                            depth=3,)
    print(model)
    # model = model.to(device)
    # get_model_flops(model, img_size=224, device="cpu")
    #
    # model.eval()  # 如果只测前向，不做训练
    #
    # # 模拟输入
    # B = 256  # batch size，可改
    # img_size = 224
    # x = torch.randn(B, 3, img_size, img_size, device=device)
    #
    # torch.cuda.reset_peak_memory_stats(device)  # 重置统计
    # with torch.no_grad():
    #     _ = model(x)
    #
    # peak_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 3)  # 转为 GB
    # print(f"Peak GPU memory (forward only, B={B}): {peak_memory:.2f} GB")
    #
    # # 如果想测训练（forward + backward）
    # model.train()
    # y = torch.randint(0, 1000, (B,), device=device)
    # criterion = torch.nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    #
    # torch.cuda.reset_peak_memory_stats(device)
    # optimizer.zero_grad()
    # out = model(x)
    # loss = criterion(out, y)
    # loss.backward()
    # optimizer.step()
    # peak_memory_train = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
    # print(f"Peak GPU memory (training, B={B}): {peak_memory_train:.2f} GB")

    print("\n" + "=" * 50)
    demo_cosine_similarity_monitoring()