import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import time
import ollama
import re
import math
from typing import List, Dict, Tuple
import random
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SelfGatingModule(nn.Module):


    def __init__(self, feature_dim):
        super().__init__()

        self.gate = nn.Linear(feature_dim, feature_dim // 4)
        self.proj = nn.Linear(feature_dim // 4, feature_dim)

        nn.init.zeros_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)

    def forward(self, x):

        gate_features = self.gate(x)
        gate_weights = torch.sigmoid(self.proj(gate_features))

        gated_x = (1.0 + gate_weights) * x
        return gated_x

class SafeBatchNorm1d(nn.Module):


    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        self.bn = nn.BatchNorm1d(num_features, eps, momentum, affine, track_running_stats)
        self.num_features = num_features
        self.eps = eps
        self.affine = affine


        if affine:
            self.weight = self.bn.weight
            self.bias = self.bn.bias

    def forward(self, x):

        if x.dim() != 2 and x.dim() != 3:
            raise ValueError(f'expected 2D or 3D input (got {x.dim()}D input)')


        if x.size(0) > 1 or not self.training:
            return self.bn(x)


        if self.affine:

            if x.dim() == 2:
                return x * self.weight.view(1, -1) + self.bias.view(1, -1)
            else:
                return x * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
        else:
            return x

class EnhancedEncoder(nn.Module):


    def __init__(self, input_dim, hidden_dims=[512, 256, 128]):
        super().__init__()

        layers = []
        dims = [input_dim] + hidden_dims

        for i in range(len(dims) - 1):

            layers.append(nn.Linear(dims[i], dims[i+1]))

            layers.append(SafeBatchNorm1d(dims[i+1]))
            layers.append(nn.ReLU())


            if i == len(dims) - 2:
                layers.append(SelfGatingModule(dims[i+1]))


            if i < len(dims) - 2:
                layers.append(nn.Dropout(0.2))

        self.encoder = nn.Sequential(*layers)


        self.has_residual = (input_dim == hidden_dims[-1])
        if not self.has_residual:
            self.residual_proj = nn.Sequential(
                nn.Linear(input_dim, hidden_dims[-1]),
                SafeBatchNorm1d(hidden_dims[-1])
            )


        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):

            nn.init.xavier_normal_(m.weight, gain=0.5)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, SafeBatchNorm1d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):

        encoded = self.encoder(x)


        if self.has_residual:
            residual = x
        else:
            residual = self.residual_proj(x)


        return encoded + 0.1 * residual

class BaseAgent(nn.Module):


    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MultiviewAgent(BaseAgent):


    def __init__(self, input_dim: int, n_clusters: int):
        super().__init__(input_dim, n_clusters)


        self.encoder = EnhancedEncoder(
            input_dim=input_dim,
            hidden_dims=[512, 256, 128]
        ).to(self.device)


        self.cluster_head = nn.Sequential(
            nn.Linear(128, 64),
            SafeBatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, n_clusters)
        ).to(self.device)


        self.projection_head = nn.Sequential(
            nn.Linear(128, 64),
            SafeBatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32)
        ).to(self.device)

    def forward(self, x: torch.Tensor, adj: torch.Tensor = None) -> torch.Tensor:

        x = x.to(self.device)

        features = self.encoder(x)

        logits = self.cluster_head(features)


        self.features = features
        self.projection = self.projection_head(features)

        return F.softmax(logits, dim=1)

    def get_features(self, x: torch.Tensor) -> torch.Tensor:

        x = x.to(self.device)
        with torch.no_grad():
            return self.encoder(x)

    def get_projection(self, x: torch.Tensor) -> torch.Tensor:

        x = x.to(self.device)
        with torch.no_grad():
            features = self.encoder(x)
            return self.projection_head(features)

class AdversarialAgent(BaseAgent):


    def __init__(self, input_dim: int, n_clusters: int):
        super().__init__(input_dim, n_clusters)


        self.encoder = EnhancedEncoder(
            input_dim=input_dim,
            hidden_dims=[512, 256, 128]
        ).to(self.device)


        self.contrast_head = nn.Sequential(
            nn.Linear(128, 64),
            SafeBatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32)
        ).to(self.device)


        self.cluster_head = nn.Sequential(
            nn.Linear(128, 64),
            SafeBatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, n_clusters)
        ).to(self.device)


        self.discriminator = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        ).to(self.device)

    def forward(self, x: torch.Tensor, adj: torch.Tensor = None) -> torch.Tensor:
        x = x.to(self.device)

        features = self.encoder(x)


        logits = self.cluster_head(features)


        self.features = features
        self.contrast_features = self.contrast_head(features)
        self.adversarial_score = self.discriminator(features)

        return F.softmax(logits, dim=1)

    def generate_adversarial_samples(self, x: torch.Tensor, epsilon: float = 0.1) -> torch.Tensor:

        x = x.to(self.device)
        x.requires_grad = True


        features = self.encoder(x)
        logits = self.cluster_head(features)
        pred = F.softmax(logits, dim=1)


        target = torch.argmax(pred, dim=1)
        loss = F.cross_entropy(logits, target)
        loss.backward()


        grad_sign = x.grad.sign()
        adv_x = x + epsilon * grad_sign


        adv_x = torch.clamp(adv_x, 0, 1)

        return adv_x

class FCR(nn.Module):

    def __init__(self, multi_n_num=2):
        super(FCR, self).__init__()
        self.l1 = nn.L1Loss()
        self.multi_n_num = multi_n_num

    def sample_with_j(self, k, n, j):

        n = min(n, k)


        j = max(0, min(j, k - 1))


        if k <= 1:
            return [j]


        numbers = list(range(k))


        sample = [j]


        remaining = [num for num in numbers if num != j]
        n_to_sample = min(n - 1, len(remaining))

        if n_to_sample > 0:
            sample.extend(random.sample(remaining, n_to_sample))

        return sample

    def forward(self, a, p, n):

        if a.size(0) < 1:
            return torch.tensor(0.0, device=a.device)


        batch_size = a.size(0)


        a_fft = torch.fft.fft(a, dim=1)
        p_fft = torch.fft.fft(p, dim=1)
        n_fft = torch.fft.fft(n, dim=1)


        a_fft_mag = torch.abs(a_fft)
        p_fft_mag = torch.abs(p_fft)
        n_fft_mag = torch.abs(n_fft)

        contrastive = 0
        for i in range(batch_size):
            d_ap = self.l1(a_fft_mag[i], p_fft_mag[i])


            if batch_size == 1:
                d_an = self.l1(a_fft_mag[i], n_fft_mag[i])
                contrastive += (d_ap / (d_an + 1e-7))
            else:

                for j in self.sample_with_j(batch_size, min(self.multi_n_num, batch_size), i):
                    d_an = self.l1(a_fft_mag[i], n_fft_mag[j])
                    contrastive += (d_ap / (d_an + 1e-7))


        if batch_size == 1:
            contrastive = contrastive / 1.0
        else:
            contrastive = contrastive / (min(self.multi_n_num, batch_size) * batch_size)

        return contrastive

class DMLp(nn.Module):

    def __init__(self, dim, growth_rate=2.0):
        super().__init__()
        hidden_dim = int(dim * growth_rate)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

class SMFA(nn.Module):

    def __init__(self, dim):
        super(SMFA, self).__init__()
        self.linear_0 = nn.Linear(dim, dim * 2)
        self.linear_1 = nn.Linear(dim, dim)
        self.linear_2 = nn.Linear(dim, dim)


        self.lde = DMLp(dim, 2)


        self.attention = nn.MultiheadAttention(dim, 4, batch_first=True)

        self.gelu = nn.GELU()


        self.alpha = nn.Parameter(torch.ones(1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, features):

        batch_size, num_features, dim = features.shape


        y, x = self.linear_0(features).chunk(2, dim=-1)


        attn_output, _ = self.attention(x, x, x)
        x_var = torch.var(x, dim=1, keepdim=True)
        x_l = x * self.gelu(self.linear_1(attn_output * self.alpha + x_var * self.beta))


        y_d = self.lde(y)


        return self.linear_2(x_l + y_d)

class EnhancedAggregator:


    def __init__(self, num_agents):
        self.num_agents = num_agents
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        

        self.base_weights = torch.ones(num_agents).to(self.device) / num_agents
        

        self.performance_history = []
        

        self.feature_dim = None
        self.smfa = None
        
    def _init_smfa(self, dim):

        self.feature_dim = dim
        self.smfa = SMFA(dim).to(self.device)
        
    def _enhance_features(self, features: List[torch.Tensor]) -> List[torch.Tensor]:

        if not features:
            return features
            

        features = [f.to(self.device) for f in features]
        

        batch_size = features[0].size(0)
        feature_dim = features[0].size(1)
        

        if self.smfa is None:
            self._init_smfa(feature_dim)
        

        stacked_features = torch.stack(features, dim=1)  # [batch_size, num_features, dim]
        

        enhanced_features = self.smfa(stacked_features)  # [batch_size, num_features, dim]
        

        return [enhanced_features[:, i, :] for i in range(enhanced_features.size(1))]
    
    def aggregate(self, predictions: List[torch.Tensor], features: List[torch.Tensor] = None) -> torch.Tensor:

        predictions = [p.to(self.device) for p in predictions]
        

        if not all(p.shape == predictions[0].shape for p in predictions):
            predictions = [p.view(p.size(0), -1) for p in predictions]
        

        if features and len(features) > 1:
            try:
                enhanced_features = self._enhance_features(features)
                

                weights = []
                for i, (pred, feat) in enumerate(zip(predictions, enhanced_features)):

                    feat_norm = torch.norm(feat, dim=1).mean()

                    pred_conf = torch.max(pred, dim=1)[0].mean()

                    weight = pred_conf * feat_norm
                    weights.append(weight.item())
                

                weights = torch.tensor(weights, device=self.device)
                weights = F.softmax(weights, dim=0)
            except Exception as e:
                logger.warning(f"特征增强失败，使用基本权重: {str(e)}")

                weights = torch.tensor([torch.max(p, dim=1)[0].mean().item() for p in predictions], device=self.device)
                weights = F.softmax(weights, dim=0)
        else:

            weights = torch.tensor([torch.max(p, dim=1)[0].mean().item() for p in predictions], device=self.device)
            weights = F.softmax(weights, dim=0)
        

        aggregated = torch.zeros_like(predictions[0])
        for i, pred in enumerate(predictions):
            if i < len(weights):  # 确保索引有效
                aggregated += weights[i] * pred
        
        return aggregated

class DynamicRouter:
    def __init__(self, device="cuda", config=None):
        self.device = device
        self.routing_history = []
        self.llm_contributions = []
        self.performance_history = []
        

        self.min_epoch = config.get("min_epoch", 100)  # LLM开始轮次
        self.llm_epochs = config.get("llm_epochs", 10)  # LLM介入轮数
        self.use_llm = config.get("use_llm", True)  # 是否使用LLM
        self.llm_interval = config.get("llm_interval", 5)  # LLM调用间隔
        

        self.current_epoch = 0
        
        # 缓存相关
        self.last_llm_epoch = -1  # 上次调用LLM的epoch
        self.cached_policy = None  # 缓存的策略

    def update_epoch(self, epoch):
        """更新当前epoch"""
        self.current_epoch = epoch

    def _get_agent_metrics(self, agent_reports: List[Dict]) -> List[Dict]:
        """获取无监督的agent指标"""
        agent_metrics = []
        
        for i, report in enumerate(agent_reports):
            # 直接使用报告中的无监督指标
            agent_info = {
                'id': i,
                # 预测质量指标
                'confidence': float(report.get('confidence', 0)),
                'entropy': float(report.get('entropy', 0)),
                'cluster_balance': float(report.get('cluster_balance', 0)),
                'feature_quality': float(report.get('feature_quality', 0)),
                
                # 资源使用指标
                'memory_usage': float(report.get('memory_usage', 0)),
                'compute_time': float(report.get('compute_time', 0)),
                'resource_efficiency': float(report.get('resource_efficiency', 0)),
                
                # 稳定性指标
                'consistency': float(report.get('consistency', 0)),
                
                # 综合评分
                'overall_score': (
                    float(report.get('confidence', 0)) * 0.3 +
                    float(report.get('consistency', 0)) * 0.3 +
                    float(report.get('feature_quality', 0)) * 0.2 +
                    float(report.get('resource_efficiency', 0)) * 0.2
                )
            }
            
            # 确保所有值都是有效的浮点数
            for key, value in agent_info.items():
                if key != 'id' and (not isinstance(value, (int, float)) or math.isnan(value) or math.isinf(value)):
                    agent_info[key] = 0.0
            
            agent_metrics.append(agent_info)
        
        return agent_metrics

    def generate_policy(self, agent_reports: List[Dict]) -> Dict:
        """生成动态路由策略"""
        try:
            if self.use_llm and self.min_epoch <= self.current_epoch < self.min_epoch + self.llm_epochs:
                epochs_since_last_llm = self.current_epoch - self.last_llm_epoch
                if not self.cached_policy or epochs_since_last_llm >= self.llm_interval:
                    # 获取agent指标 - 只使用无监督指标
                    agent_metrics = self._get_agent_metrics(agent_reports)
                    
                    # 构建详细的指标字符串，用于LLM提示
                    metrics_str = ""
                    
                    for agent in agent_metrics:
                        agent_info = f"Agent-{agent['id']}性能指标:\n"
                        agent_info += "- 预测质量:\n"
                        agent_info += f"  • 预测置信度: {agent['confidence']:.3f}\n"
                        agent_info += f"  • 预测熵: {agent['entropy']:.3f}\n"
                        agent_info += f"  • 聚类平衡度: {agent['cluster_balance']:.3f}\n"
                        agent_info += f"  • 特征质量: {agent['feature_quality']:.3f}\n"
                        agent_info += "- 资源效率:\n"
                        agent_info += f"  • 内存使用: {agent['memory_usage']:.1f}MB\n"
                        agent_info += f"  • 计算时间: {agent['compute_time']:.3f}s\n"
                        agent_info += f"  • 资源效率: {agent['resource_efficiency']:.3f}\n"
                        agent_info += "- 稳定性:\n"
                        agent_info += f"  • 一致性指标: {agent['consistency']:.3f}\n"
                        agent_info += f"  • 整体评分: {agent['overall_score']:.3f}\n\n"
                        
                        # 添加到指标字符串
                        metrics_str += agent_info
                    
                    # 获取实际Agent数量
                    num_agents = len(agent_metrics)
                    
                    # 构建prompt，包含实际的指标数据
                    prompt = f"""
请基于以下真实的无监督指标数据，生成最优的路由策略。

当前系统状态:
{metrics_str}

注意: 当前系统中只有{num_agents}个Agent (编号从0到{num_agents-1})。

请根据上述真实数据，通过多维度对比来选择最优方案。重点考虑:
1. 预测的确定性和稳定性
2. 聚类结构的质量
3. 计算资源的使用效率
4. 整体表现的平衡性

输出格式(请确保输出以下格式的单行数字):
<主要Agent编号>,<备用Agent编号>,<置信度阈值(0.6-0.95)>,<融合方式(1=加权/0=简单)>

示例输出: 0,1,0.75,1或者1,0,0.81,0

请先进行详细分析，然后在最后一行给出上述格式的决策数字。
"""
                    
                    # 调用LLM
                    response = ollama.chat(
                        'deepseek-r1:1.5b',
                        messages=[{'role': 'user', 'content': prompt}]
                    )
                    
                    # 解析响应并缓存策略
                    self.cached_policy = self._parse_llm_response(response.message.content, agent_metrics)
                    self.last_llm_epoch = self.current_epoch
                    
                    return self.cached_policy
                
                return self.cached_policy
            
            return self._get_default_policy(self._get_agent_metrics(agent_reports))
            
        except Exception as e:
            logger.error(f"策略生成失败: {e}")
            return self._get_default_policy(self._get_agent_metrics(agent_reports))

    def _parse_llm_response(self, response_text: str, agent_metrics: List[Dict]) -> Dict:
        try:
            logger.info(f"💭 决策分析:\n{response_text}")
            num_agents = len(agent_metrics)
            lines = response_text.strip().split('\n')
            for line in reversed(lines):

                clean_line = line.strip()

                prefixes = ["输出：", "结果：", "决策：", "策略：", "输出格式："]
                for prefix in prefixes:
                    if clean_line.startswith(prefix):
                        clean_line = clean_line[len(prefix):].strip()
                
                # 检查是否包含逗号分隔的数字
                if ',' in clean_line:
                    # 分割并清理每个部分
                    parts = [p.strip() for p in clean_line.split(',')]
                    # 过滤掉非数字部分
                    numbers = []
                    for part in parts:
                        # 提取数字部分
                        num_match = re.search(r'(\d+\.?\d*)', part)
                        if num_match:
                            numbers.append(num_match.group(1))
                    
                    # 根据实际Agent数量调整期望的数字数量
                    expected_numbers = 4  # 主要Agent, 备用Agent, 置信度阈值, 融合方式
                    
                    if len(numbers) >= expected_numbers:
                        try:
                            # 转换为正确的数据类型
                            primary_agent = int(float(numbers[0]))
                            fallback_agent = int(float(numbers[1]))
                            conf_threshold = float(numbers[2])
                            fusion_method = int(float(numbers[3]))
                            
                            # 验证数值范围
                            if primary_agent >= num_agents or primary_agent < 0:
                                logger.warning(f"Invalid primary agent index: {primary_agent}, using default")
                                return self._get_default_policy(agent_metrics)
                                
                            if fallback_agent >= num_agents or fallback_agent < 0:
                                logger.warning(f"Invalid fallback agent index: {fallback_agent}, using default")
                                fallback_agent = (primary_agent + 1) % num_agents
                            
                            # 确保主备不同
                            if fallback_agent == primary_agent and num_agents > 1:
                                fallback_agent = (primary_agent + 1) % num_agents
                                
                            conf_threshold = max(0.6, min(0.95, conf_threshold))
                            
                            policy = {
                                "primary_agent": primary_agent,
                                "fallback_agents": [fallback_agent],  # 只有一个备用Agent
                                "confidence_threshold": conf_threshold,
                                "fusion_method": "weighted_average" if fusion_method == 1 else "simple_average",
                                "used_llm": True
                            }
                            
                            # 打印详细的决策信息
                            logger.info("\n✨ 路由决策详情:")
                            primary_agent_metrics = agent_metrics[primary_agent]
                            logger.info(f"选择Agent-{primary_agent}作为主要Agent:")
                            logger.info(f"- 预测置信度: {primary_agent_metrics['confidence']:.3f}")
                            logger.info(f"- 预测一致性: {primary_agent_metrics['consistency']:.3f}")
                            logger.info(f"- 特征质量: {primary_agent_metrics['feature_quality']:.3f}")
                            logger.info(f"- 资源效率: {primary_agent_metrics['resource_efficiency']:.3f}")
                            logger.info(f"- 备用Agent: {fallback_agent}")
                            logger.info(f"- 置信度阈值: {conf_threshold:.2f}")
                            logger.info(f"- 融合策略: {'加权平均' if fusion_method else '简单平均'}")
                            
                            return policy
                        except (ValueError, IndexError) as e:
                            logger.warning(f"解析数字失败: {str(e)}, 尝试下一行")
                            continue
            
            # 如果没有找到有效的格式，使用默认策略
            logger.warning(f"无法从响应中提取有效数字格式，使用默认策略")
            return self._get_default_policy(agent_metrics)
            
        except Exception as e:
            logger.error(f"解析LLM响应失败: {str(e)}")
            # 返回默认策略
            return self._get_default_policy(agent_metrics)

    def _get_default_policy(self, agent_metrics: List[Dict]) -> Dict:
        """获取默认策略 - 适应实际Agent数量"""
        # 获取实际Agent数量
        num_agents = len(agent_metrics)
        
        # 根据overall_score选择主要Agent
        scores = [agent.get('overall_score', 0) for agent in agent_metrics]
        primary_idx = scores.index(max(scores))
        
        # 选择备用Agents - 如果只有一个Agent，备用也是它自己
        fallback_idx = primary_idx
        if num_agents > 1:
            indices = list(range(num_agents))
            indices.remove(primary_idx)
            fallback_idx = indices[0]  # 只取一个备用
        
        return {
            "primary_agent": primary_idx,
            "fallback_agents": [fallback_idx],  # 只有一个备用Agent
            "confidence_threshold": 0.7,  # 默认置信度阈值
            "fusion_method": "weighted_average",
            "used_llm": False
        }

class ClusteringMetrics:
    """聚类评估指标计算器"""
    
    @staticmethod
    def calculate_metrics(labels_true: np.ndarray, labels_pred: np.ndarray, 
                         features: np.ndarray = None) -> Dict[str, float]:
        """计算完整的聚类评估指标集"""
        metrics = {}
        
        # 基础聚类指标
        metrics["accuracy"] = ClusteringMetrics.calculate_acc(labels_true, labels_pred)
        metrics["nmi"] = normalized_mutual_info_score(labels_true, labels_pred)
        metrics["ari"] = adjusted_rand_score(labels_true, labels_pred)
        metrics["purity"] = ClusteringMetrics.calculate_purity(labels_true, labels_pred)
        
        return metrics

    @staticmethod
    def calculate_acc(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """计算聚类准确率（基于匈牙利算法）"""
        y_true = y_true.astype(np.int64)
        y_pred = y_pred.astype(np.int64)
        assert y_pred.size == y_true.size
        
        D = max(y_pred.max(), y_true.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)
        
        # 构建权重矩阵
        for i in range(y_pred.size):
            w[y_pred[i], y_true[i]] += 1
            
        # 使用匈牙利算法找到最优匹配
        ind_row, ind_col = linear_sum_assignment(w.max() - w)
        
        return sum([w[i, j] for i, j in zip(ind_row, ind_col)]) * 1.0 / y_pred.size

    @staticmethod
    def calculate_purity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """计算聚类纯度"""
        y_true = y_true.astype(np.int64)
        y_pred = y_pred.astype(np.int64)
        
        # 计算每个簇的主导类别
        contingency_matrix = ClusteringMetrics.contingency_matrix(y_true, y_pred)
        return np.sum(np.amax(contingency_matrix, axis=1)) / np.sum(contingency_matrix)

    @staticmethod
    def contingency_matrix(labels_true: np.ndarray, labels_pred: np.ndarray) -> np.ndarray:
        """计算列联矩阵"""
        classes = np.unique(labels_true)
        clusters = np.unique(labels_pred)
        contingency = np.zeros((len(clusters), len(classes)))
        
        for i, cluster in enumerate(clusters):
            for j, cls in enumerate(classes):
                contingency[i, j] = np.sum(
                    (labels_pred == cluster) & (labels_true == cls)
                )
                
        return contingency

class MultiviewClusterTrainer:
    """多视图聚类训练器核心实现"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.current_epoch = 0
        
        # 初始化路由器
        self.router = DynamicRouter(device=self.device, config=self.config)
        
    def compute_loss(self, predictions, batch_views):
        """计算综合损失函数"""
        # 获取当前批次大小和设备
        batch_size = predictions.size(0)
        device = predictions.device

        # 对比损失 - 鼓励相似样本的特征接近
        contrast_loss = 0.0
        view_features = []

        # 获取每个视图的特征
        for view_idx, view_data in enumerate(batch_views):
            view_data = view_data.to(device)
            # 获取特征
            features = self.multiview_agents[view_idx].get_features(view_data)
            view_features.append(features)

        # 计算对比损失
        for i in range(len(view_features)):
            for j in range(i + 1, len(view_features)):
                # 计算特征间的余弦相似度
                sim_matrix = F.cosine_similarity(
                    view_features[i].unsqueeze(1),
                    view_features[j].unsqueeze(0),
                    dim=2
                )
                # 对角线上是对应样本的相似度，应该最大化
                pos_sim = torch.diag(sim_matrix)
                # 负样本相似度，应该最小化
                neg_sim = sim_matrix - torch.eye(batch_size, device=device) * 2.0

                # InfoNCE损失
                logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) / 0.1
                labels = torch.zeros(batch_size, dtype=torch.long, device=device)
                contrast_loss += F.cross_entropy(logits, labels)

        # 添加频域对比正则化
        fcr_loss = 0.0
        if not hasattr(self, 'fcr'):
            self.fcr = FCR().to(device)

        # 对每对视图应用FCR
        for i in range(len(view_features)):
            for j in range(i + 1, len(view_features)):
                # 为每个样本选择负样本
                for k in range(batch_size):
                    # 选择一个不同的样本作为负样本
                    neg_idx = (k + 1) % batch_size
                    anchor = view_features[i][k:k + 1]
                    positive = view_features[j][k:k + 1]
                    negative = view_features[j][neg_idx:neg_idx + 1]

                    # 应用FCR
                    fcr_loss += self.fcr(anchor, positive, negative)

        # 聚类质量损失
        # 聚类一致性损失 - 鼓励清晰的聚类边界
        probs = F.softmax(predictions, dim=1)
        max_probs = torch.max(probs, dim=1)[0]
        cluster_clarity_loss = -torch.mean(torch.log(max_probs + 1e-10))

        # 聚类平衡损失 - 防止所有样本分到同一类
        cluster_size = torch.sum(probs, dim=0)  # 每个聚类的大小
        ideal_size = torch.ones_like(cluster_size) * batch_size / self.n_clusters
        cluster_balance_loss = torch.mean(torch.abs(cluster_size - ideal_size))

        # 一致性正则化 - 确保不同视图的预测一致
        consistency_loss = 0.0
        if len(batch_views) > 1:
            all_preds = []
            for view_idx, view_data in enumerate(batch_views):
                view_data = view_data.to(device)
                pred = F.softmax(self.multiview_agents[view_idx].cluster_head(
                    self.multiview_agents[view_idx].encoder(view_data)
                ), dim=1)
                all_preds.append(pred)

            # 计算预测之间的KL散度
            for i in range(len(all_preds)):
                for j in range(i + 1, len(all_preds)):
                    kl_ij = F.kl_div(all_preds[i].log(), all_preds[j], reduction='batchmean')
                    kl_ji = F.kl_div(all_preds[j].log(), all_preds[i], reduction='batchmean')
                    consistency_loss += (kl_ij + kl_ji) / 2

        # 综合损失 - 可调整的组件
        # 定义损失组件
        loss_components = {
            "contrast": contrast_loss,
            "fcr": fcr_loss,
            "clarity": cluster_clarity_loss,
            "balance": cluster_balance_loss,
            "consistency": consistency_loss
        }
        
        # 根据实验配置动态调整损失权重
        weights = {
            "contrast": 1.0,  # 对比损失权重
            "fcr": 0.5,       # 频域对比正则化权重
            "clarity": 0.5,   # 聚类清晰度权重
            "balance": 0.5,   # 聚类平衡权重
            "consistency": 0.3 # 一致性权重
        }
        
        # 计算最终的加权损失
        total_loss = sum(weights[k] * v for k, v in loss_components.items())
        
        return total_loss