import json
import os
import re
import numpy as np
import pandas as pd
from collections import Counter
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from openai import OpenAI
from tqdm import tqdm
import warnings
import random
import time
warnings.filterwarnings('ignore')

# Set font for plots
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class DiversityAnalyzer:
    def __init__(self, results_dir: str, openai_api_key: str = None, openai_base_url: str = None, random_seed: int = 42, sample_size: int = 50):
        """
        Initialize diversity analyzer
        
        Args:
            results_dir: Path to results directory
            openai_api_key: OpenAI API key (optional)
            openai_base_url: OpenAI base URL (optional)
            random_seed: Random seed for consistent sampling across models
            sample_size: Number of questions to sample for analysis
        """
        self.results_dir = results_dir
        self.openai_api_key = openai_api_key
        self.random_seed = random_seed
        self.sample_size = sample_size
        
        # Set random seed for consistent sampling
        random.seed(random_seed)
        np.random.seed(random_seed)
        
        # Initialize OpenAI client
        if openai_api_key:
            if openai_base_url:
                self.openai_client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)
            else:
                self.openai_client = OpenAI(api_key=openai_api_key)
        else:
            self.openai_client = None
        
        # Store analysis results
        self.diversity_results = {}
        
        # Store selected question indices for consistency across models
        self.selected_indices = None
        
    def load_data_by_question(self, model_name: str, budgets: List[int] = None) -> Dict[str, Dict[int, str]]:
        """
        Load data organized by question ID, with responses for each budget
        
        Args:
            model_name: Model name
            budgets: List of budgets
            
        Returns:
            Dict[str, Dict[int, str]]: question_id -> {budget -> response}
        """
        if budgets is None:
            budgets = [512, 1024, 2048, 4096, 8192]
        
        print(f"Loading data for model {model_name} with budgets {budgets}...")
        
        # Load all budget files
        budget_data = {}
        for budget in budgets:
            # Try different file naming patterns
            possible_paths = [
                os.path.join(self.results_dir, f"{model_name}_{budget}_test.jsonl"),
                os.path.join(self.results_dir, f"{model_name}-{budget}_test.jsonl"),
            ]
            
            file_path = None
            for path in possible_paths:
                if os.path.exists(path):
                    file_path = path
                    break
            
            if not file_path:
                print(f"Warning: No file found for model {model_name} with budget {budget}")
                continue
            
            try:
                responses = []
                with open(file_path, 'r', encoding='utf-8') as f:
                    for line in f:
                        if line.strip():
                            data_point = json.loads(line)
                            if 'generated_text' in data_point:
                                responses.append(data_point['generated_text'])
                
                budget_data[budget] = responses
                print(f"  Budget {budget}: Loaded {len(responses)} responses")
                
            except Exception as e:
                print(f"Error: Failed to load file {file_path}: {e}")
                continue
        
        # Organize by question ID (assuming same order across budgets)
        question_responses = {}
        min_responses = min(len(responses) for responses in budget_data.values()) if budget_data else 0
        
        for i in range(min_responses):
            question_id = f"question_{i+1}"
            question_responses[question_id] = {}
            for budget in budgets:
                if budget in budget_data and i < len(budget_data[budget]):
                    question_responses[question_id][budget] = budget_data[budget][i]
        
        print(f"Organized {len(question_responses)} questions with responses across all budgets")
        
        # 使用与其他模型相同的问题选择策略
        if self.selected_indices is None:
            # 第一次：生成随机索引
            all_indices = list(range(len(question_responses)))
            self.selected_indices = sorted(random.sample(all_indices, min(self.sample_size, len(all_indices))))
            print(f"Selected {len(self.selected_indices)} random questions with indices: {self.selected_indices[:10]}...")
        else:
            # 使用相同的索引以保持一致性
            print(f"Using pre-selected {len(self.selected_indices)} questions for consistency")
        
        # 过滤到只选择的问题
        question_ids = list(question_responses.keys())
        sampled_question_responses = {}
        for idx in self.selected_indices:
            if idx < len(question_ids):
                question_id = question_ids[idx]
                sampled_question_responses[question_id] = question_responses[question_id]
        
        print(f"Sampled {len(sampled_question_responses)} questions for analysis")
        return sampled_question_responses
    
    def load_5samples_baseline_data(self, model_name: str, budget: int = 8192) -> Dict[str, List[str]]:
        """
        加载5个sample的baseline数据，按前后五条五条选取作为一个问题
        
        Args:
            model_name: 模型名称
            budget: 预算大小
            
        Returns:
            Dict[str, List[str]]: question_id -> [response1, response2, response3, response4, response5]
        """
        # 构建文件路径
        file_path = os.path.join(self.results_dir, f"{model_name}-5samples_{budget}_test.jsonl")
        
        if not os.path.exists(file_path):
            print(f"Warning: No file found for model {model_name} with 5samples at budget {budget}")
            return {}
        
        print(f"Loading 5samples baseline data from {file_path}...")
        
        # 读取所有数据
        all_responses = []
        all_prompts = []
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        data_point = json.loads(line)
                        if 'prompt' in data_point and 'generated_text' in data_point:
                            all_prompts.append(data_point['prompt'])
                            all_responses.append(data_point['generated_text'])
            
            print(f"Loaded {len(all_responses)} total responses")
            
            # 按前后五条五条分组
            question_responses = {}
            question_count = 0
            
            for i in range(0, len(all_responses), 5):
                if i + 5 <= len(all_responses):
                    question_id = f"question_{question_count + 1}"
                    question_responses[question_id] = all_responses[i:i+5]
                    question_count += 1
            
            print(f"Grouped into {len(question_responses)} questions with 5 samples each")
            
            # 使用与其他模型相同的问题选择策略
            if self.selected_indices is None:
                # 第一次：生成随机索引
                all_question_ids = list(question_responses.keys())
                self.selected_indices = sorted(random.sample(range(len(all_question_ids)), min(self.sample_size, len(all_question_ids))))
                print(f"Selected {len(self.selected_indices)} random questions with indices: {self.selected_indices[:10]}...")
            else:
                # 使用相同的索引以保持一致性
                print(f"Using pre-selected {len(self.selected_indices)} questions for consistency")
            
            # 过滤到只选择的问题
            all_question_ids = list(question_responses.keys())
            sampled_question_responses = {}
            for idx in self.selected_indices:
                if idx < len(all_question_ids):
                    question_id = all_question_ids[idx]
                    sampled_question_responses[question_id] = question_responses[question_id]
            
            print(f"Sampled {len(sampled_question_responses)} questions for analysis")
            # breakpoint()
            return sampled_question_responses
            
        except Exception as e:
            print(f"Error: Failed to load file {file_path}: {e}")
            return {}
    
    def analyze_5samples_baseline_diversity(self, model_name: str, budget: int = 8192, max_tokens: int = 500) -> pd.DataFrame:
        """
        分析5个sample的baseline多样性
        
        Args:
            model_name: 模型名称
            budget: 预算大小
            
        Returns:
            pd.DataFrame: 5个sample的多样性分析结果
        """
        print(f"\nAnalyzing 5samples baseline diversity for {model_name} with budget {budget}...")
        
        # 加载5个sample的数据
        question_data = self.load_5samples_baseline_data(model_name, budget)
        if not question_data:
            print(f"No data found for model {model_name}")
            return pd.DataFrame()
        
        results = []
        
        print(f"Analyzing diversity for {len(question_data)} questions...")
        print("This may take a while due to OpenAI API calls...")
        
        # 使用tqdm进行进度跟踪
        for i, (question_id, responses) in enumerate(tqdm(question_data.items(), desc=f"Processing {model_name} 5samples")):
            if len(responses) != 5:
                continue
            
            # 计算n-gram多样性
            ngram_metrics = {}
            for n in [1, 2, 3]:
                metrics = self.calculate_ngram_diversity(responses, n)
                for key, value in metrics.items():
                    ngram_metrics[f"{key}_{n}gram"] = value
            
            # 计算语义多样性
            semantic_metrics = self.calculate_semantic_diversity(responses, max_tokens=max_tokens)
            
            # 计算OpenAI嵌入多样性
            openai_metrics = self.calculate_openai_embedding_diversity(responses, max_tokens=max_tokens)
            
            # 组合所有指标
            question_result = {
                'model': f"{model_name}_5samples",
                'question_id': question_id,  # 使用顺序编号确保与其他模型对应
                'response_count': len(responses),
                'budget': budget,
                **ngram_metrics,
                **semantic_metrics,
                **openai_metrics
            }
            
            results.append(question_result)
        
        # 转换为DataFrame
        df = pd.DataFrame(results)
        self.diversity_results[f"{model_name}_5samples"] = df
        
        return df
    
    def truncate_to_first_n_tokens(self, text: str, n: int = 500) -> str:
        """
        截取文本的前n个token
        
        Args:
            text: 输入文本
            n: 要保留的token数量
            
        Returns:
            str: 截取后的文本
        """
        if not text:
            return ""
        
        # 简单的按空格分词
        words = text.split()
        if len(words) <= n:
            return text
        
        # 截取前n个token
        truncated_words = words[:n]
        return ' '.join(truncated_words)

    def calculate_ngram_diversity(self, texts: List[str], n: int = 2, max_tokens: int = 500) -> Dict[str, float]:
        """
        Calculate n-gram diversity metrics for a set of responses
        
        Args:
            texts: List of response texts
            n: n-gram value
            max_tokens: 最大token数量，只分析前max_tokens个token
            
        Returns:
            Dict[str, float]: Dictionary of diversity metrics
        """
        if not texts or len(texts) < 2:
            return {}
        
        # 截取每个文本的前max_tokens个token
        truncated_texts = [self.truncate_to_first_n_tokens(text, max_tokens) for text in texts]
        
        # Extract n-grams
        all_ngrams = []
        for text in truncated_texts:
            # Simple tokenization (split by space)
            words = text.split()
            if len(words) >= n:
                ngrams = [' '.join(words[i:i+n]) for i in range(len(words)-n+1)]
                all_ngrams.extend(ngrams)
        
        if not all_ngrams:
            return {}
        
        # Calculate diversity metrics
        ngram_counts = Counter(all_ngrams)
        total_ngrams = len(all_ngrams)
        unique_ngrams = len(ngram_counts)
        
        # 1. Type-Token Ratio (TTR)
        ttr = unique_ngrams / total_ngrams if total_ngrams > 0 else 0
        
        # 2. Shannon Entropy
        shannon = -sum((count / total_ngrams) * np.log2(count / total_ngrams) 
                      for count in ngram_counts.values() if count > 0)
        
        return {
            'type_token_ratio': ttr,
            'shannon_entropy': shannon,
            'unique_ngrams': unique_ngrams,
            'total_ngrams': total_ngrams
        }
    
    def calculate_semantic_diversity(self, texts: List[str], max_tokens: int = 500) -> Dict[str, float]:
        """
        Calculate semantic diversity metrics for a set of responses
        
        Args:
            texts: List of response texts
            max_tokens: 最大token数量，只分析前max_tokens个token
            
        Returns:
            Dict[str, float]: Semantic diversity metrics dictionary
        """
        if not texts or len(texts) < 2:
            return {}
        
        # 截取每个文本的前max_tokens个token
        truncated_texts = [self.truncate_to_first_n_tokens(text, max_tokens) for text in texts]
        
        try:
            # Use TF-IDF vectorization
            vectorizer = TfidfVectorizer(
                max_features=1000,
                stop_words='english',
                ngram_range=(1, 2),
                min_df=1,  # Changed to 1 since we have few texts
                max_df=1.0  # Changed to 1.0 since we have few texts
            )
            
            tfidf_matrix = vectorizer.fit_transform(truncated_texts)
            
            # Calculate cosine similarity matrix
            similarity_matrix = cosine_similarity(tfidf_matrix)
            
            # Calculate diversity metrics
            n = len(truncated_texts)
            similarities = []
            for i in range(n):
                for j in range(i+1, n):
                    similarities.append(similarity_matrix[i, j])
            
            avg_similarity = np.mean(similarities) if similarities else 0
            similarity_std = np.std(similarities) if similarities else 0
            semantic_diversity = 1 - avg_similarity
            
            return {
                'semantic_diversity': semantic_diversity,
                'avg_similarity': avg_similarity,
                'similarity_std': similarity_std
            }
            
        except Exception as e:
            print(f"Error calculating semantic diversity: {e}")
            return {}
    
    def calculate_openai_embedding_diversity(self, texts: List[str], max_tokens: int = 500) -> Dict[str, float]:
        """
        Calculate semantic diversity using OpenAI embeddings
        
        Args:
            texts: List of response texts
            max_tokens: 最大token数量，只分析前max_tokens个token
            
        Returns:
            Dict[str, float]: Embedding-based diversity metrics
        """
        if not self.openai_client or not texts or len(texts) < 2:
            return {}
        
        # 截取每个文本的前max_tokens个token
        truncated_texts = [self.truncate_to_first_n_tokens(text, max_tokens) for text in texts]
        
        try:
            # Get embeddings with retry logic
            embeddings = []
            for i, text in enumerate(truncated_texts):
                max_retries = 3
                for attempt in range(max_retries):
                    try:
                        print("Using OpenAI API to get embedding")
                        response = self.openai_client.embeddings.create(
                            input=text[:8000],  # Limit text length
                            model="text-embedding-3-small"
                        )
                        embeddings.append(response.data[0].embedding)
                        break
                    except Exception as e:
                        if attempt == max_retries - 1:
                            print(f"Failed to get embedding for text {i} after {max_retries} attempts: {e}")
                            return {}
                        time.sleep(1)  # Wait before retry
            
            if len(embeddings) != len(truncated_texts):
                return {}
            
            embeddings = np.array(embeddings)
            
            # Calculate cosine similarity matrix
            similarity_matrix = cosine_similarity(embeddings)
            
            # Calculate diversity metrics
            n = len(embeddings)
            similarities = []
            for i in range(n):
                for j in range(i+1, n):
                    similarities.append(similarity_matrix[i, j])
            
            avg_similarity = np.mean(similarities)
            similarity_std = np.std(similarities)
            semantic_diversity = 1 - avg_similarity
            
            return {
                'openai_semantic_diversity': semantic_diversity,
                'openai_avg_similarity': avg_similarity,
                'openai_similarity_std': similarity_std
            }
            
        except Exception as e:
            print(f"Error calculating OpenAI embedding diversity: {e}")
            return {}
    
    def analyze_question_diversity(self, model_name: str, budgets: List[int] = None) -> pd.DataFrame:
        """
        Analyze diversity for each question across different budgets
        
        Args:
            model_name: Model name
            budgets: List of budgets
            
        Returns:
            pd.DataFrame: Diversity analysis results per question
        """
        # Load data organized by question
        question_data = self.load_data_by_question(model_name, budgets)
        if not question_data:
            print(f"No data found for model {model_name}")
            return pd.DataFrame()
        
        results = []
        
        print(f"\nAnalyzing diversity for {len(question_data)} questions...")
        print("This may take a while due to OpenAI API calls...")
        
        # Use tqdm for progress tracking
        for question_id, budget_responses in tqdm(question_data.items(), desc=f"Processing {model_name}"):
            # Get all responses for this question
            responses = list(budget_responses.values())
            
            if len(responses) < 2:
                continue
            
            # Calculate n-gram diversity
            ngram_metrics = {}
            for n in [1, 2, 3]:
                metrics = self.calculate_ngram_diversity(responses, n)
                for key, value in metrics.items():
                    ngram_metrics[f"{key}_{n}gram"] = value
            
            # Calculate semantic diversity
            semantic_metrics = self.calculate_semantic_diversity(responses)
            
            # Calculate OpenAI embedding diversity
            openai_metrics = self.calculate_openai_embedding_diversity(responses)
            
            # Combine all metrics
            question_result = {
                'model': model_name,
                'question_id': question_id,
                'response_count': len(responses),
                **ngram_metrics,
                **semantic_metrics,
                **openai_metrics
            }
            
            results.append(question_result)
        
        # Convert to DataFrame
        df = pd.DataFrame(results)
        self.diversity_results[model_name] = df
        
        return df
    
    def analyze_baseline_diversity(self, baseline_model_name: str = None, budgets: List[int] = None) -> pd.DataFrame:
        """
        分析baseline多样性 - 从指定模型文件中随机选择不同问题的输出组合
        
        Args:
            baseline_model_name: 用作baseline的模型名称，如果为None则使用第一个可用模型
            budgets: List of budgets
            
        Returns:
            pd.DataFrame: Baseline diversity analysis results
        """
        if budgets is None:
            budgets = [512, 1024, 2048, 4096, 8192]
        
        print(f"\nAnalyzing baseline diversity with {self.sample_size} random samples...")
        
        # 如果没有指定baseline模型，使用第一个可用的模型
        if baseline_model_name is None:
            available_models = list(self.diversity_results.keys())
            if not available_models:
                print("Error: No models available for baseline analysis")
                return pd.DataFrame()
            baseline_model_name = available_models[0]
            print(f"Using {baseline_model_name} as baseline model")
        
        # 加载baseline模型的所有数据
        baseline_data = self.load_data_by_question(baseline_model_name, budgets)
        
        if not baseline_data:
            print(f"Error: No data found for baseline model {baseline_model_name}")
            return pd.DataFrame()
        
        # 收集所有可用的响应
        all_available_responses = []
        question_budget_pairs = []
        
        for question_id, budget_responses in baseline_data.items():
            for budget, response in budget_responses.items():
                if response and response.strip():  # 确保响应不为空
                    all_available_responses.append(response)
                    question_budget_pairs.append((question_id, budget))
        
        print(f"Collected {len(all_available_responses)} available responses from {baseline_model_name}")
        
        if len(all_available_responses) < len(budgets):
            print(f"Warning: Not enough responses ({len(all_available_responses)}) for {len(budgets)} budgets")
            return pd.DataFrame()
        
        # 随机选择样本
        baseline_results = []
        np.random.seed(self.random_seed)
        
        for sample_idx in range(self.sample_size):
            # 随机选择5个不同budget的响应（来自不同问题）
            selected_responses = []
            selected_budgets = []
            selected_question_ids = []
            
            # 为每个budget随机选择一个响应
            for budget in budgets:
                # 从该budget的响应中随机选择一个
                budget_responses = []
                budget_indices = []
                
                for i, (q_id, b) in enumerate(question_budget_pairs):
                    if b == budget:
                        budget_responses.append(all_available_responses[i])
                        budget_indices.append(i)
                
                if budget_responses:
                    # 随机选择一个响应
                    random_idx = np.random.choice(len(budget_responses))
                    selected_response = budget_responses[random_idx]
                    selected_responses.append(selected_response)
                    selected_budgets.append(budget)
                    selected_question_ids.append(question_budget_pairs[budget_indices[random_idx]][0])
                else:
                    # 如果没有该budget的响应，从所有响应中随机选择
                    random_idx = np.random.choice(len(all_available_responses))
                    selected_responses.append(all_available_responses[random_idx])
                    selected_budgets.append(budget)
                    selected_question_ids.append(question_budget_pairs[random_idx][0])
            
            # 计算多样性指标
            ngram_metrics = {}
            for n in [1, 2, 3]:
                metrics = self.calculate_ngram_diversity(selected_responses, n)
                for key, value in metrics.items():
                    ngram_metrics[f"{key}_{n}gram"] = value
            
            # 计算语义多样性
            semantic_metrics = self.calculate_semantic_diversity(selected_responses)
            
            # 计算OpenAI嵌入多样性
            openai_metrics = self.calculate_openai_embedding_diversity(selected_responses)
            
            # 创建baseline结果
            baseline_result = {
                'model': 'baseline_random',
                'question_id': f'baseline_sample_{sample_idx+1}',
                'response_count': len(selected_responses),
                'budgets_used': selected_budgets,
                'source_questions': selected_question_ids,  # 记录来源问题ID
                **ngram_metrics,
                **semantic_metrics,
                **openai_metrics
            }
            
            baseline_results.append(baseline_result)
        
        # 转换为DataFrame
        baseline_df = pd.DataFrame(baseline_results)
        self.diversity_results['baseline_random'] = baseline_df
        
        print(f"Generated {len(baseline_results)} baseline samples using responses from {baseline_model_name}")
        return baseline_df
    
    def calculate_average_diversity(self, model_name: str) -> Dict[str, float]:
        """
        Calculate average diversity metrics across all questions
        
        Args:
            model_name: Model name
            
        Returns:
            Dict[str, float]: Average diversity metrics
        """
        if model_name not in self.diversity_results:
            return {}
        
        df = self.diversity_results[model_name]
        
        # Calculate averages for numeric columns
        avg_metrics = {}
        for col in df.columns:
            if col not in ['model', 'question_id'] and df[col].dtype in ['float64', 'int64']:
                avg_metrics[f'avg_{col}'] = df[col].mean()
                avg_metrics[f'std_{col}'] = df[col].std()
        
        return avg_metrics
    
    def plot_diversity_distribution(self, model_name: str, save_path: str = None):
        """
        Plot diversity distribution across questions
        
        Args:
            model_name: Model name
            save_path: Save path
        """
        if model_name not in self.diversity_results:
            print(f"No analysis results found for model {model_name}")
            return
        
        df = self.diversity_results[model_name]
        
        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle(f'{model_name} Response Diversity Distribution Across Questions', fontsize=16)
        
        # 1. Type-Token Ratio distribution
        for n in [1, 2, 3]:
            ax = axes[0, 0]
            ax.hist(df[f'type_token_ratio_{n}gram'], alpha=0.6, label=f'{n}-gram TTR', bins=20)
        ax.set_xlabel('Type-Token Ratio')
        ax.set_ylabel('Frequency')
        ax.set_title('Lexical Diversity Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 2. Shannon Entropy distribution
        for n in [1, 2, 3]:
            ax = axes[0, 1]
            ax.hist(df[f'shannon_entropy_{n}gram'], alpha=0.6, label=f'{n}-gram Shannon Entropy', bins=20)
        ax.set_xlabel('Shannon Entropy')
        ax.set_ylabel('Frequency')
        ax.set_title('Information Entropy Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 3. Semantic diversity distribution
        ax = axes[1, 0]
        ax.hist(df['semantic_diversity'], alpha=0.7, color='red', bins=20, label='TF-IDF Semantic Diversity')
        if 'openai_semantic_diversity' in df.columns:
            ax.hist(df['openai_semantic_diversity'], alpha=0.7, color='green', bins=20, label='OpenAI Semantic Diversity')
        ax.set_xlabel('Semantic Diversity')
        ax.set_ylabel('Frequency')
        ax.set_title('Semantic Diversity Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 4. Box plot of diversity metrics
        ax = axes[1, 1]
        diversity_metrics = ['type_token_ratio_2gram', 'shannon_entropy_2gram', 'semantic_diversity']
        if 'openai_semantic_diversity' in df.columns:
            diversity_metrics.append('openai_semantic_diversity')
        
        box_data = [df[metric] for metric in diversity_metrics]
        ax.boxplot(box_data, labels=[metric.replace('_', '\n') for metric in diversity_metrics])
        ax.set_ylabel('Diversity Score')
        ax.set_title('Diversity Metrics Comparison')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Chart saved to: {save_path}")
        
        # plt.show()
    
    def save_results(self, output_path: str):
        """
        Save analysis results
        
        Args:
            output_path: Output file path
        """
        summary_results = []
        
        for model_name, df in self.diversity_results.items():
            avg_metrics = self.calculate_average_diversity(model_name)
            
            # 创建第二张图格式的数据，保留4位小数
            summary_data = {
                'Model': model_name,
                'Number_of_questions': len(df),
                'Average_TTR_2gram': round(avg_metrics.get('avg_type_token_ratio_2gram', 0), 4),
                'Average_Shannon_Entropy_2gram': round(avg_metrics.get('avg_shannon_entropy_2gram', 0), 4),
                'Average_Semantic_Diversity': round(avg_metrics.get('avg_semantic_diversity', 0), 4),
                'Average_OpenAI_Semantic_Diversity': round(avg_metrics.get('avg_openai_semantic_diversity', 0), 4)
            }
            summary_results.append(summary_data)
        
        # 保存为CSV
        summary_df = pd.DataFrame(summary_results)
        summary_df.to_csv(output_path, index=False)
        print(f"Summary results saved to: {output_path}")
        
        # 打印第二张图格式的摘要，保留4位小数
        print("\n=== Diversity Analysis Summary ===")
        for model_name, df in self.diversity_results.items():
            print(f"\nModel: {model_name}")
            print(f"Number of questions: {len(df)}")
            avg_metrics = self.calculate_average_diversity(model_name)
            print(f"Average TTR (2-gram): {avg_metrics.get('avg_type_token_ratio_2gram', 0):.4f}")
            print(f"Average Shannon Entropy (2-gram): {avg_metrics.get('avg_shannon_entropy_2gram', 0):.4f}")
            print(f"Average Semantic Diversity: {avg_metrics.get('avg_semantic_diversity', 0):.4f}")
            if 'avg_openai_semantic_diversity' in avg_metrics:
                print(f"Average OpenAI Semantic Diversity: {avg_metrics.get('avg_openai_semantic_diversity', 0):.4f}")

def main():
    """
    Main function - run diversity analysis
    """
    # Configuration parameters
    results_dir = "../results"  # Results file directory
    openai_api_key = "sk-h09s76zO0rZMYe1s4NNkZiiIX0svbldcmhTI7Vlgk7irrsd0"  # Set your OpenAI API key
    openai_base_url = "https://zjuapi.com/v1"  # Set your OpenAI base URL
    budgets = [512, 1024, 2048, 4096, 8192]  # Budgets to analyze
    random_seed = 42  # Fixed random seed for consistent sampling
    sample_size = 2  # Number of questions to sample
    
    # Models to analyze
    models_to_analyze = [
        'l1-8b',
        'l1-8b-ours-openr1',
        'l1-8b-ours-deepscaler-LUFFY-style',
        'l1-1.5',
        'seed-36b'
    ]
    
    # Initialize analyzer with consistent sampling
    analyzer = DiversityAnalyzer(results_dir, openai_api_key, openai_base_url, random_seed, sample_size)
    
    # 首先分析5个sample的baseline数据
    print(f"\n{'='*50}")
    print("Starting 5samples baseline analysis")
    print(f"Using random seed {random_seed} and sample size {sample_size}")
    print(f"{'='*50}")
    
    # 分析5个sample的baseline
    baseline_results_df = analyzer.analyze_5samples_baseline_diversity('l1-8b', budget=8192, max_tokens=9000)
    
    if not baseline_results_df.empty:
        # 绘制baseline的多样性分布
        plot_path = f"plots/diversity_analysis_l1-8b_5samples_baseline.png"
        analyzer.plot_diversity_distribution('l1-8b_5samples', plot_path)
    
    # Analyze each model
    for model_name in models_to_analyze:
        print(f"\n{'='*50}")
        print(f"Starting analysis for model: {model_name}")
        print(f"Using random seed {random_seed} and sample size {sample_size}")
        print(f"{'='*50}")
        
        # Perform analysis
        results_df = analyzer.analyze_question_diversity(model_name, budgets)
        
        if not results_df.empty:
            # Plot trends
            plot_path = f"plots/diversity_analysis_{model_name}.png"
            analyzer.plot_diversity_distribution(model_name, plot_path)
    
    # 合并所有结果到一个文件
    print("\n合并所有结果...")
    all_results = []
    for model_name, df in analyzer.diversity_results.items():
        if not df.empty:
            all_results.append(df)
    
    if all_results:
        # 合并所有DataFrame
        combined_results = pd.concat(all_results, ignore_index=True)
        combined_results.to_csv("plots/diversity_analysis_results_detailed.csv", index=False)
        print(f"合并结果已保存到: plots/diversity_analysis_results_detailed.csv")
    
    # Save all results
    analyzer.save_results("plots/diversity_analysis_results.csv")

if __name__ == "__main__":
    main()
