# iterative_token_model_soft_insert_blk.py

import torch
import torch.nn as nn
from functools import partial
import timm.models.vision_transformer
from timm.layers import Mlp
import torch.nn.functional as F
from typing import Optional, Union

# --- 基础模块 (SimpleLoRA, SoftRouter等保持不变) ---
class SimpleLoRA(nn.Module):
    def __init__(self, dim: int, rank: int):
        super().__init__()
        self.lora_A = nn.Linear(dim, rank, bias=False)
        self.lora_B = nn.Linear(rank, dim, bias=False)
        nn.init.kaiming_uniform_(self.lora_A.weight, a=0.01)
        nn.init.zeros_(self.lora_B.weight)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lora_B(self.lora_A(x))

class SoftRouter(nn.Module):
    def __init__(self, d_model, tau_start=1.0, tau_end=0.1, total_steps=30_000):
        super().__init__()
        self.linear = nn.Linear(d_model, 1)
        self.register_buffer("tau_start", torch.tensor(tau_start))
        self.register_buffer("tau_end", torch.tensor(tau_end))
        self.total_steps = total_steps
        self.step = 0

    def forward(self, x):
        # x: [B, T, C]
        self.step += 1
        tau = (self.tau_start - self.tau_end) * \
              max(0, (self.total_steps - self.step) / self.total_steps) + self.tau_end

        logits = self.linear(x) / tau           # [B, T, 1]
        p = torch.sigmoid(logits)               # probability, gradient carrier

        if self.training:
            # Straight-Through: 采样硬 mask，但反向用 p
            u = torch.rand_like(p)
            hard = (u < p).float()              # 0/1
            mask = hard.detach() - p.detach() + p
        else:
            mask = (p > 0.5).float()            # 推理硬筛选

        return mask, p          # mask 用来选 token，p 可当权重

class IterativeLoRABlock(nn.Module):
    """
    修改后的 IterativeLoRABlock.
    它不再拥有独立的LoRA/Router，而是接收包含所有领域模块的ModuleDicts，
    并根据传入的 domain_name 动态选择使用。
    现在支持vanilla forward用于特征提取。
    """
    def __init__(
        self,
        original_block: timm.models.vision_transformer.Block,
        # 新增：接收包含所有领域模块的字典
        domain_routers: nn.ModuleDict,
        domain_lora_qs: nn.ModuleDict,
        domain_lora_vs: nn.ModuleDict,
        domain_step_embeddings: nn.ParameterDict,
    ):
        super().__init__()

        # --- 存储对外部模块字典的引用 ---
        self.routers = domain_routers
        self.lora_qs = domain_lora_qs
        self.lora_vs = domain_lora_vs
        self.step_embeddings = domain_step_embeddings

        # --- 复制共享的ViT子层 (这些层在所有领域中是共享且冻结的) ---
        self.norm1, self.attn = original_block.norm1, original_block.attn
        self.drop_path1       = original_block.drop_path1
        self.norm2, self.mlp  = original_block.norm2,  original_block.mlp
        self.drop_path2       = original_block.drop_path2
        
        # 将原始模块的参数设置为不可训练，确保它们被冻结
        for param in self.parameters():
            param.requires_grad = False
            
    def apply_lora_and_attention(
        self, x: torch.Tensor, lora_q: SimpleLoRA, lora_v: SimpleLoRA, key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.attn.qkv(x).reshape(B, N, 3, self.attn.num_heads, C // self.attn.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        delta_q = lora_q(x).reshape(B, N, self.attn.num_heads, -1).permute(0, 2, 1, 3)
        delta_v = lora_v(x).reshape(B, N, self.attn.num_heads, -1).permute(0, 2, 1, 3)

        q, v = q + delta_q, v + delta_v
        attn = (q @ k.transpose(-2, -1)) * self.attn.scale
        if key_padding_mask is not None:
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2) \
                             .expand(-1, self.attn.num_heads, N, -1)
            attn.masked_fill_(mask, float("-inf"))
        attn = attn.softmax(dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        attn = self.attn.attn_drop(attn)

        x_out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_out = self.attn.proj(x_out)
        x_out = self.attn.proj_drop(x_out)
        return x_out

    def _process_block(self, x: torch.Tensor, lora_q: SimpleLoRA, lora_v: SimpleLoRA, key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        x_res = x
        x = self.norm1(x)
        x = x_res + self.drop_path1(self.apply_lora_and_attention(x, lora_q, lora_v, key_padding_mask))
        x_res = x
        x = self.norm2(x)
        x = x_res + self.drop_path2(self.mlp(x))
        return x

    def vanilla_forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        执行一个标准的、非迭代的ViT Block前向传播。
        用于Key-Value方法的特征提取，避免domain-specific偏差。
        """
        # 这是timm库中标准Block的forward实现
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

    def forward(self, x: torch.Tensor, domain_name: Optional[str] = None):
        """
        前向传播现在可以处理两种情况：
        1. 如果 domain_name 是一个字符串, 执行带LoRA的迭代计算。
        2. 如果 domain_name 是 None, 执行vanilla_forward。
        """
        # --- 情况2: 执行原生计算 (用于Key-Value的特征提取) ---
        if domain_name is None:
            return self.vanilla_forward(x), []  # 返回空的routing info

        # --- 情况1: 执行迭代计算 (用于训练和正常推理) ---
        # 根据 domain_name 从字典中选择当前需要的模块
        router = self.routers[domain_name]
        lora_q = self.lora_qs[domain_name]
        lora_v = self.lora_vs[domain_name]
        step_embeddings_param = self.step_embeddings[domain_name]

        B, N, C = x.shape
        x_canvas = x.clone()
        list_of_routing_data = []
        active_mask = torch.ones((B, N), device=x.device, dtype=torch.bool)
        
        num_recursion_steps = step_embeddings_param.shape[0]

        for step in range(num_recursion_steps):
            if not active_mask.any():
                break
            
            mask_ste, p_soft = router(x_canvas)
            mask_ste = mask_ste.squeeze(-1)
            
            hard_mask = (mask_ste > 0.5).detach()
            step_mask = hard_mask & active_mask
            key_padding_mask = ~step_mask
            x_with_step = x_canvas + step_embeddings_param[step]
            
            # 将选中的模块传入处理函数
            x_processed = self._process_block(
                x_with_step, 
                lora_q, 
                lora_v, 
                key_padding_mask)
            
            delta = x_processed - x_canvas
            gate = mask_ste * active_mask.float()
            gated_delta = delta * gate.unsqueeze(-1)
            x_canvas = x_canvas + gated_delta
            
            active_mask = step_mask
            list_of_routing_data.append(
                (p_soft.detach().squeeze(-1), step_mask.clone())
            )

        return x_canvas, list_of_routing_data

# --- 工厂函数 (用于创建基础ViT模型，暂不修改) ---
def create_iterative_token_routed_vit(model_name: str, pretrained: bool, **kwargs):
    # 这个函数现在只用来创建一个普通的、未被修改的ViT模型作为基础
    # 具体的IterativeLoRABlock替换将在DIL模型中完成
    model = timm.create_model(model_name, pretrained=pretrained, **kwargs)
    return model