import torch
import torch.nn as nn
import math
import numpy as np
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg, Block
from timm.models.registry import register_model
from timm.layers import get_norm_layer, get_act_layer, DropPath, Mlp


class KINDLinear(nn.Module):
    """KIND风格的线性层，分离基因知识和类别知识，添加正则化"""

    def __init__(self, in_features, out_features, gene_size, cls_size=None, bias=True, drop_rate=0.1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.gene_size = gene_size
        self.cls_size = cls_size
        self.drop_rate = drop_rate

        # 基因知识部分（所有任务共享）
        self.u_gene = nn.Linear(in_features, gene_size, bias=False)
        self.sigma_gene = nn.Parameter(torch.empty((gene_size,)))
        self.v_gene = nn.Linear(gene_size, out_features, bias=False)

        # 类别知识部分（任务特定）
        if self.cls_size is not None and self.cls_size > 0:
            self.u_cls = nn.Linear(in_features, cls_size, bias=False)
            self.sigma_cls = nn.Parameter(torch.empty((cls_size,)))
            self.v_cls = nn.Linear(cls_size, out_features, bias=False)

        # 正则化层
        self.dropout = nn.Dropout(drop_rate) if drop_rate > 0 else nn.Identity()

        # 偏置项
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)

        self._init_weights()

    def _init_weights(self):
        """初始化权重"""
        # 基因知识初始化
        nn.init.kaiming_uniform_(self.u_gene.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.v_gene.weight, a=math.sqrt(5))
        nn.init.ones_(self.sigma_gene)

        # 类别知识初始化
        if self.cls_size is not None and self.cls_size > 0:
            nn.init.kaiming_uniform_(self.u_cls.weight, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.v_cls.weight, a=math.sqrt(5))
            nn.init.ones_(self.sigma_cls)  # 偏置初始化
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.u_gene.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, task_gate=None):
        """
        前向传播

        Args:
            x: 输入特征 [batch_size, seq_len, in_features]
            task_gate: 任务门控信号，需要与self.cls_size维度匹配
        """
        # 基因知识总是被使用
        x_gene = self.u_gene(x) * self.sigma_gene.unsqueeze(0).unsqueeze(0)
        out = self.v_gene(x_gene)

        # 添加dropout正则化
        out = self.dropout(out)

        # 类别知识根据门控选择性激活
        if self.cls_size is not None and self.cls_size > 0 and task_gate is not None:
            # 验证task_gate维度是否匹配
            if task_gate.size(-1) != self.cls_size:
                raise ValueError(f"task_gate size {task_gate.size(-1)} does not match cls_size {self.cls_size}")

            x_cls = self.u_cls(x) * self.sigma_cls.unsqueeze(0).unsqueeze(0) * task_gate
            cls_out = self.v_cls(x_cls)
            cls_out = self.dropout(cls_out)  # 类别知识也添加dropout
            out = out + cls_out

        # 添加偏置
        if self.bias is not None:
            out = out + self.bias

        return out


class KINDAttention(nn.Module):
    """KIND风格的注意力机制，添加正则化"""

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
                 gene_size=512, cls_size=None, drop_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # 使用KIND线性层替代标准的QKV投影，传递dropout参数
        self.qkv = KINDLinear(dim, dim * 3, gene_size, cls_size, bias=qkv_bias, drop_rate=drop_rate)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = KINDLinear(dim, dim, gene_size, cls_size, drop_rate=drop_rate)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, task_gate=None):
        B, N, C = x.shape
        qkv = self.qkv(x, task_gate).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x, task_gate)
        x = self.proj_drop(x)
        return x


class KINDMlp(nn.Module):
    """KIND风格的MLP，添加正则化"""

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 drop=0., gene_size=512, cls_size=None, drop_rate=0.1):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = KINDLinear(in_features, hidden_features, gene_size, cls_size, drop_rate=drop_rate)
        self.act = act_layer()
        self.fc2 = KINDLinear(hidden_features, out_features, gene_size, cls_size, drop_rate=drop_rate)
        self.drop = nn.Dropout(drop)

    def forward(self, x, task_gate=None):
        x = self.fc1(x, task_gate)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x, task_gate)
        x = self.drop(x)
        return x


class KINDBlock(nn.Module):
    """KIND风格的Transformer块，添加正则化"""

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 gene_size=512, cls_size=None, drop_rate=0.1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = KINDAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                                  attn_drop=attn_drop, proj_drop=drop,
                                  gene_size=gene_size, cls_size=cls_size, drop_rate=drop_rate)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = KINDMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
                           drop=drop, gene_size=gene_size, cls_size=cls_size, drop_rate=drop_rate)

    def forward(self, x, task_gate=None):
        x = x + self.drop_path(self.attn(self.norm1(x), task_gate))
        x = x + self.drop_path(self.mlp(self.norm2(x), task_gate))
        return x


class KINDFusionVisionTransformer(VisionTransformer):
    """
    KIND风格的融合Vision Transformer
    分离基因知识（所有任务共享）和类别知识（任务特定）
    """

    def __init__(self, gene_size=15 * 256, task_configs=None, drop_rate=0.1, drop_path_rate=0.1, *args, **kwargs):
        """
        Args:
            gene_size: 基因知识的维度
            task_configs: 任务配置，格式为 {'task_name': cls_size, ...}
            drop_rate: Dropout率，用于正则化
            drop_path_rate: DropPath率，用于正则化
        """
        # 移除原始的num_classes参数
        if 'num_classes' in kwargs:
            del kwargs['num_classes']
        kwargs['num_classes'] = 0

        self.gene_size = gene_size
        self.patch_points = kwargs['patch_size'] ** 2
        self.embed_dim_of_one_point = round(kwargs['embed_dim'] / self.patch_points)
        self.gene_size_of_one_point = round(self.gene_size / self.patch_points)
        self.total_cls_dim_of_one_point = self.embed_dim_of_one_point - self.gene_size_of_one_point

        self.task_configs = task_configs or {'congestion': self.patch_points, 'drc': self.patch_points, 'ir_drop': self.patch_points}
        self.num_tasks = len(self.task_configs)
        self.one_cls_dim_of_one_point = round(self.total_cls_dim_of_one_point / self.num_tasks)
        self.drop_rate = drop_rate
        self.drop_path_rate = drop_path_rate
        # self.patch_size = args.patch_size

        # 计算总的类别知识大小
        self.total_cls_size = sum(self.task_configs.values())
        # 计算每个任务在门控向量中的位置
        self.task_gate_pos = {}

        # self.mission_index = []
        # for mission in self.task_gate_pos.keys():
        #    self.mission_index.append(mission)
        super().__init__(*args, **kwargs)

        pos = 0
        for task_name, cls_size in self.task_configs.items():
            add = round(cls_size / self.patch_points)
            self.task_gate_pos[task_name] = (pos, pos + add)
            pos += add

        in_dim = kwargs['in_chans']
        self.patch_embed2 = Mlp(in_features=in_dim, out_features=self.embed_dim_of_one_point)

        # 用KIND块替换标准的Transformer块
        self._replace_blocks_with_kind()
        self.share_mlp = 0
        # 创建任务特定的输出头
        if self.share_mlp:
            self.task_heads = KINDMlp(in_features=self.embed_dim_of_one_point,
                                      out_features=self.num_tasks,
                                      gene_size=self.gene_size_of_one_point, cls_size=self.one_cls_dim_of_one_point)
        else:
            self._create_task_heads()
        # 初始化权重
        self._init_kind_weights()

    def _replace_blocks_with_kind(self):
        """用KIND块替换标准的Transformer块，添加渐进式正则化"""
        # 获取原始块的配置
        block_config = {
            'dim': self.embed_dim,
            'num_heads': self.blocks[0].attn.num_heads,
            'mlp_ratio': 4.0,  # 假设标准的4倍扩展
            'qkv_bias': True,
            'drop': self.drop_rate,  # 使用传入的dropout率
            'attn_drop': self.drop_rate * 0.5,  # 注意力dropout略小
            'drop_path': 0.0,  # 将在循环中设置
            'act_layer': nn.GELU,
            'norm_layer': nn.LayerNorm,
            'gene_size': self.gene_size,
            'cls_size': self.total_cls_size,
            'drop_rate': self.drop_rate  # 传递到KIND层
        }

        # 创建新的KIND块，使用渐进式DropPath
        new_blocks = nn.ModuleList()
        for i in range(len(self.blocks)):
            # 渐进增加drop_path率：深层网络使用更高的drop_path
            block_config['drop_path'] = self.drop_path_rate * i / (len(self.blocks) - 1) if len(
                self.blocks) > 1 else 0.0
            new_blocks.append(KINDBlock(**block_config))

        self.blocks = new_blocks

    def _create_task_heads(self):
        """创建任务特定的输出头 - 输出空间特征图"""
        self.task_heads = nn.ModuleDict()

        # 为每个任务分配一个任务头
        for task_name, cls_size in self.task_configs.items():
            print(f"Creating task head for {task_name} with cls_size: {cls_size}")
            # 创建任务头，使用强正则化防止过拟合
            self.task_heads[task_name] = Mlp(in_features=self.embed_dim_of_one_point, out_features=1)

    def _init_kind_weights(self):
        """初始化KIND相关权重"""
        # 任务头的权重已经在KINDLinear中初始化了
        pass

    def generate_task_gate(self, batch_size, task_name, device):
        """
        为指定任务生成门控信号

        Args:
            batch_size: 批量大小
            task_name: 任务名称 ('congestion', 'drc', 'ir_drop')
            device: 设备

        Returns:
            task_gate: [batch_size, total_cls_size] 门控向量
        """
        if task_name == "thermal":
            task_gate = torch.ones(self.total_cls_size, device=device)
            return task_gate
        task_gate = torch.zeros(self.total_cls_size, device=device).reshape(self.patch_points, -1)
        '''
        if task_name == "all":  #为所有任务平均地生成taskgate
            missions = len(self.task_gate_pos)
            missions_index = np.random.randint(0, missions, size=batch_size)
            for i in range(batch_size):
                task_gate[:,self.mission_index]
        '''
        if task_name in self.task_gate_pos:
            start_pos, end_pos = self.task_gate_pos[task_name]
            task_gate[:, start_pos:end_pos] = 1.0
        task_gate = task_gate.flatten()
        return task_gate

    def forward_features(self, x, task_gate=None):
        """提取特征，支持任务门控"""
        # Patch embedding
        b,c,h,w = x.shape
        x = x.permute(0,2,3,1)
        x = self.patch_embed2(x)
        x = x.reshape(b,h,-1)
        # x = self._pos_embed(x)
        #x = self.patch_drop(x)
        # 通过KIND Transformer块
        #for block in self.blocks:
        #    x = block(x, task_gate)
        x = self.blocks[0](x, task_gate)

        x = self.norm(x)
        #x = nn.GELU()(x)
        return x

    def forward_head(self, x, task_name, task_gate=None):
        """任务特定的头部前向传播 - 输出空间特征图"""
        # 移除CLS token，保留patch tokens
        # x = x[:, 1:]  # [B, num_patches, embed_dim]

        # 通过任务特定的头部
        if task_name in self.task_heads:
            head = self.task_heads[task_name]

            for i, layer in enumerate(head):
                if isinstance(layer, KINDLinear):
                    x = layer(x, task_gate)  # 使用任务特定的门控
                else:
                    x = layer(x)

            # 重新组织为空间特征图
            B, num_patches, output_dim = x.shape
            patch_h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]  # 16 for 256/16
            patch_w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]  # 16 for 256/16
            patch_size = self.patch_embed.patch_size[0]  # 16
            # 确定输出通道数
            if task_name == 'congestion':
                output_channels = 1
            elif task_name == 'thermal':
                output_channels = 1  # 热图是单通道输出
            else:  # drc, ir_drop
                output_channels = 1

            # 重新形状: [B, num_patches, output_channels * patch_size * patch_size]
            # -> [B, patch_h, patch_w, output_channels, patch_size, patch_size]
            x = x.view(B, self.patch_embed.img_size[0], self.patch_embed.img_size[0], -1)
            # 重新排列: [B, output_channels, patch_h, patch_size, patch_w, patch_size]
            x = x.permute(0, 3, 1, 2)  # 最终形状: [B, output_channels, H, W]

            return x
        else:
            raise ValueError(f"Unknown task: {task_name}")

    def forward(self, x, task_name=None):
        """
        前向传播

        Args:
            x: 输入特征 [batch_size, channels, height, width]
            task_name: 指定任务名称，如果为None则按照batch内顺序进行多任务训练

        Returns:
            torch.Tensor: 统一的96通道拼接输出 [batch_size, 96, height, width]
                        其中96 = 64(congestion) + 16(drc) + 16(ir_drop)
        """
        batch_size = x.size(0)
        device = x.device

        # 处理所有任务，生成96通道拼接输出
        # 对于新任务（如thermal），我们仍然生成96通道输出，但使用不同的激活模式
        standard_tasks = ['congestion', 'drc', 'ir_drop']
        all_predictions = []

        # 如果是单任务训练模式并且目标任务是thermal
        if task_name == 'thermal':
            # 对于thermal任务，我们重用现有的任务头但使用thermal的门控
            task_gate = self.generate_task_gate(batch_size, task_name, device)

            # 使用任务特定门控提取特征
            task_features = self.forward_features(x, task_gate)

            b, n, _ = task_features.shape
            if not self.share_mlp:
                head = self.task_heads["thermal"]
                task_features = task_features.reshape(b, n, -1, self.embed_dim_of_one_point)
                task_pred = head(task_features).permute(0, 3, 1, 2)
            else:
                task_pred = self.task_heads(task_features, task_gate).permute(0, 3, 1, 2)

            return task_pred
        elif task_name == "all":  # 默认预训练的方式
            # 标准任务处理
            for task in standard_tasks:
                # 为当前任务生成门控向量
                task_gate = self.generate_task_gate(batch_size, task, device)

                # 使用任务特定门控提取特征
                task_features = self.forward_features(x, task_gate)
                b, n, _ = task_features.shape
                if not self.share_mlp:
                    head = self.task_heads[task]
                    task_features = task_features.reshape(b, n, -1, self.embed_dim_of_one_point)
                    task_pred = head(task_features).permute(0, 3, 1, 2)
                    all_predictions.append(task_pred)
                else:
                    all_predictions = self.task_heads(task_features, task_gate).permute(0, 3, 1, 2)
            multi_task_output = torch.cat(all_predictions, dim=1)
            return multi_task_output
        else:  # task_name是预训练其中的一种
            assert task_name in ["congestion", "drc", "ir_drop"]
            task_gate = self.generate_task_gate(batch_size, task_name, device)

            # 使用任务特定门控提取特征
            task_features = self.forward_features(x, task_gate)
            b, n, _ = task_features.shape
            if not self.share_mlp:
                head = self.task_heads[task_name]
                task_features = task_features.reshape(b, n, -1, self.embed_dim_of_one_point)
                task_pred = head(task_features).permute(0, 3, 1, 2)
            else:
                task_pred = self.task_heads(task_features, task_gate).permute(0, 3, 1, 2)
            return task_pred

    def freeze_gene_knowledge(self):
        """冻结基因知识参数"""
        frozen_count = 0
        for name, param in self.named_parameters():
            if 'gene' in name:
                param.requires_grad = False
                frozen_count += 1
        print(f"冻结了 {frozen_count} 个基因知识参数")

    def freeze_cls_knowledge(self):
        """冻结类别知识参数"""
        frozen_count = 0
        for name, param in self.named_parameters():
            if 'cls' in name:
                param.requires_grad = False
                frozen_count += 1
        print(f"冻结了 {frozen_count} 个类别知识参数")

    def unfreeze_cls_knowledge(self):
        """解冻类别知识参数"""
        unfrozen_count = 0
        for name, param in self.named_parameters():
            if 'cls' in name:
                param.requires_grad = True
                unfrozen_count += 1
        print(f"解冻了 {unfrozen_count} 个类别知识参数")

    def load_cls_knowledge(self, cls_knowledge_path, target_task=None):
        """
        加载类别知识参数，支持跨任务迁移和从多任务权重中提取单任务知识

        Args:
            cls_knowledge_path: 类别知识权重文件路径
            target_task: 目标任务名称，用于选择性加载
        """
        checkpoint = torch.load(cls_knowledge_path, map_location='cpu')

        current_state = self.state_dict()
        loaded_count = 0
        skipped_count = 0
        extracted_count = 0

        for name, param in checkpoint["cls_state_dict"].items():
            if name in current_state:
                if current_state[name].shape == param.shape:
                    current_state[name] = param.clone()
                    loaded_count += 1
                    print(f"  加载类别知识参数: {name} {list(param.shape)}")
                else:
                    print(f"  跳过参数 {name}: 形状不匹配 {current_state[name].shape} vs {param.shape}")
                    skipped_count += 1
            else:
                print(f"  跳过参数 {name}: 当前模型中不存在")
                skipped_count += 1

        # 加载修改后的state_dict
        self.load_state_dict(current_state)

        print(f"类别知识加载完成:")
        print(f"  直接加载: {loaded_count} 个参数")
        print(f"  提取加载: {extracted_count} 个参数")
        print(f"  跳过: {skipped_count} 个参数")
        print(f"  目标任务: {target_task if target_task else 'all'}")

        return (loaded_count + extracted_count) > 0

    def save_cls_knowledge(self, save_path, task_name=None):
        """
        单独保存类别知识参数

        Args:
            save_path: 保存路径
            task_name: 任务名称（用于文件命名）
        """
        cls_state_dict = {}
        for name, param in self.named_parameters():
            if 'cls' in name:
                cls_state_dict[name] = param.clone()

        save_data = {
            'cls_state_dict': cls_state_dict,
            'task_configs': self.task_configs,
            'gene_size': self.gene_size,
            'task_name': task_name,
        }

        torch.save(save_data, save_path)
        print(f"类别知识已保存到: {save_path}")
        print(f"  保存了 {len(cls_state_dict)} 个类别知识参数")
        if task_name:
            print(f"  任务: {task_name}")

    def save_gene_knowledge(self, save_path):
        """单独保存基因知识"""
        gene_state_dict = {}
        backbone_state_dict = {}

        for name, param in self.named_parameters():
            if 'gene' in name:
                # 真正的基因知识参数
                gene_state_dict[name] = param.clone()
            elif 'cls' not in name and 'patch_embed' not in name:
                backbone_state_dict[name] = param.clone()
            elif 'task_heads' in name or 'patch_embed' in name:
                backbone_state_dict[name] = param.clone()
        torch.save({
            'gene_state_dict': gene_state_dict,
            'backbone_state_dict': backbone_state_dict,  # 分开保存backbone参数
            'gene_size': self.gene_size,
            'embed_dim': self.embed_dim,
            'model_config': {
                'patch_size': getattr(self.patch_embed, 'patch_size', (16, 16)),
                'img_size': getattr(self.patch_embed, 'img_size', (256, 256)),
                'in_chans': getattr(self.patch_embed, 'in_chans',
                                    getattr(self.patch_embed, 'in_channels', 36)),
                'embed_dim': self.embed_dim,
                'depth': len(self.blocks),
                'num_heads': self.blocks[0].attn.num_heads if len(self.blocks) > 0 else 8,
            }
        }, save_path)
        print(f"Gene knowledge saved to: {save_path}")
        print(f"  Gene parameters: {len(gene_state_dict)}")
        print(f"  Backbone parameters: {len(backbone_state_dict)}")

    def load_gene_knowledge(self, load_path):
        """加载基因知识"""
        checkpoint = torch.load(load_path, map_location='cpu')

        # 尝试不同的基因知识格式
        gene_state_dict = None
        backbone_state_dict = None

        if 'gene_state_dict' in checkpoint:
            gene_state_dict = checkpoint['gene_state_dict']
            backbone_state_dict = checkpoint.get('backbone_state_dict', {})
            print(f"Found gene_state_dict with {len(gene_state_dict)} parameters")
            print(f"Found backbone_state_dict with {len(backbone_state_dict)} parameters")
        elif 'model_state_dict' in checkpoint:
            # 从完整模型状态中提取基因知识和backbone参数
            full_state_dict = checkpoint['model_state_dict']
            gene_state_dict = {}
            backbone_state_dict = {}
            for name, param in full_state_dict.items():
                if 'gene' in name:
                    gene_state_dict[name] = param
                elif 'cls' not in name and 'patch_embed' not in name and 'task_heads' not in name:
                    backbone_state_dict[name] = param
            print(f"Extracted gene knowledge from model_state_dict: {len(gene_state_dict)} parameters")
            print(f"Extracted backbone parameters from model_state_dict: {len(backbone_state_dict)} parameters")
        else:
            # 向后兼容：使用旧的逻辑
            gene_state_dict = {}
            backbone_state_dict = {}
            for name, param in checkpoint.items():
                if 'gene' in name:
                    gene_state_dict[name] = param
                elif 'cls' not in name and 'patch_embed' not in name and 'task_heads' not in name:
                    backbone_state_dict[name] = param
            print(
                f"Using entire checkpoint as gene knowledge: {len(gene_state_dict)} gene + {len(backbone_state_dict)} backbone parameters")

        if (gene_state_dict is None or len(gene_state_dict) == 0) and (
                backbone_state_dict is None or len(backbone_state_dict) == 0):
            print(f"Warning: No gene knowledge or backbone parameters found in {load_path}")
            print(f"Available keys in checkpoint: {list(checkpoint.keys())}")
            return False

        # 加载基因知识和backbone参数
        current_state_dict = self.state_dict()
        loaded_count = 0
        missing_important = []
        missing_task_heads = []

        # 加载基因知识参数
        for name, param in gene_state_dict.items():
            if name in current_state_dict:
                if current_state_dict[name].shape == param.shape:
                    current_state_dict[name] = param
                    loaded_count += 1
                    print(f"  加载基因知识参数: {name} {list(param.shape)}")
                else:
                    print(f"Shape mismatch for gene param {name}: {current_state_dict[name].shape} vs {param.shape}")
            else:
                if 'task_heads' in name:
                    missing_task_heads.append(name)
                else:
                    missing_important.append(name)

        # 加载backbone参数
        for name, param in backbone_state_dict.items():
            if name in current_state_dict:
                if current_state_dict[name].shape == param.shape:
                    current_state_dict[name] = param
                    loaded_count += 1
                    print(f"  加载backbone知识参数: {name} {list(param.shape)}")
                else:
                    print(
                        f"Shape mismatch for backbone param {name}: {current_state_dict[name].shape} vs {param.shape}")
            else:
                missing_important.append(name)

        # 只显示重要的缺失参数
        if missing_important:
            print("Important parameters not found:")
            for name in missing_important:
                print(f"  {name}")

        if missing_task_heads:
            print(f"Task head gene parameters not found: {len(missing_task_heads)} (expected for new tasks)")

        self.load_state_dict(current_state_dict)
        print(f"Loaded {loaded_count} gene knowledge + backbone parameters from: {load_path}")
        return True


def create_kind_fusion_vit(gene_size=512, task_configs=None, drop_rate=0.1, drop_path_rate=0.1, **kwargs):
    """创建KIND融合VIT模型，支持正则化参数"""
    model = KINDFusionVisionTransformer(
        gene_size=gene_size,
        task_configs=task_configs,
        drop_rate=drop_rate,
        drop_path_rate=drop_path_rate,
        **kwargs
    )
    return model


@register_model
def kind_fusion_vit_base_patch16_224(pretrained=False, **kwargs):
    """KIND融合VIT模型，包含默认正则化设置"""
    model = create_kind_fusion_vit(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        gene_size=512,
        task_configs={'congestion': 128, 'drc': 64, 'ir_drop': 64},
        drop_rate=0.1,  # 默认10%的dropout
        drop_path_rate=0.1,  # 默认10%的drop_path
        **kwargs
    )
    return model
