# dil_models.py

import torch
import torch.nn as nn
import numpy as np
import timm
from methods.continual_iterative_model import IterativeLoRABlock, SimpleLoRA, SoftRouter
from fft_utils import extract_amp_spectrum, compare_spectra, extract_amp_spectrum_torch
import copy

class DomainIncrementalModel(nn.Module):
    def __init__(self, args, base_model: timm.models.vision_transformer.VisionTransformer):
        super().__init__()
        self.args = args
        self.num_domains = len(args.domains)
        self.target_block_indices = args.target_block_indices

        # --- 修改核心：从ModuleDict变为ModuleList of ModuleDicts ---
        # 每个元素代表一个被替换的Block，其内部是一个字典，映射domain到具体模块
        embed_dim = base_model.embed_dim
        self.head = base_model.head
        self.block_specific_routers = nn.ModuleList()
        self.block_specific_lora_qs = nn.ModuleList()
        self.block_specific_lora_vs = nn.ModuleList()
        self.block_specific_step_embeddings = nn.ModuleList()

        for _ in self.target_block_indices:
            self.block_specific_routers.append(nn.ModuleDict({
                domain: SoftRouter(embed_dim) for domain in args.domains
            }))
            self.block_specific_lora_qs.append(nn.ModuleDict({
                domain: SimpleLoRA(embed_dim, args.lora_rank) for domain in args.domains
            }))
            self.block_specific_lora_vs.append(nn.ModuleDict({
                domain: SimpleLoRA(embed_dim, args.lora_rank) for domain in args.domains
            }))
            # Step embeddings 需要用 ParameterDict of Parameters
            step_embs_per_block = nn.ParameterDict()
            for domain in args.domains:
                step_emb = nn.Parameter(torch.zeros(args.num_recursion_steps, 1, 1, embed_dim))
                nn.init.trunc_normal_(step_emb, std=.02)
                step_embs_per_block[domain] = step_emb
            self.block_specific_step_embeddings.append(step_embs_per_block)

        self.domain_heads = nn.ModuleDict({
            domain: copy.deepcopy(base_model.head) for domain in args.domains
        })

        # --- 共享模块部分不变 ---
        self.patch_embed = base_model.patch_embed
        self.cls_token = base_model.cls_token
        self.pos_embed = base_model.pos_embed
        self.pos_drop = base_model.pos_drop
        self.norm = base_model.norm
        self.head_drop = base_model.head_drop
        
        self.blocks = nn.ModuleList()
        replaced_block_counter = 0
        for i, blk in enumerate(base_model.blocks):
            if i in self.target_block_indices:
                iter_block = IterativeLoRABlock(
                    original_block=blk,
                    # 传入对应这个Block位置的模块字典
                    domain_routers=self.block_specific_routers[replaced_block_counter],
                    domain_lora_qs=self.block_specific_lora_qs[replaced_block_counter],
                    domain_lora_vs=self.block_specific_lora_vs[replaced_block_counter],
                    domain_step_embeddings=self.block_specific_step_embeddings[replaced_block_counter],
                )
                self.blocks.append(iter_block)
                replaced_block_counter += 1
            else:
                self.blocks.append(blk)
        
        # --- 初始化不同DIL方法的组件 ---
        # 1. FFT方法的组件（保持原有逻辑）
        if args.dil_method == 'fft':
            self.domain_amp_keys = nn.ModuleDict()

        # 2. Key-Value方法的组件（支持可学习和不可学习两种模式）
        if 'key_value' in args.dil_method:
            is_learnable = (args.dil_method == 'key_value_learnable')
            print(f"Initializing Key-Value mechanism. Keys are learnable: {is_learnable}")
            
            self.domain_keys = nn.ParameterDict()
            for domain in args.domains:
                key = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=is_learnable)
                if is_learnable:
                    # 使用均匀分布初始化可学习的key
                    nn.init.uniform_(key, -1, 1)
                self.domain_keys[domain] = key

        # +++ 新增：为K-Means方法初始化组件 +++
        if args.dil_method == 'kmeans':
            # 初始化空的buffers，稍后会在模型移动到设备后正确设置
            # 使用CPU初始化，稍后会自动移动到正确设备
            self.register_buffer('all_kmeans_centers', torch.empty(0, embed_dim))
            self.register_buffer('center_to_task_id_map', torch.empty(0, dtype=torch.long))
            print(f"Initialized K-Means mechanism with n_clusters={args.kmeans_n_clusters} per domain.")

        # --- 4. 冻结所有参数，后续由 set_active_domain 解冻 ---
        for param in self.parameters():
            param.requires_grad = False
            
        self.current_domain = self.args.domains[0]
        self.force_current_domain = False  # 新增：强制使用当前domain的标志

    def load_amp_keys(self, amp_keys_dict):
        """ 将计算好的振幅密钥加载到模型的缓冲区中 """
        print("Loading amplitude keys into model buffers...")
        
        # 获取模型当前所在的设备
        model_device = next(self.parameters()).device
        
        for domain_name, key_tensor in amp_keys_dict.items():
            # 确保tensor在正确的设备上
            key_tensor = key_tensor.to(model_device)
            # 使用更清晰的buffer名称
            buffer_name = f'amp_key_{domain_name}'
            self.register_buffer(buffer_name, key_tensor)
        
        print(f"Loaded {len(amp_keys_dict)} keys to device {model_device}.")

    # +++ Key-Value方法的辅助函数 +++
    def _l2_normalize(self, x, dim=1, epsilon=1e-12):
        """L2归一化辅助函数"""
        square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
        x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
        return x * x_inv_norm

    def _extract_vanilla_features(self, x):
        """
        提取纯净的（不带domain-specific模块的）特征，用于Pull-Loss计算。
        这个方法复用了vanilla_forward的逻辑。
        """
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        
        # 使用vanilla forward通过所有blocks
        for blk in self.blocks:
            if isinstance(blk, IterativeLoRABlock):
                # 传入 domain_name=None 来触发vanilla forward pass
                x, _ = blk(x, domain_name=None)
            else:
                # 标准的ViT Block直接调用
                x = blk(x)
        
        x = self.norm(x)
        return x[:, 0]  # 返回CLS token特征

    @torch.no_grad()
    def compute_and_set_domain_key(self, domain_name: str, data_loader):
        """
        计算某个领域的平均特征表示并设置为该领域的key。
        这个函数在某个领域训练完成后由训练引擎调用。
        只用于非可学习的key_value方法。
        """
        if self.args.dil_method == 'key_value_learnable':
            print(f"Skipping key computation for domain '{domain_name}' - using learnable keys.")
            return
            
        self.eval()
        all_features = []
        print(f"Computing feature key for domain: {domain_name}...")
        
        device = next(self.parameters()).device
        for samples, _ in data_loader:
            samples = samples.to(device)
            # 使用新的vanilla特征提取方法
            cls_features = self._extract_vanilla_features(samples)
            all_features.append(cls_features.cpu())

        if not all_features:
            print(f"Warning: No features extracted for domain {domain_name}. Key not updated.")
            return

        mean_features = torch.cat(all_features, dim=0).mean(dim=0, keepdim=True)  # [1, Embed_Dim]
        
        # 更新模型ParameterDict中对应的key
        self.domain_keys[domain_name].data.copy_(mean_features.to(self.domain_keys[domain_name].device))
        print(f"Successfully computed and set key for domain '{domain_name}'.")
        self.train()  # 返回训练模式

    # +++ 新增：为K-Means新增一个密钥计算和更新的函数 +++
    @torch.no_grad()
    def update_kmeans_keys_for_task(self, task_id: int, data_loader, n_clusters: int):
        """
        为当前任务计算K-Means簇中心，并将其追加到全局密钥池中。
        """
        self.eval()
        all_features = []
        print(f"Computing K-Means keys for Task {task_id+1} (n_clusters={n_clusters})...")

        # 1. 使用纯净的骨干网络提取特征
        device = next(self.parameters()).device
        for samples, _ in data_loader:
            samples = samples.to(device)
            cls_features = self._extract_vanilla_features(samples)
            all_features.append(cls_features.cpu().numpy())

        if not all_features:
            print(f"Warning: No features extracted for Task {task_id+1}. K-Means keys not updated.")
            return

        features_np = np.concatenate(all_features, axis=0)
        
        # 2. 执行K-Means聚类
        print(f"Running KMeans clustering on {features_np.shape[0]} samples...")
        try:
            from sklearn.cluster import KMeans
            kmeans = KMeans(n_clusters=n_clusters, random_state=self.args.seed, n_init='auto').fit(features_np)
            new_centers = torch.from_numpy(kmeans.cluster_centers_).float().to(device)
        except ImportError:
            print("Warning: sklearn not available. Falling back to mean calculation.")
            # 如果sklearn不可用，回退到计算均值
            mean_features = torch.from_numpy(features_np).mean(dim=0, keepdim=True).float().to(device)
            new_centers = mean_features

        # 3. 将新计算出的簇中心追加到全局密钥池
        self.all_kmeans_centers = torch.cat([self.all_kmeans_centers, new_centers], dim=0)
        
        # 4. 更新映射表
        new_task_ids = torch.full((new_centers.shape[0],), fill_value=task_id, dtype=torch.long, device=device)
        self.center_to_task_id_map = torch.cat([self.center_to_task_id_map, new_task_ids], dim=0)
        
        print(f"Successfully added {new_centers.shape[0]} new keys for Task {task_id+1}. Total keys now: {self.all_kmeans_centers.shape[0]}")
        self.train()

    # +++ 新增：权重迁移的核心方法 +++
    def transfer_weights_from_previous_task(self, current_task_id: int):
        """
        将上一个任务训练好的模块权重复制到当前任务对应的模块中。
        """
        if current_task_id == 0:
            print("Task 0 has no previous task. Skipping weight transfer.")
            return

        prev_domain_name = self.args.domains[current_task_id - 1]
        current_domain_name = self.args.domains[current_task_id]

        print(f"\n--- Transferring weights from '{prev_domain_name}' to '{current_domain_name}' ---")

        # 1. 迁移每个Block的模块权重
        for block_idx in range(len(self.block_specific_routers)):
            # 迁移Router权重
            prev_router = self.block_specific_routers[block_idx][prev_domain_name]
            current_router = self.block_specific_routers[block_idx][current_domain_name]
            current_router.load_state_dict(prev_router.state_dict())
            
            # 迁移LoRA Q权重
            prev_lora_q = self.block_specific_lora_qs[block_idx][prev_domain_name]
            current_lora_q = self.block_specific_lora_qs[block_idx][current_domain_name]
            current_lora_q.load_state_dict(prev_lora_q.state_dict())
            
            # 迁移LoRA V权重
            prev_lora_v = self.block_specific_lora_vs[block_idx][prev_domain_name]
            current_lora_v = self.block_specific_lora_vs[block_idx][current_domain_name]
            current_lora_v.load_state_dict(prev_lora_v.state_dict())
            
            # 迁移Step Embeddings
            prev_step_emb = self.block_specific_step_embeddings[block_idx][prev_domain_name]
            current_step_emb = self.block_specific_step_embeddings[block_idx][current_domain_name]
            with torch.no_grad():
                current_step_emb.data.copy_(prev_step_emb.data)

        print(f"Transferred weights for {len(self.block_specific_routers)} blocks (Router, LoRA_Q, LoRA_V, Step_Embeddings).")

        # 2. 迁移分类头权重
        prev_head = self.domain_heads[prev_domain_name]
        current_head = self.domain_heads[current_domain_name]
        current_head.load_state_dict(prev_head.state_dict())
        print("Transferred weights for classification head.")

        # # 3. 迁移可学习的Domain Keys（如果适用）
        # if self.args.dil_method == 'key_value_learnable' and hasattr(self, 'domain_keys'):
        #     prev_key = self.domain_keys[prev_domain_name]
        #     current_key = self.domain_keys[current_domain_name]
        #     with torch.no_grad():
        #         current_key.data.copy_(prev_key.data)
        #     print("Transferred weights for learnable domain keys.")

        print("--- Weight transfer complete. ---\n")

    def set_active_domain(self, domain_name: str):
        """
        根据新的细粒度结构，设置活动领域并解冻相应参数。
        """
        self.current_domain = domain_name
        
        # 遍历每个Block的模块字典，根据domain_name设置梯度
        for block_idx in range(len(self.block_specific_routers)):
            for d_name, module in self.block_specific_routers[block_idx].items():
                for param in module.parameters(): 
                    param.requires_grad = (d_name == domain_name)
            
            for d_name, module in self.block_specific_lora_qs[block_idx].items():
                for param in module.parameters(): 
                    param.requires_grad = (d_name == domain_name)
            
            for d_name, module in self.block_specific_lora_vs[block_idx].items():
                for param in module.parameters(): 
                    param.requires_grad = (d_name == domain_name)
            
            for d_name, param in self.block_specific_step_embeddings[block_idx].items():
                param.requires_grad = (d_name == domain_name)
        
        # 设置分类头的梯度
        for d_name, module in self.domain_heads.items():
            for param in module.parameters(): 
                param.requires_grad = (d_name == domain_name)
        
        # +++ 新增：如果是可学习模式，则同步解冻/冻结对应的Key +++
        if hasattr(self, 'domain_keys'):
            if self.args.dil_method == 'key_value_learnable':
                for d_name, param in self.domain_keys.items():
                    param.requires_grad = (d_name == domain_name)
            else:
                # 对于其他模式，key始终是不可训练的
                for param in self.domain_keys.values():
                    param.requires_grad = False
        
        print(f"Active domain set to '{domain_name}'. Trainable parameters have been updated.")

    def set_force_current_domain(self, force: bool):
        """
        设置是否强制使用当前domain进行推理
        在训练过程中的validation应该设置为True
        """
        self.force_current_domain = force

    def _forward_pass_with_domain(self, x, domain_name):
        """
        一个统一的内部前向传播函数，接收一个确定的domain_name。
        【已修改】将计算pull_sim的特征提取与计算logits的特征提取分离开。
        """
        # 1. 初始的 Patch 和 Position Embedding (两个路径共享)
        x_embedded = self.patch_embed(x)
        x_embedded = self._pos_embed(x_embedded)
        
        # 初始化routing info列表（保持原有格式）
        all_router_info = []

        # 2. 【新增逻辑】如果需要，为 "pull_sim" 计算一个纯净的 query_feature
        # 这个计算过程完全独立，不影响后续的主路径计算
        if (self.training and self.args.dil_method == 'key_value_learnable' 
            and hasattr(self.args, 'pull_constraint') and self.args.pull_constraint):
            
            # --- 执行一个独立的、纯净的(vanilla)前向传播 ---
            # 使用detach()而不是clone()来减少内存开销，因为这里不需要梯度
            with torch.no_grad():
                query_feature = self._extract_vanilla_features(x)
            
            # 使用纯净的query_feature计算pull_sim（需要梯度）
            current_key = self.domain_keys[domain_name]  # [1, embed_dim]
            query_norm = self._l2_normalize(query_feature, dim=1)
            key_norm = self._l2_normalize(current_key, dim=1)
            
            # 计算相似度: (B, C) @ (C, 1) -> (B, 1) -> (B,) -> mean -> scalar
            similarity = torch.matmul(query_norm, key_norm.t()).squeeze(-1)
            pull_sim = similarity.mean()
            
            # 将pull_sim作为特殊的routing info添加到列表中
            all_router_info.append(('pull_sim', pull_sim))

        # 3. 【主路径】执行带有领域特定模块的前向传播，用于计算分类logits
        x_main = x_embedded
        for blk in self.blocks:
            if isinstance(blk, IterativeLoRABlock):
                # 传入真实的domain_name，激活LoRA和Router
                x_main, info = blk(x_main, domain_name=domain_name)
                # 'info' 是一个包含(p_soft, step_mask)元组的列表
                if info: 
                    all_router_info.extend(info)
            else:
                x_main = blk(x_main)
        
        x_main = self.norm(x_main)
        pooled = x_main[:, 0]
        pooled = self.head_drop(pooled)
        
        # 使用特定领域的分类头
        active_head = self.domain_heads[domain_name]
        logits = active_head(pooled)
        
        return logits, all_router_info

    def forward(self, x):
        """
        分离训练和推理逻辑。
        - 训练时，使用由 set_active_domain 设置的 self.current_domain。
        - 推理时，根据force_current_domain和dil_method决定领域选择策略。
        """
        # --- 训练阶段的逻辑 ---
        if self.training:
            # 训练时，domain是固定的，由外部的set_active_domain方法设定
            return self._forward_pass_with_domain(x, self.current_domain)

        # --- 推理阶段的逻辑 ---
        else:
            # 如果强制使用当前domain（如训练过程中的validation）
            if self.force_current_domain:
                return self._forward_pass_with_domain(x, self.current_domain)
            
            # === 根据方法选择领域识别策略 ===
            
            # --- FFT方法（原有逻辑保持不变） ---
            if self.args.dil_method == 'fft':
                B, C, H, W = x.shape
                
                input_amp = extract_amp_spectrum_torch(x)
                distances = []
                available_domains = []
                
                # 检查哪些domain的key已经加载
                for d_name in self.args.domains:
                    buffer_name = f'amp_key_{d_name}'
                    if hasattr(self, buffer_name):
                        key = getattr(self, buffer_name)
                        dist = compare_spectra(input_amp, key)
                        distances.append(dist)
                        available_domains.append(d_name)
                
                if not distances:
                    # 如果没有加载任何key，使用当前domain
                    print(f"Warning: No FFT keys loaded, using current domain: {self.current_domain}")
                    return self._forward_pass_with_domain(x, self.current_domain)
                    
                distances = torch.stack(distances, dim=1)  # [B, Num_Domains]
                predicted_domain_indices = torch.argmin(distances, dim=1)  # [B]
                
                # 创建全局domain索引映射
                global_domain_indices = torch.zeros_like(predicted_domain_indices)
                for i, d_name in enumerate(available_domains):
                    mask = (predicted_domain_indices == i)
                    global_idx = self.args.domains.index(d_name)
                    global_domain_indices[mask] = global_idx

            # --- Key-Value方法（支持可学习和不可学习两种模式） ---
            elif 'key_value' in self.args.dil_method:
                with torch.no_grad():
                    # 1. 使用共享backbone提取query特征
                    query_feature = self._extract_vanilla_features(x)

                # 2. 找到最佳匹配的领域
                # 检查哪些领域有非零的key
                available_domains = [name for name, key in self.domain_keys.items() 
                                   if key.abs().sum() > 1e-6]
                
                if not available_domains:
                    print(f"Warning: No feature keys computed, using current domain: {self.current_domain}")
                    return self._forward_pass_with_domain(x, self.current_domain)
                
                domain_key_tensors = torch.cat([self.domain_keys[d] for d in available_domains], dim=0)

                # 3. 归一化并计算余弦相似度
                query_norm = self._l2_normalize(query_feature, dim=1)
                keys_norm = self._l2_normalize(domain_key_tensors, dim=1)
                similarity = torch.matmul(query_norm, keys_norm.t())  # [B, Num_Domains]
                
                # 4. 预测领域（相似度越高越好）
                predicted_domain_indices = torch.argmax(similarity, dim=1)  # [B]
                
                # 创建全局domain索引映射
                global_domain_indices = torch.zeros_like(predicted_domain_indices)
                for i, d_name in enumerate(available_domains):
                    mask = (predicted_domain_indices == i)
                    global_idx = self.args.domains.index(d_name)
                    global_domain_indices[mask] = global_idx

            # +++ 新增：K-Means方法的推理逻辑 +++
            elif self.args.dil_method == 'kmeans':
                if self.all_kmeans_centers.shape[0] == 0:
                    print(f"Warning: K-Means keys not computed yet. Using current domain: {self.current_domain}")
                    return self._forward_pass_with_domain(x, self.current_domain)

                with torch.no_grad():
                    # 1. 提取纯净的query特征
                    query_feature = self._extract_vanilla_features(x)  # [B, Embed_Dim]

                # 2. 计算与所有簇中心的欧氏距离
                # torch.cdist 计算两组向量之间的p范数距离，p=2即欧氏距离
                distances = torch.cdist(query_feature, self.all_kmeans_centers)  # [B, Total_Keys]
                
                # 3. 找到最近的簇中心，并映射回任务ID
                closest_center_indices = torch.argmin(distances, dim=1)  # [B]
                global_domain_indices = self.center_to_task_id_map[closest_center_indices]  # [B]
                
            else:
                raise NotImplementedError(f"Inference logic for method '{self.args.dil_method}' is not implemented.")

            # --- 通用的后处理逻辑（领域选择后） ---
            B = x.shape[0]
            final_logits = torch.zeros(B, self.args.nb_classes, device=x.device, dtype=x.dtype)
            
            # 根据全局domain索引进行前向传播
            for i in range(len(self.args.domains)):
                domain_name = self.args.domains[i]
                mask = (global_domain_indices == i)
                if mask.any():
                    subset_x = x[mask]
                    # 为这个子集运行前向传播
                    subset_logits, _ = self._forward_pass_with_domain(subset_x, domain_name)
                    # 确保数据类型匹配
                    final_logits[mask] = subset_logits.to(final_logits.dtype)
            
            return final_logits, global_domain_indices

    def _pos_embed(self, x):
        if self.cls_token is not None:
            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        if self.pos_embed is not None:
            x = x + self.pos_embed
        return self.pos_drop(x)