#!/usr/bin/env python3
"""
内存优化的N-gram分布KL散度分析工具

该脚本用于计算不同模型和不同budget之间ngram分布的KL散度，
支持GPU加速和内存优化。

使用方法:
python ngram_kl_analysis_optimized.py --data_dir /path/to/ngram_results --output_dir ./kl_results
"""

import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, Optional, Iterator
import warnings
import gc
import psutil
import sys
from tqdm import tqdm

warnings.filterwarnings('ignore')

# 尝试导入GPU相关库
try:
    import torch
    import torch.nn.functional as F
    HAS_TORCH = True
    print(f"PyTorch可用，CUDA设备: {torch.cuda.is_available()}")
except ImportError:
    HAS_TORCH = False
    print("PyTorch不可用，将使用CPU计算")

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

def get_memory_usage():
    """获取当前内存使用情况（MB）"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

class MemoryOptimizedNgramKLAnalyzer:
    """内存优化的N-gram分布KL散度分析器"""
    
    def __init__(self, data_dir: str, output_dir: str = "./kl_results", 
                 max_memory_mb: int = 2000, use_gpu: bool = True):
        """
        初始化分析器
        
        Args:
            data_dir: ngram结果文件目录
            output_dir: 输出结果目录
            max_memory_mb: 最大内存使用限制（MB）
            use_gpu: 是否使用GPU加速
        """
        self.data_dir = data_dir
        self.output_dir = output_dir
        self.max_memory_mb = max_memory_mb
        self.use_gpu = use_gpu and HAS_TORCH and torch.cuda.is_available()
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 设置设备
        if self.use_gpu:
            self.device = torch.device('cuda')
            print(f"使用GPU: {torch.cuda.get_device_name()}")
        else:
            self.device = torch.device('cpu')
            print("使用CPU计算")
        
        print(f"内存限制: {max_memory_mb}MB")
        print(f"初始内存使用: {get_memory_usage():.1f}MB")
    
    def _parse_model_budget(self, model_dir: str) -> Tuple[str, str]:
        """解析模型目录名称，提取模型名称和budget"""
        parts = model_dir.split('_')
        if len(parts) >= 2:
            budget = parts[-2]
            model_name = '_'.join(parts[:-2])
        else:
            model_name = model_dir
            budget = "unknown"
        return model_name, budget
    
    def _load_ngram_file_streaming(self, file_path: str) -> Iterator[Tuple[str, int]]:
        """
        流式加载ngram文件，避免一次性加载所有数据到内存
        
        Args:
            file_path: 文件路径
            
        Yields:
            (ngram, count) 元组
        """
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                in_data_section = False
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    
                    # 检查是否进入数据部分
                    if line.startswith("=" * 60):
                        in_data_section = True
                        continue
                    
                    # 跳过文件头信息
                    if not in_data_section:
                        continue
                    
                    # 解析ngram数据行
                    if '\t' in line:
                        parts = line.split('\t', 1)
                        if len(parts) == 2:
                            ngram, count_str = parts
                            try:
                                yield ngram, int(count_str)
                            except ValueError:
                                continue
        except Exception as e:
            print(f"加载文件 {file_path} 时出错: {e}")
    
    def _load_ngram_file_batch(self, file_path: str, batch_size: int = 10000) -> Iterator[Dict[str, int]]:
        """
        批量加载ngram文件，控制内存使用
        
        Args:
            file_path: 文件路径
            batch_size: 批处理大小
            
        Yields:
            ngram计数字典批次
        """
        batch = {}
        count = 0
        
        for ngram, ngram_count in self._load_ngram_file_streaming(file_path):
            batch[ngram] = ngram_count
            count += 1
            
            if count >= batch_size:
                yield batch
                batch = {}
                count = 0
                
                # 检查内存使用
                if get_memory_usage() > self.max_memory_mb:
                    print(f"内存使用超过限制，执行垃圾回收...")
                    gc.collect()
        
        # 返回最后一批
        if batch:
            yield batch
    
    def calculate_kl_divergence_cpu(self, p_dist: Dict[str, int], q_dist: Dict[str, int], 
                                   smoothing: float = 1e-10) -> float:
        """CPU版本的KL散度计算"""
        all_ngrams = set(p_dist.keys()) | set(q_dist.keys())
        if not all_ngrams:
            return 0.0
        
        p_total = sum(p_dist.values())
        q_total = sum(q_dist.values())
        
        if p_total == 0 or q_total == 0:
            return float('inf')
        
        kl_div = 0.0
        for ngram in all_ngrams:
            p_prob = (p_dist.get(ngram, 0) + smoothing) / (p_total + smoothing * len(all_ngrams))
            q_prob = (q_dist.get(ngram, 0) + smoothing) / (q_total + smoothing * len(all_ngrams))
            kl_div += p_prob * np.log(p_prob / q_prob)
        
        return kl_div
    
    def calculate_kl_divergence_gpu(self, p_dist: Dict[str, int], q_dist: Dict[str, int], 
                                   smoothing: float = 1e-10) -> float:
        """GPU版本的KL散度计算"""
        all_ngrams = list(set(p_dist.keys()) | set(q_dist.keys()))
        if not all_ngrams:
            return 0.0
        
        p_total = sum(p_dist.values())
        q_total = sum(q_dist.values())
        
        if p_total == 0 or q_total == 0:
            return float('inf')
        
        # 创建ngram到索引的映射
        ngram_to_idx = {ngram: i for i, ngram in enumerate(all_ngrams)}
        
        # 创建概率向量
        p_probs = torch.zeros(len(all_ngrams), device=self.device)
        q_probs = torch.zeros(len(all_ngrams), device=self.device)
        
        for ngram, count in p_dist.items():
            idx = ngram_to_idx[ngram]
            p_probs[idx] = count
        
        for ngram, count in q_dist.items():
            idx = ngram_to_idx[ngram]
            q_probs[idx] = count
        
        # 归一化并添加平滑
        p_probs = (p_probs + smoothing) / (p_total + smoothing * len(all_ngrams))
        q_probs = (q_probs + smoothing) / (q_total + smoothing * len(all_ngrams))
        
        # 计算KL散度
        kl_div = torch.sum(p_probs * torch.log(p_probs / q_probs))
        
        return kl_div.item()
    
    def calculate_kl_divergence(self, p_dist: Dict[str, int], q_dist: Dict[str, int], 
                               smoothing: float = 1e-10) -> float:
        """计算KL散度，自动选择CPU或GPU版本"""
        if self.use_gpu and len(p_dist) + len(q_dist) > 1000:  # 大数据集使用GPU
            return self.calculate_kl_divergence_gpu(p_dist, q_dist, smoothing)
        else:
            return self.calculate_kl_divergence_cpu(p_dist, q_dist, smoothing)
    
    def calculate_js_divergence(self, p_dist: Dict[str, int], q_dist: Dict[str, int], 
                               smoothing: float = 1e-10) -> float:
        """计算Jensen-Shannon散度"""
        kl_pq = self.calculate_kl_divergence(p_dist, q_dist, smoothing)
        kl_qp = self.calculate_kl_divergence(q_dist, p_dist, smoothing)
        return 0.5 * (kl_pq + kl_qp)
    
    def compare_models_streaming(self, ngram_len: int = 3, file_type: str = "deduplicated", 
                                comparison_types: List[str] = None) -> List[Dict]:
        """
        流式比较模型，避免内存爆炸
        
        Args:
            ngram_len: ngram长度
            file_type: 文件类型
            comparison_types: 要进行的比较类型列表，如果为None则进行所有比较
            
        Returns:
            比较结果列表
        """
        print(f"Starting streaming comparison of {ngram_len}-gram distributions...")
        
        # 获取所有模型目录
        model_dirs = []
        for model_dir in os.listdir(self.data_dir):
            model_path = os.path.join(self.data_dir, model_dir)
            if os.path.isdir(model_path):
                file_name = f"ngram_{ngram_len}_{file_type}.txt"
                file_path = os.path.join(model_path, file_name)
                if os.path.exists(file_path):
                    model_dirs.append(model_dir)
        
        print(f"Found {len(model_dirs)} model directories")
        
        results = []
        
        # 分批处理模型对，避免内存爆炸
        batch_size = 5  # 每次只处理5个模型
        for i in range(0, len(model_dirs), batch_size):
            batch_dirs = model_dirs[i:i + batch_size]
            print(f"Processing batch {i//batch_size + 1}/{(len(model_dirs) + batch_size - 1)//batch_size}")
            
            # 加载当前批次的模型数据
            batch_data = {}
            for model_dir in batch_dirs:
                model_name, budget = self._parse_model_budget(model_dir)
                file_path = os.path.join(self.data_dir, model_dir, f"ngram_{ngram_len}_{file_type}.txt")
                
                # 流式加载数据
                ngram_data = {}
                for batch in self._load_ngram_file_batch(file_path, batch_size=5000):
                    ngram_data.update(batch)
                
                batch_data[model_dir] = {
                    'model_name': model_name,
                    'budget': budget,
                    'ngram_data': ngram_data
                }
            
            # 计算当前批次内的比较
            for j in range(len(batch_dirs)):
                for k in range(j + 1, len(batch_dirs)):
                    dir1, dir2 = batch_dirs[j], batch_dirs[k]
                    data1, data2 = batch_data[dir1], batch_data[dir2]
                    
                    kl_div = self.calculate_kl_divergence(
                        data1['ngram_data'], data2['ngram_data']
                    )
                    js_div = self.calculate_js_divergence(
                        data1['ngram_data'], data2['ngram_data']
                    )
                    
                    # 判断比较类型
                    if data1['model_name'] == data2['model_name']:
                        # 同一模型的不同budget比较
                        comparison_type = 'same_model_different_budgets'
                        if comparison_types is None or comparison_type in comparison_types:
                            results.append({
                                'model': data1['model_name'],
                                'budget1': data1['budget'],
                                'budget2': data2['budget'],
                                'ngram_len': ngram_len,
                                'kl_divergence': kl_div,
                                'js_divergence': js_div,
                                'comparison_type': comparison_type
                            })
                    elif data1['budget'] == data2['budget']:
                        # 不同模型相同budget的比较
                        comparison_type = 'different_models_same_budget'
                        if comparison_types is None or comparison_type in comparison_types:
                            results.append({
                                'model1': data1['model_name'],
                                'model2': data2['model_name'],
                                'budget': data1['budget'],
                                'ngram_len': ngram_len,
                                'kl_divergence': kl_div,
                                'js_divergence': js_div,
                                'comparison_type': comparison_type
                            })
                    else:
                        # 不同模型不同budget的比较（可选，如果需要的话）
                        comparison_type = 'different_models_different_budgets'
                        if comparison_types is None or comparison_type in comparison_types:
                            results.append({
                                'model1': data1['model_name'],
                                'model2': data2['model_name'],
                                'budget1': data1['budget'],
                                'budget2': data2['budget'],
                                'ngram_len': ngram_len,
                                'kl_divergence': kl_div,
                                'js_divergence': js_div,
                                'comparison_type': comparison_type
                            })
            
            # 清理内存
            del batch_data
            gc.collect()
            
            print(f"Current memory usage: {get_memory_usage():.1f}MB")
        
        return results
    
    def analyze_all_comparisons_streaming(self, file_type: str = "deduplicated", 
                                        comparison_types: List[str] = None) -> Dict[str, List[Dict]]:
        """
        流式执行所有类型的比较分析
        
        Args:
            file_type: 文件类型
            comparison_types: 要进行的比较类型列表，如果为None则进行所有比较
            
        Returns:
            包含所有分析结果的字典
        """
        print("Starting streaming execution of all comparison analyses...")
        
        results = {}
        
        # 对每个ngram长度进行分析
        for ngram_len in [3, 4, 5]:
            print(f"\nAnalyzing {ngram_len}-gram...")
            
            # 流式比较
            comparison_results = self.compare_models_streaming(ngram_len, file_type, comparison_types)
            
            if comparison_results:
                # 按比较类型分组
                same_model_results = [r for r in comparison_results if r['comparison_type'] == 'same_model_different_budgets']
                different_models_same_budget_results = [r for r in comparison_results if r['comparison_type'] == 'different_models_same_budget']
                different_models_different_budgets_results = [r for r in comparison_results if r['comparison_type'] == 'different_models_different_budgets']
                
                if same_model_results:
                    results[f'same_model_different_budgets_ngram_{ngram_len}'] = same_model_results
                
                if different_models_same_budget_results:
                    results[f'different_models_same_budget_ngram_{ngram_len}'] = different_models_same_budget_results
                
                if different_models_different_budgets_results:
                    results[f'different_models_different_budgets_ngram_{ngram_len}'] = different_models_different_budgets_results
        
        return results
    
    def save_results_streaming(self, results: Dict[str, List[Dict]]):
        """流式保存分析结果"""
        print("Saving analysis results...")
        
        # 按比较类型分组，合并不同ngram长度的结果
        grouped_results = {}
        for key, result_list in results.items():
            if not result_list:
                continue
            
            # 提取比较类型（去掉ngram长度部分）
            if '_ngram_' in key:
                comparison_type = key.split('_ngram_')[0]
            else:
                comparison_type = key
            
            if comparison_type not in grouped_results:
                grouped_results[comparison_type] = []
            grouped_results[comparison_type].extend(result_list)
        
        # 保存合并后的结果
        for comparison_type, result_list in grouped_results.items():
            if not result_list:
                continue
            
            # 转换为DataFrame并保存
            df = pd.DataFrame(result_list)
            csv_path = os.path.join(self.output_dir, f"{comparison_type}_all_ngrams.csv")
            df.to_csv(csv_path, index=False, encoding='utf-8')
            print(f"Saved results to: {csv_path}")
            
            # 保存汇总统计
            summary = {
                'count': len(df),
                'mean_kl': df['kl_divergence'].mean(),
                'std_kl': df['kl_divergence'].std(),
                'mean_js': df['js_divergence'].mean(),
                'std_js': df['js_divergence'].std(),
                'min_kl': df['kl_divergence'].min(),
                'max_kl': df['kl_divergence'].max(),
                'ngram_lengths': sorted(df['ngram_len'].unique().tolist())
            }
            
            summary_path = os.path.join(self.output_dir, f"{comparison_type}_summary.json")
            with open(summary_path, 'w', encoding='utf-8') as f:
                json.dump(summary, f, indent=2, ensure_ascii=False)
            print(f"Saved summary statistics to: {summary_path}")
    
    def create_visualizations_optimized(self, results: Dict[str, List[Dict]]):
        """创建优化的可视化图表"""
        print("Creating visualization charts...")
        
        # 按比较类型分组，合并不同ngram长度的结果
        grouped_results = {}
        for key, result_list in results.items():
            if not result_list:
                continue
            
            # 提取比较类型（去掉ngram长度部分）
            if '_ngram_' in key:
                comparison_type = key.split('_ngram_')[0]
            else:
                comparison_type = key
            
            if comparison_type not in grouped_results:
                grouped_results[comparison_type] = []
            grouped_results[comparison_type].extend(result_list)
        
        # 为每个比较类型创建合并的图表
        for comparison_type, result_list in grouped_results.items():
            if not result_list:
                continue
            
            df = pd.DataFrame(result_list)
            
            # 创建图表
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            
            # 设置标题映射
            title_mapping = {
                'same_model_different_budgets': 'Same Model Different Budgets - KL Divergence Analysis',
                'different_models_same_budget': 'Different Models Same Budget - KL Divergence Analysis',
                'different_models_different_budgets': 'Different Models Different Budgets - KL Divergence Analysis'
            }
            
            fig.suptitle(title_mapping.get(comparison_type, f'{comparison_type} - KL Divergence Analysis'), fontsize=16)
            
            # KL散度直方图
            axes[0, 0].hist(df['kl_divergence'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            axes[0, 0].set_title('KL Divergence Distribution')
            axes[0, 0].set_xlabel('KL Divergence')
            axes[0, 0].set_ylabel('Frequency')
            
            # JS散度直方图
            axes[0, 1].hist(df['js_divergence'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
            axes[0, 1].set_title('JS Divergence Distribution')
            axes[0, 1].set_xlabel('JS Divergence')
            axes[0, 1].set_ylabel('Frequency')
            
            # KL vs JS散度散点图
            axes[0, 2].scatter(df['kl_divergence'], df['js_divergence'], alpha=0.6)
            axes[0, 2].set_title('KL vs JS Divergence')
            axes[0, 2].set_xlabel('KL Divergence')
            axes[0, 2].set_ylabel('JS Divergence')
            
            # N-gram长度分布
            ngram_counts = df['ngram_len'].value_counts().sort_index()
            axes[1, 0].bar(ngram_counts.index, ngram_counts.values, alpha=0.7, color='lightgreen')
            axes[1, 0].set_title('N-gram Length Distribution')
            axes[1, 0].set_xlabel('N-gram Length')
            axes[1, 0].set_ylabel('Count')
            
            # 按N-gram长度的KL散度箱线图
            ngram_data = []
            ngram_labels = []
            for ngram_len in sorted(df['ngram_len'].unique()):
                ngram_kl = df[df['ngram_len'] == ngram_len]['kl_divergence']
                if len(ngram_kl) > 0:
                    ngram_data.append(ngram_kl)
                    ngram_labels.append(f'{ngram_len}-gram')
            
            if ngram_data:
                axes[1, 1].boxplot(ngram_data, labels=ngram_labels)
                axes[1, 1].set_title('KL Divergence by N-gram Length')
                axes[1, 1].set_ylabel('KL Divergence')
                axes[1, 1].tick_params(axis='x', rotation=45)
            
            # 根据比较类型调整最后一个子图
            if 'model' in df.columns:
                # 同一模型不同budget的比较
                model_data = []
                model_labels = []
                for model in df['model'].unique():
                    model_kl = df[df['model'] == model]['kl_divergence']
                    if len(model_kl) > 0:
                        model_data.append(model_kl)
                        model_labels.append(model)
                
                if model_data:
                    axes[1, 2].boxplot(model_data, labels=model_labels)
                    axes[1, 2].set_title('KL Divergence by Model')
                    axes[1, 2].set_ylabel('KL Divergence')
                    axes[1, 2].tick_params(axis='x', rotation=45)
            elif 'model1' in df.columns and 'budget' in df.columns:
                # 不同模型相同budget的比较
                budget_data = []
                budget_labels = []
                for budget in df['budget'].unique():
                    budget_kl = df[df['budget'] == budget]['kl_divergence']
                    if len(budget_kl) > 0:
                        budget_data.append(budget_kl)
                        budget_labels.append(f'Budget {budget}')
                
                if budget_data:
                    axes[1, 2].boxplot(budget_data, labels=budget_labels)
                    axes[1, 2].set_title('KL Divergence by Budget')
                    axes[1, 2].set_ylabel('KL Divergence')
                    axes[1, 2].tick_params(axis='x', rotation=45)
            elif 'model1' in df.columns:
                # 不同模型不同budget的比较
                model_data = []
                model_labels = []
                for model in df['model1'].unique():
                    model_kl = df[df['model1'] == model]['kl_divergence']
                    if len(model_kl) > 0:
                        model_data.append(model_kl)
                        model_labels.append(model)
                
                if model_data:
                    axes[1, 2].boxplot(model_data, labels=model_labels)
                    axes[1, 2].set_title('KL Divergence by Model')
                    axes[1, 2].set_ylabel('KL Divergence')
                    axes[1, 2].tick_params(axis='x', rotation=45)
            
            plt.tight_layout()
            
            # 保存图表
            plot_path = os.path.join(self.output_dir, f"{comparison_type}_analysis_all_ngrams.png")
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"Saved chart to: {plot_path}")
    
    def print_summary(self, results: Dict[str, List[Dict]]):
        """打印分析摘要"""
        print("\n" + "="*60)
        print("N-gram Distribution KL Divergence Analysis Summary")
        print("="*60)
        
        # 按比较类型分组，合并不同ngram长度的结果
        grouped_results = {}
        for key, result_list in results.items():
            if not result_list:
                continue
            
            # 提取比较类型（去掉ngram长度部分）
            if '_ngram_' in key:
                comparison_type = key.split('_ngram_')[0]
            else:
                comparison_type = key
            
            if comparison_type not in grouped_results:
                grouped_results[comparison_type] = []
            grouped_results[comparison_type].extend(result_list)
        
        # 设置比较类型的中英文映射
        type_mapping = {
            'same_model_different_budgets': 'Same Model Different Budgets',
            'different_models_same_budget': 'Different Models Same Budget',
            'different_models_different_budgets': 'Different Models Different Budgets'
        }
        
        for comparison_type, result_list in grouped_results.items():
            if not result_list:
                continue
            
            df = pd.DataFrame(result_list)
            type_name = type_mapping.get(comparison_type, comparison_type)
            print(f"\n{type_name}:")
            print(f"  Number of comparisons: {len(df)}")
            print(f"  KL Divergence - Mean: {df['kl_divergence'].mean():.4f}, Std: {df['kl_divergence'].std():.4f}")
            print(f"  JS Divergence - Mean: {df['js_divergence'].mean():.4f}, Std: {df['js_divergence'].std():.4f}")
            print(f"  KL Divergence Range: [{df['kl_divergence'].min():.4f}, {df['kl_divergence'].max():.4f}]")
            print(f"  N-gram Lengths: {sorted(df['ngram_len'].unique().tolist())}")
            
            # 显示前5个最大的KL散度
            top_kl = df.nlargest(5, 'kl_divergence')
            print("  Top 5 KL Divergence Comparisons:")
            for _, row in top_kl.iterrows():
                if 'model' in row:
                    # 同一模型不同budget的比较
                    print(f"    {row['model']}: {row['budget1']} vs {row['budget2']} = {row['kl_divergence']:.4f}")
                elif 'budget' in row:
                    # 不同模型相同budget的比较
                    print(f"    {row['model1']} vs {row['model2']} (budget {row['budget']}) = {row['kl_divergence']:.4f}")
                else:
                    # 不同模型不同budget的比较
                    print(f"    {row['model1']} vs {row['model2']} = {row['kl_divergence']:.4f}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='内存优化的N-gram分布KL散度分析工具')
    parser.add_argument('--data_dir', type=str, 
                       default='/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/ngram_results',
                       help='ngram结果文件目录路径')
    parser.add_argument('--output_dir', type=str, 
                       default='/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/kl_results',
                       help='输出结果目录路径')
    parser.add_argument('--file_type', type=str, default='deduplicated',
                       choices=['deduplicated', 'full', 'stats'],
                       help='要分析的ngram文件类型')
    parser.add_argument('--max_memory_mb', type=int, default=10240,
                       help='最大内存使用限制（MB）')
    parser.add_argument('--use_gpu', action='store_true', default=True,
                       help='使用GPU加速计算')
    parser.add_argument('--no_visualization', action='store_true',
                       help='跳过可视化图表生成')
    parser.add_argument('--comparison_types', nargs='+', 
                       choices=['same_model_different_budgets', 'different_models_same_budget', 'different_models_different_budgets'],
                       help='要进行的比较类型，可以指定多个。如果不指定则进行所有比较')
    
    args = parser.parse_args()
    
    # 创建分析器
    analyzer = MemoryOptimizedNgramKLAnalyzer(
        args.data_dir, 
        args.output_dir, 
        max_memory_mb=args.max_memory_mb,
        use_gpu=args.use_gpu
    )
    
    # 执行流式分析
    results = analyzer.analyze_all_comparisons_streaming(args.file_type, args.comparison_types)
    
    # 保存结果
    analyzer.save_results_streaming(results)
    
    # 创建可视化
    if not args.no_visualization:
        analyzer.create_visualizations_optimized(results)
    
    # 打印摘要
    analyzer.print_summary(results)
    
    print(f"\nAnalysis completed! Results saved to: {args.output_dir}")
    print(f"Final memory usage: {get_memory_usage():.1f}MB")


if __name__ == "__main__":
    main()
