#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
修复版记忆网络实验脚本
解决梯度计算、参数存储和训练逻辑问题
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import os
import re
import random
from collections import Counter
import jieba
import warnings
warnings.filterwarnings('ignore')

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class TextProcessor:
    """文本预处理器"""
    def __init__(self, max_vocab_size=10000, max_seq_len=128):
        self.max_vocab_size = max_vocab_size
        self.max_seq_len = max_seq_len
        self.word2idx = {}
        self.idx2word = {}
        self.vocab_size = 0
        
    def build_vocab(self, texts):
        """构建词汇表"""
        print("构建词汇表...")
        word_counts = Counter()
        
        for text in texts:
            words = jieba.lcut(text)
            word_counts.update(words)
        
        most_common = word_counts.most_common(self.max_vocab_size - 2)
        
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = {0: '<PAD>', 1: '<UNK>'}
        
        for i, (word, count) in enumerate(most_common):
            self.word2idx[word] = i + 2
            self.idx2word[i + 2] = word
            
        self.vocab_size = len(self.word2idx)
        print(f"词汇表大小: {self.vocab_size}")
        
    def text_to_sequence(self, text):
        """将文本转换为序列"""
        words = jieba.lcut(text)
        sequence = []
        
        for word in words:
            if word in self.word2idx:
                sequence.append(self.word2idx[word])
            else:
                sequence.append(self.word2idx['<UNK>'])
        
        if len(sequence) > self.max_seq_len:
            sequence = sequence[:self.max_seq_len]
        else:
            sequence.extend([self.word2idx['<PAD>']] * (self.max_seq_len - len(sequence)))
            
        return sequence

def create_synthetic_chinese_dataset(num_samples=5000, num_classes=10):
    """创建合成中文文本数据集"""
    print("创建合成中文文本数据集...")
    
    categories = [
        "体育", "娱乐", "财经", "科技", "教育", 
        "军事", "政治", "社会", "健康", "旅游"
    ]
    
    category_keywords = {
        "体育": ["比赛", "运动员", "足球", "篮球", "奥运会", "冠军", "训练", "教练"],
        "娱乐": ["电影", "明星", "音乐", "电视剧", "综艺", "演员", "导演", "票房"],
        "财经": ["股票", "经济", "投资", "银行", "市场", "公司", "利润", "增长"],
        "科技": ["技术", "创新", "人工智能", "互联网", "软件", "硬件", "研发", "数据"],
        "教育": ["学校", "学生", "老师", "教育", "学习", "考试", "课程", "大学"],
        "军事": ["军队", "国防", "武器", "军事", "安全", "战略", "演习", "装备"],
        "政治": ["政府", "政策", "法律", "国家", "领导", "会议", "决策", "改革"],
        "社会": ["社会", "民生", "问题", "发展", "建设", "服务", "管理", "改善"],
        "健康": ["健康", "医疗", "医院", "医生", "疾病", "治疗", "保健", "养生"],
        "旅游": ["旅游", "景点", "酒店", "旅行", "风景", "文化", "美食", "度假"]
    }
    
    texts = []
    labels = []
    
    for i in range(num_samples):
        category_idx = i % num_classes
        category = categories[category_idx]
        keywords = category_keywords[category]
        
        text_length = np.random.randint(20, 100)
        text_words = []
        
        num_keywords = np.random.randint(3, 8)
        selected_keywords = np.random.choice(keywords, num_keywords, replace=True)
        text_words.extend(selected_keywords)
        
        common_words = ["的", "是", "在", "有", "和", "与", "为", "了", "也", "都", "可以", "能够", "进行", "实现", "发展", "提高", "加强", "促进", "推动", "支持"]
        remaining_length = text_length - len(text_words)
        if remaining_length > 0:
            additional_words = np.random.choice(common_words, remaining_length, replace=True)
            text_words.extend(additional_words)
        
        np.random.shuffle(text_words)
        text = "".join(text_words)
        
        texts.append(text)
        labels.append(category_idx)
    
    return texts, labels, categories

class TextCNN(nn.Module):
    """文本CNN模型"""
    def __init__(self, vocab_size, embed_dim=128, num_classes=10, num_filters=100, filter_sizes=[3, 4, 5]):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, (k, embed_dim)) for k in filter_sizes
        ])
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = x.unsqueeze(1)
        
        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(x))
            conv_out = conv_out.squeeze(3)
            pooled = F.max_pool1d(conv_out, conv_out.size(2))
            conv_outputs.append(pooled.squeeze(2))
        
        x = torch.cat(conv_outputs, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class TextLSTM(nn.Module):
    """文本LSTM模型"""
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_classes=10, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=0.5)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        lstm_out, (h_n, c_n) = self.lstm(x)
        output = self.fc(lstm_out[:, -1, :])
        return output

class TextTransformer(nn.Module):
    """文本Transformer模型"""
    def __init__(self, vocab_size, embed_dim=128, num_heads=8, num_layers=4, num_classes=10, max_seq_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = nn.Parameter(torch.randn(1, max_seq_len, embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = x + self.pos_encoding[:, :x.size(1), :]
        x = self.transformer(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

class FixedMemoryNetwork(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_classes=10, max_seq_len=128, 
                 max_memory_nodes=20, similarity_threshold=0.7):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.max_memory_nodes = max_memory_nodes
        self.similarity_threshold = similarity_threshold
        
        # 文本编码器
        self.text_encoder = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.LSTM(embed_dim, embed_dim, batch_first=True, dropout=0.3),
            nn.AdaptiveAvgPool1d(1)
        )
        
        # 修复：使用Parameter存储所有记忆相关数据
        self.memory_features = nn.Parameter(torch.zeros(max_memory_nodes, embed_dim))
        self.memory_labels = nn.Parameter(torch.full((max_memory_nodes,), -1, dtype=torch.long), requires_grad=False)
        self.memory_access_count = nn.Parameter(torch.zeros(max_memory_nodes), requires_grad=False)
        self.memory_valid_mask = nn.Parameter(torch.zeros(max_memory_nodes, dtype=torch.bool), requires_grad=False)
        self.next_memory_idx = 0
        
        # 新颖性评估网络
        self.novelty_estimator = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Linear(embed_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # 记忆融合网络
        self.memory_fusion = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.Sigmoid()
        )
        
        # 输出解码器
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(embed_dim // 2, num_classes)
        )
        
        # 初始化记忆节点
        self.init_memory_nodes()
        
    def init_memory_nodes(self, num_nodes=5):
        """初始化记忆节点"""
        print(f"初始化 {num_nodes} 个记忆节点...")
        with torch.no_grad():
            for i in range(num_nodes):
                self.memory_features.data[i] = torch.randn(self.embed_dim) * 0.1
                self.memory_labels.data[i] = -1
                self.memory_access_count.data[i] = 0
                self.memory_valid_mask.data[i] = True
            self.next_memory_idx = num_nodes
    
    def extract_features(self, x):
        """提取文本特征（不参与梯度计算）"""
        with torch.no_grad():
            embedded = self.text_encoder[0](x)
            lstm_out, _ = self.text_encoder[1](embedded)
            text_feature = lstm_out.mean(dim=1)
        return text_feature
    
    def calculate_novelty_score(self, feature):
        """计算特征的新颖性分数"""
        valid_memories = self.memory_features[self.memory_valid_mask]
        if len(valid_memories) == 0:
            return torch.tensor(1.0, device=feature.device)
        
        # 处理批量输入
        if feature.dim() == 1:
            feature = feature.unsqueeze(0)  # [1, embed_dim]
        
        # 计算与所有记忆节点的相似度
        # feature: [batch_size, embed_dim], valid_memories: [num_valid_memories, embed_dim]
        similarities = F.cosine_similarity(feature.unsqueeze(1), valid_memories.unsqueeze(0), dim=2)
        # similarities: [batch_size, num_valid_memories]
        max_similarity, _ = similarities.max(dim=1)  # [batch_size]
        
        # 新颖性分数 = 1 - 最大相似度
        novelty_score = 1.0 - max_similarity
        
        # 使用神经网络进一步评估新颖性
        neural_novelty = self.novelty_estimator(feature).squeeze()  # [batch_size]
        
        # 组合新颖性分数
        combined_novelty = 0.7 * novelty_score + 0.3 * neural_novelty
        
        return combined_novelty
    
    def add_memory_node(self, feature, label):
        """添加新的记忆节点（训练后调用）"""
        if self.next_memory_idx < self.memory_features.size(0):
            idx = self.next_memory_idx
            with torch.no_grad():
                self.memory_features.data[idx] = feature.detach().clone()
                self.memory_labels.data[idx] = label.item()
                self.memory_access_count.data[idx] = 1
                self.memory_valid_mask.data[idx] = True
            self.next_memory_idx += 1
            print(f"添加记忆节点 {idx}，当前总数: {self.next_memory_idx}")
            return idx
        else:
            # 替换访问次数最少的记忆节点
            valid_indices = torch.where(self.memory_valid_mask)[0]
            if len(valid_indices) > 0:
                access_counts = self.memory_access_count[valid_indices]
                min_access_idx = valid_indices[torch.argmin(access_counts)]
                with torch.no_grad():
                    self.memory_features.data[min_access_idx] = feature.detach().clone()
                    self.memory_labels.data[min_access_idx] = label.item()
                    self.memory_access_count.data[min_access_idx] = 1
                print(f"替换记忆节点 {min_access_idx}，访问次数: {self.memory_access_count.data[min_access_idx]}")
                return min_access_idx
        return -1
    
    def update_memory_node(self, node_idx, feature, label):
        """更新现有记忆节点"""
        with torch.no_grad():
            # 使用指数移动平均更新记忆节点
            alpha = 0.1
            self.memory_features.data[node_idx] = (1 - alpha) * self.memory_features.data[node_idx] + alpha * feature.detach()
            self.memory_access_count.data[node_idx] += 1
    
    def forward(self, x):
        """前向传播 - 不修改记忆节点，保持计算图完整"""
        batch_size = x.size(0)
        
        # 文本编码
        embedded = self.text_encoder[0](x)
        lstm_out, _ = self.text_encoder[1](embedded)
        text_feature = lstm_out.mean(dim=1)
        
        # 记忆检索
        valid_memories = self.memory_features[self.memory_valid_mask]
        if len(valid_memories) > 0:
            similarities = F.cosine_similarity(
                text_feature.unsqueeze(1), 
                valid_memories.unsqueeze(0), 
                dim=2
            )
            
            # 使用detach()避免梯度通过记忆节点传播
            max_similarity, best_memory_idx = similarities.max(dim=1)
            best_memory = valid_memories[best_memory_idx].detach()
            
            # 融合机制
            combined = torch.cat([text_feature, best_memory], dim=1)
            fusion_weight = self.memory_fusion(combined)
            fused_feature = fusion_weight * text_feature + (1 - fusion_weight) * best_memory
        else:
            fused_feature = text_feature
        
        # 输出解码
        output = self.decoder(fused_feature)
        return output
    
    def get_memory_statistics(self):
        """获取记忆统计信息"""
        valid_count = self.memory_valid_mask.sum().item()
        avg_access = self.memory_access_count[self.memory_valid_mask].mean().item() if valid_count > 0 else 0
        
        return {
            'num_memory_nodes': valid_count,
            'total_capacity': self.max_memory_nodes,
            'avg_access_count': avg_access,
            'next_memory_idx': self.next_memory_idx
        }

def train_enhanced_memory(model, train_loader, test_loader, criterion, optimizer, num_epochs=5):
    """专门训练增强记忆网络"""
    model.to(device)
    train_losses = []
    train_accuracies = []
    test_accuracies = []
    
    print(f"\n开始训练 FixedMemoryNetwork...")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        epoch_novelty_scores = []
        epoch_memory_candidates = []
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)  # 不传递labels，避免前向传播中修改记忆
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            
            # 训练后处理：收集需要添加记忆的样本
            with torch.no_grad():
                # 计算新颖性分数（单独计算，不参与梯度）
                text_feature = model.extract_features(inputs)
                novelty = model.calculate_novelty_score(text_feature)
                epoch_novelty_scores.append(novelty)
                
                # 收集高新颖性样本作为记忆候选
                high_novelty_mask = novelty > model.similarity_threshold
                if high_novelty_mask.any():
                    candidates = text_feature[high_novelty_mask]
                    candidate_labels = labels[high_novelty_mask]
                    epoch_memory_candidates.extend(zip(candidates, candidate_labels))
            
            if i % 50 == 49:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        # epoch结束后添加记忆节点
        if epoch_memory_candidates:
            # 随机选择部分候选添加到记忆
            num_to_add = min(len(epoch_memory_candidates), 5)  # 限制添加数量
            selected_candidates = random.sample(epoch_memory_candidates, num_to_add)
            
            for feature, label in selected_candidates:
                model.add_memory_node(feature, label)
        
        # 计算训练准确率
        train_acc = 100 * correct_train / total_train
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        train_accuracies.append(train_acc)
        
        # 测试阶段
        test_acc = evaluate_model(model, test_loader, "FixedMemoryNetwork")
        test_accuracies.append(test_acc)
        
        # 打印记忆统计信息
        stats = model.get_memory_statistics()
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
        print(f'  记忆节点数: {stats["num_memory_nodes"]}/{stats["total_capacity"]}, 平均访问次数: {stats["avg_access_count"]:.2f}')
    
    return train_losses, train_accuracies, test_accuracies

def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs=5, model_name="Model"):
    """训练模型"""
    model.to(device)
    train_losses = []
    train_accuracies = []
    test_accuracies = []
    
    print(f"\n开始训练 {model_name}...")
    
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            
            if i % 50 == 49:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        # 计算训练准确率
        train_acc = 100 * correct_train / total_train
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        train_accuracies.append(train_acc)
        
        # 测试阶段
        test_acc = evaluate_model(model, test_loader, model_name)
        test_accuracies.append(test_acc)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
    
    return train_losses, train_accuracies, test_accuracies

def evaluate_model(model, data_loader, model_name="Model"):
    """评估模型"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def measure_inference_time(model, data_loader, num_batches=10, model_name="Model"):
    """测量推理时间"""
    model.eval()
    total_time = 0
    processed_samples = 0
    
    with torch.no_grad():
        for i, (inputs, _) in enumerate(data_loader):
            if i >= num_batches:
                break
                
            inputs = inputs.to(device)
            start_time = time.time()
            _ = model(inputs)
            end_time = time.time()
            
            total_time += (end_time - start_time)
            processed_samples += inputs.size(0)
    
    avg_inference_time = total_time / processed_samples
    return avg_inference_time

def run_experiments():
    """运行所有实验"""
    print("开始修复版记忆网络实验...")
    
    # 创建数据集
    texts, labels, categories = create_synthetic_chinese_dataset(num_samples=5000, num_classes=10)
    
    # 文本预处理
    processor = TextProcessor(max_vocab_size=5000, max_seq_len=64)
    processor.build_vocab(texts)
    
    # 转换为序列
    sequences = [processor.text_to_sequence(text) for text in texts]
    X = torch.tensor(sequences, dtype=torch.long)
    y = torch.tensor(labels, dtype=torch.long)
    
    # 划分训练集和测试集
    train_size = int(0.8 * len(X))
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]
    
    # 创建数据加载器
    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_test, y_test)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    print(f"训练集大小: {len(X_train)}")
    print(f"测试集大小: {len(X_test)}")
    print(f"词汇表大小: {processor.vocab_size}")
    print(f"类别数: {len(categories)}")
    print(f"类别: {categories}")
    
    # 模型初始化
    vocab_size = processor.vocab_size
    num_classes = len(categories)
    
    models = {
        'TextCNN': TextCNN(vocab_size, num_classes=num_classes),
        'TextLSTM': TextLSTM(vocab_size, num_classes=num_classes),
        'TextTransformer': TextTransformer(vocab_size, num_classes=num_classes),
        'FixedMemoryNetwork': FixedMemoryNetwork(vocab_size, num_classes=num_classes)
    }
    
    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizers = {
        name: optim.Adam(model.parameters(), lr=0.001) 
        for name, model in models.items()
    }
    
    # 训练所有模型
    results = {}
    for name, model in models.items():
        if name == "FixedMemoryNetwork":
            train_losses, train_accs, test_accs = train_enhanced_memory(
                model, train_loader, test_loader, criterion, optimizers[name], num_epochs=5
            )
        else:
            train_losses, train_accs, test_accs = train_model(
                model, train_loader, test_loader, criterion, optimizers[name], 
                num_epochs=5, model_name=name
            )
        
        # 评估最终性能
        final_acc = evaluate_model(model, test_loader, name)
        inference_time = measure_inference_time(model, test_loader, model_name=name)
        
        results[name] = {
            'final_accuracy': final_acc,
            'inference_time': inference_time,
            'train_losses': train_losses,
            'train_accuracies': train_accs,
            'test_accuracies': test_accs
        }
        
        # 获取记忆网络统计信息
        if name == "FixedMemoryNetwork":
            memory_stats = model.get_memory_statistics()
            results[name]['memory_statistics'] = memory_stats
            print(f"{name} - 最终准确率: {final_acc:.2f}%, 推理时间: {inference_time*1000:.4f}ms")
            print(f"  记忆节点数: {memory_stats['num_memory_nodes']}/{memory_stats['total_capacity']}, 平均访问次数: {memory_stats['avg_access_count']:.2f}")
        else:
            print(f"{name} - 最终准确率: {final_acc:.2f}%, 推理时间: {inference_time*1000:.4f}ms")
    
    # 绘制训练曲线
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    for name, result in results.items():
        plt.plot(result['train_losses'], label=name, linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    for name, result in results.items():
        plt.plot(result['train_accuracies'], label=name, linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Training Accuracy (%)')
    plt.title('Training Accuracy Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    for name, result in results.items():
        plt.plot(result['test_accuracies'], label=name, linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.title('Test Accuracy Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('fixed_memory_experiment_results.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # 保存结果
    with open('fixed_memory_experiment_results.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    # 输出最终结果
    print("\n=== 最终实验结果 ===")
    print(f"{'模型':<25} {'准确率':<10} {'推理时间(ms)':<15}")
    print("-" * 50)
    for name, result in results.items():
        print(f"{name:<25} {result['final_accuracy']:<10.2f} {result['inference_time']*1000:<15.4f}")
    
    return results

if __name__ == "__main__":
    try:
        results = run_experiments()
        print("\n实验完成！结果已保存到 fixed_memory_experiment_results.json")
        print("图表已保存到 fixed_memory_experiment_results.png")
        
    except Exception as e:
        print(f"实验过程中出现错误: {e}")
        import traceback
        traceback.print_exc()
