#!/usr/bin/env python3
"""
绘图脚本：比较三个不同训练集在6个evaluation datasets上的表现
生成6张图，每张图对应一个数据集，每张图上有3条线对应3个训练集
"""

import json
import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple

# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

class TrainingComparisonPlotter:
    def __init__(self, results_dir: str = "prompt_decoder/training_results"):
        self.results_dir = Path(results_dir)
        # 重新排列：第一行helpful，第二行safety，第三行math，每行左边rewrite，右边reject
        self.dataset_names = [
            'helpful_rewrite', 
            'helpful_reject',
            'safety_rewrite',
            'safety_reject',
            'math_rewrite',
            'math_reject'
        ]
        
        # 数据集显示名称
        self.dataset_display_names = {
            'helpful_reject': 'Helpful - Reject',
            'helpful_rewrite': 'Helpful - Rewrite',
            'math_reject': 'Math - Reject', 
            'math_rewrite': 'Math - Rewrite',
            'safety_reject': 'Safety - Reject',
            'safety_rewrite': 'Safety - Rewrite'
        }
        
        # 颜色和线型
        self.colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
        self.linestyles = ['-', '--', '-.', ':']
        
    def load_training_results(self, json_files: List[str]) -> Dict[str, Dict]:
        """加载训练结果JSON文件"""
        results = {}
        
        for json_file in json_files:
            if not os.path.exists(json_file):
                print(f"⚠️  File not found: {json_file}")
                continue
                
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                
                run_name = data.get('run_name', os.path.basename(json_file))
                
                # 提取每个epoch的6个数据集准确率
                epochs_data = {}
                for epoch_data in data.get('epochs', []):
                    epoch = epoch_data['epoch']
                    rewritten_results = epoch_data.get('rewritten_results', {})
                    six_datasets = rewritten_results.get('six_datasets_accuracy', {})
                    
                    # 计算准确率
                    accuracy_data = {}
                    for dataset_name in self.dataset_names:
                        dataset_stats = six_datasets.get(dataset_name, {'correct': 0, 'total': 0})
                        accuracy = dataset_stats['correct'] / dataset_stats['total'] if dataset_stats['total'] > 0 else 0.0
                        accuracy_data[dataset_name] = accuracy
                    
                    epochs_data[epoch] = accuracy_data
                
                results[run_name] = {
                    'epochs_data': epochs_data
                }
                
                print(f"✅ Loaded {run_name}: {len(epochs_data)} epochs")
                
            except Exception as e:
                print(f"❌ Error loading {json_file}: {e}")
        
        return results
    
    def plot_comparison(self, results: Dict[str, Dict], output_dir: str = "plots"):
        """绘制比较图"""
        os.makedirs(output_dir, exist_ok=True)
        
        # 创建3x2的子图布局
        fig, axes = plt.subplots(3, 2, figsize=(16, 18))
        axes = axes.flatten()
        
        # 为每个数据集绘制一张图
        for i, dataset_name in enumerate(self.dataset_names):
            ax = axes[i]
            
            # 绘制每个训练集的结果
            for j, (run_name, run_data) in enumerate(results.items()):
                epochs_data = run_data['epochs_data']
                
                # 提取该数据集的准确率
                epochs = sorted(epochs_data.keys())
                accuracies = [epochs_data[epoch][dataset_name] for epoch in epochs]
                
                if not epochs:  # 如果没有数据，跳过
                    continue
                
                # 使用run_name作为标签
                label = run_name
                
                # 绘制线条
                color = self.colors[j % len(self.colors)]
                linestyle = self.linestyles[j % len(self.linestyles)]
                
                ax.plot(epochs, accuracies, 
                       color=color, linestyle=linestyle, linewidth=2,
                       marker='o', markersize=4, label=label)
            
            # 设置子图属性 - 使用专业的ICLR论文标题格式
            title_map = {
                'helpful_rewrite': 'Helpful Domain - Rewrite-mse',
                'math_rewrite': 'Mathematical Reasoning - Rewrite-mse',
                'safety_rewrite': 'Safety Domain - Rewrite-mse',
                'helpful_reject': 'Helpful Domain - Reject-mse',
                'math_reject': 'Mathematical Reasoning - Reject-mse',
                'safety_reject': 'Safety Domain - Reject-mse'
            }
            ax.set_title(title_map[dataset_name], fontsize=14, fontweight='bold')
            ax.set_xlabel('Epoch', fontsize=12)
            ax.set_ylabel('Accuracy', fontsize=12)
            ax.grid(True, alpha=0.3)
            ax.legend(fontsize=10)
            
            # 统一设置y轴范围为35%-100%
            ax.set_ylim(0.35, 1.0)
            
            # 设置y轴为百分比显示
            ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1%}'))
        
        # 调整布局
        plt.tight_layout()
        
        # 保存图片
        output_file = os.path.join(output_dir, "training_comparison_6_datasets.png")
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"📊 Saved comparison plot to: {output_file}")
        
        # 显示图片
        plt.show()
    
    def plot_individual_datasets(self, results: Dict[str, Dict], output_dir: str = "plots"):
        """为每个数据集单独绘制一张图"""
        os.makedirs(output_dir, exist_ok=True)
        
        for dataset_name in self.dataset_names:
            plt.figure(figsize=(10, 6))
            
            # 绘制每个训练集的结果
            for j, (run_name, run_data) in enumerate(results.items()):
                epochs_data = run_data['epochs_data']
                
                # 提取该数据集的准确率
                epochs = sorted(epochs_data.keys())
                accuracies = [epochs_data[epoch][dataset_name] for epoch in epochs]
                
                if not epochs:  # 如果没有数据，跳过
                    continue
                
                # 使用run_name作为标签
                label = run_name
                
                # 绘制线条
                color = self.colors[j % len(self.colors)]
                linestyle = self.linestyles[j % len(self.linestyles)]
                
                plt.plot(epochs, accuracies, 
                        color=color, linestyle=linestyle, linewidth=2,
                        marker='o', markersize=6, label=label)
            
            # 设置图表属性 - 使用专业的ICLR论文标题格式
            title_map = {
                'helpful_rewrite': 'Helpful Domain - Rewrite -mse Performance',
                'math_rewrite': 'Mathematical Reasoning - Rewrite -mse Performance',
                'safety_rewrite': 'Safety Domain - Rewrite -mse Performance',
                'helpful_reject': 'Helpful Domain - Reject -mse Performance-mse',
                'math_reject': 'Mathematical Reasoning - Reject -mse Performance',
                'safety_reject': 'Safety Domain - Reject -mse Performance'
            }
            plt.title(title_map[dataset_name], fontsize=16, fontweight='bold')
            plt.xlabel('Epoch', fontsize=14)
            plt.ylabel('Accuracy', fontsize=14)
            plt.grid(True, alpha=0.3)
            plt.legend(fontsize=12)
            
            # 统一设置y轴范围为35%-100%
            plt.ylim(0.35, 1.0)
            
            # 设置y轴为百分比显示
            plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1%}'))
            
            # 保存图片
            output_file = os.path.join(output_dir, f"training_comparison_{dataset_name}.png")
            plt.savefig(output_file, dpi=300, bbox_inches='tight')
            print(f"📊 Saved {dataset_name} plot to: {output_file}")
            
            plt.close()
    
    def plot_overall_comparison(self, results: Dict[str, Dict], output_dir: str = "plots"):
        """绘制overall对比图：分别绘制rewrite和reject的平均表现"""
        os.makedirs(output_dir, exist_ok=True)
        
        # 计算每个维度的rewrite和reject平均值
        overall_data = {}
        
        for run_name, run_data in results.items():
            epochs_data = run_data['epochs_data']
            
            # 计算每个epoch的overall数据
            epoch_overall = {}
            for epoch, accuracy_data in epochs_data.items():
                # 计算三个维度的rewrite和reject平均值
                rewrite_accuracies = [
                    accuracy_data['helpful_rewrite'],
                    accuracy_data['math_rewrite'], 
                    accuracy_data['safety_rewrite']
                ]
                reject_accuracies = [
                    accuracy_data['helpful_reject'],
                    accuracy_data['math_reject'],
                    accuracy_data['safety_reject']
                ]
                
                rewrite_avg = np.mean(rewrite_accuracies)
                reject_avg = np.mean(reject_accuracies)
                
                epoch_overall[epoch] = {
                    'rewrite_avg': rewrite_avg,
                    'reject_avg': reject_avg
                }
            
            overall_data[run_name] = epoch_overall
        
        # 绘制Rewrite Overall图
        plt.figure(figsize=(12, 8))
        
        for j, (run_name, epoch_data) in enumerate(overall_data.items()):
            epochs = sorted(epoch_data.keys())
            rewrite_avgs = [epoch_data[epoch]['rewrite_avg'] for epoch in epochs]
            
            if not epochs:
                continue
            
            color = self.colors[j % len(self.colors)]
            
            # 绘制rewrite平均线
            plt.plot(epochs, rewrite_avgs, 
                    color=color, linestyle='-', linewidth=2.5,
                    marker='o', markersize=6, 
                    label=run_name)
        
        # 设置Rewrite图属性
        plt.title('Average Performance Across Rewrite Tasks -mse\n(Helpful, Math, and Safety Domains)', 
                fontsize=16, fontweight='bold')
        plt.xlabel('Training Epoch', fontsize=14)
        plt.ylabel('Average Accuracy', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.legend(fontsize=12, loc='best')
        plt.ylim(0.35, 1.0)  # 35%-100%
        
        # 设置y轴为百分比显示
        plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1%}'))
        
        # 保存Rewrite图
        output_file = os.path.join(output_dir, "overall_rewrite_performance.png")
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"📊 Saved overall rewrite plot to: {output_file}")
        plt.show()
        
        # 绘制Reject Overall图
        plt.figure(figsize=(12, 8))
        
        for j, (run_name, epoch_data) in enumerate(overall_data.items()):
            epochs = sorted(epoch_data.keys())
            reject_avgs = [epoch_data[epoch]['reject_avg'] for epoch in epochs]
            
            if not epochs:
                continue
            
            color = self.colors[j % len(self.colors)]
            
            # 绘制reject平均线
            plt.plot(epochs, reject_avgs, 
                    color=color, linestyle='-', linewidth=2.5,
                    marker='s', markersize=6, 
                    label=run_name)
        
        # 设置Reject图属性
        plt.title('Average Performance Across Reject Tasks -mse\n(Helpful, Math, and Safety Domains)', 
                fontsize=16, fontweight='bold')
        plt.xlabel('Training Epoch', fontsize=14)
        plt.ylabel('Average Accuracy', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.legend(fontsize=12, loc='best')
        plt.ylim(0.35, 1.0)  # 35%-100%
        
        # 设置y轴为百分比显示
        plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1%}'))
        
        # 保存Reject图
        output_file = os.path.join(output_dir, "overall_reject_performance.png")
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"📊 Saved overall reject plot to: {output_file}")
        plt.show()
    
    def print_summary(self, results: Dict[str, Dict]):
        """打印训练结果摘要"""
        print("\n" + "="*80)
        print("TRAINING RESULTS SUMMARY")
        print("="*80)
        
        for run_name, run_data in results.items():
            epochs_data = run_data['epochs_data']
            
            print(f"\n📈 {run_name}")
            print(f"   Epochs: {len(epochs_data)}")
            
            if epochs_data:
                # 显示最终epoch的准确率
                final_epoch = max(epochs_data.keys())
                final_accuracies = epochs_data[final_epoch]
                
                print(f"   Final Epoch {final_epoch} Accuracies:")
                for dataset_name in self.dataset_names:
                    accuracy = final_accuracies[dataset_name]
                    print(f"     {self.dataset_display_names[dataset_name]}: {accuracy:.1%}")

def main():
    parser = argparse.ArgumentParser(description="Plot training comparison across 6 evaluation datasets")
    parser.add_argument("--json_files", nargs="+", required=True, 
                       help="Paths to JSON result files (at least 3 for comparison)")
    parser.add_argument("--output_dir", default="plots", 
                       help="Output directory for plots")
    parser.add_argument("--individual", action="store_true", 
                       help="Generate individual plots for each dataset")
    parser.add_argument("--overall", action="store_true", 
                       help="Generate overall comparison plot (rewrite vs reject averages)")
    
    args = parser.parse_args()
    
    if len(args.json_files) < 2:
        print("❌ Need at least 2 JSON files for comparison")
        return
    
    # 创建绘图器
    plotter = TrainingComparisonPlotter()
    
    # 加载结果
    print("🔄 Loading training results...")
    results = plotter.load_training_results(args.json_files)
    
    if not results:
        print("❌ No valid results loaded")
        return
    
    # 打印摘要
    plotter.print_summary(results)
    
    # 绘制比较图
    print("\n🔄 Generating comparison plots...")
    plotter.plot_comparison(results, args.output_dir)
    
    if args.individual:
        print("\n🔄 Generating individual dataset plots...")
        plotter.plot_individual_datasets(results, args.output_dir)
    
    if args.overall:
        print("\n🔄 Generating overall comparison plot...")
        plotter.plot_overall_comparison(results, args.output_dir)
    
    print(f"\n✅ All plots saved to: {args.output_dir}")

if __name__ == "__main__":
    main()
