import os
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import argparse
from collections import defaultdict
import numpy as np

def count_words(file_path):
    """
    统计文件中的单词数
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
            # 简单的按空格分割计数
            return len(text.split())
    except Exception as e:
        print(f"处理文件时出错 {file_path}: {str(e)}")
        return 0

def analyze_word_distribution(base_paths):
    """
    分析多个目录下所有txt文件的字数分布
    
    Args:
        base_paths: 基础路径列表
    """
    word_counts = defaultdict(list)  # 用于存储每个目录的字数统计
    
    for base_path in base_paths:
        print(f"\n处理目录: {base_path}")
        
        # 遍历目录下的所有txt文件
        for root, dirs, files in os.walk(base_path):
            txt_files = [f for f in files if f.endswith('.txt')]
            if not txt_files:
                continue
                
            for txt_file in tqdm(txt_files, desc=f"处理 {os.path.basename(root)}"):
                file_path = os.path.join(root, txt_file)
                word_count = count_words(file_path)
                if word_count > 0:  # 只记录有效的计数
                    word_counts[base_path].append({
                        'file': file_path,
                        'word_count': word_count
                    })
    
    return word_counts

def plot_distribution(word_counts, output_dir):
    """
    绘制字数分布图并保存统计结果，使用渐变色
    """
    titles = ['ACS', 'Cell', 'Nature', 'Science']
    plt.figure(figsize=(15, 10))
    
    # 设置颜色渐变
    colors = ['#FF9999', '#66B2FF', '#99FF99', '#FFCC99']  # 红、蓝、绿、橙
    
    # 为每个目录创建一个子图
    num_paths = len(word_counts)
    rows = (num_paths + 1) // 2  # 向上取整
    cols = min(2, num_paths)
    
    all_stats = []
    
    for idx, (path, counts) in enumerate(word_counts.items(), 1):
        plt.subplot(rows, cols, idx)
        
        # 提取字数列表
        counts_list = [item['word_count'] for item in counts]
        
        # 创建渐变色
        color = colors[idx % len(colors)]
        
        # 绘制分布图，添加透明度和渐变效果
        ax = plt.gca()
        n, bins, patches = plt.hist(counts_list, bins=30, density=True, alpha=0.7, 
                                  color=color, edgecolor='white')
        
        # 添加核密度估计曲线
        sns.kdeplot(data=counts_list, color='darkblue', linewidth=2)
        
        # 为直方图添加渐变效果
        bin_centers = 0.5 * (bins[:-1] + bins[1:])
        col = plt.cm.get_cmap('coolwarm')(np.linspace(0.2, 0.8, len(patches)))
        for c, p in zip(col, patches):
            plt.setp(p, 'facecolor', c)
        
        plt.title(f'{titles[idx-1]}', 
                 fontsize=20, fontweight='bold')
        plt.xlabel('Story Length', fontsize=20)
        plt.ylabel('', fontsize=10)
        
        # 计算统计信息
        stats = {
            'directory': os.path.basename(path),
            'total_files': len(counts),
            'mean_words': round(sum(counts_list) / len(counts_list), 2),
            'min_words': min(counts_list),
            'max_words': max(counts_list),
            'median_words': round(pd.Series(counts_list).median(), 2)
        }
        all_stats.append(stats)
        
        # 在图上添加统计信息，使用半透明背景
        plt.text(0.98, 0.98, 
                f"Files: {stats['total_files']}\n"
                f"Mean: {stats['mean_words']}\n"
                f"Median: {stats['median_words']}",
                transform=plt.gca().transAxes,
                verticalalignment='top',
                horizontalalignment='right',  # 右对齐
                fontsize=20,  # 增大字号
                bbox=dict(boxstyle='round', 
                         facecolor='white',
                         alpha=0.8,
                         edgecolor=color,
                         pad=0.5))  # 增加内边距
        
        # 美化坐标轴
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        plt.xticks(fontsize=18)  # 增大到20
        plt.yticks(fontsize=18)  # 增大到20
    plt.tight_layout()
    
    # 保存图表，增加DPI以提高质量
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, 'word_distribution.png'), 
                dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    
    # 保存详细统计信息
    stats_df = pd.DataFrame(all_stats)
    stats_df.to_csv(os.path.join(output_dir, 'word_statistics.csv'), index=False)
    
    # 保存所有文件的详细信息
    all_files = []
    for path, items in word_counts.items():
        for item in items:
            all_files.append({
                'directory': os.path.basename(path),
                'file': item['file'],
                'word_count': item['word_count']
            })
    
    files_df = pd.DataFrame(all_files)
    files_df.to_csv(os.path.join(output_dir, 'all_files_word_counts.csv'), index=False)
    
    return stats_df, files_df

def plot_total_distribution(word_counts, output_dir):
    """
    绘制所有数据的总体分布图（不区分期刊）
    """
    plt.figure(figsize=(12, 8))
    
    # 收集所有数据
    all_counts = []
    for path, counts in word_counts.items():
        all_counts.extend([item['word_count'] for item in counts])
    
    # 绘制主直方图
    ax = plt.gca()
    n, bins, patches = plt.hist(all_counts, bins=50, density=True, alpha=0.7,
                              color='#66B2FF', edgecolor='white')
    
    # 添加核密度估计曲线
    sns.kdeplot(data=all_counts, color='darkblue', linewidth=3)
    
    # 为直方图添加渐变效果
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    col = plt.cm.get_cmap('coolwarm')(np.linspace(0.2, 0.8, len(patches)))
    for c, p in zip(col, patches):
        plt.setp(p, 'facecolor', c)
    
    # 设置标题和标签
    # plt.title('Story Length Distribution', fontsize=24, fontweight='bold', pad=20)
    plt.xlabel('Story Length', fontsize=30)
    plt.ylabel('', fontsize=20)
    
    # 计算并添加统计信息
    stats = {
        'total_files': len(all_counts),
        'mean_words': round(np.mean(all_counts), 2),
        'median_words': round(np.median(all_counts), 2)
    }
    
    # 在图上添加统计信息
    plt.text(0.8, 0.8,
            f"Total Files: {stats['total_files']}\n"
            f"Mean: {stats['mean_words']}\n"
            f"Median: {stats['median_words']}",
            transform=plt.gca().transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            fontsize=30,
            bbox=dict(boxstyle='round',
                     facecolor='white',
                     alpha=0.8,
                     edgecolor='#66B2FF',
                     pad=0.5))
    
    # 美化坐标轴
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # 增大坐标轴刻度字体大小
    plt.xticks(fontsize=30)  # 增大到20
    plt.yticks(fontsize=30)  # 增大到20
    
    # 增大坐标轴数字大小
    ax.tick_params(axis='both', which='major', labelsize=30)  # 增加这行
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图表
    plt.savefig(os.path.join(output_dir, 'total_distribution.png'),
                dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none')

def main():
    plt.rcParams['font.family'] = 'Times New Roman'  # 设置全局字体
    base_paths = [
       'ACS/Story',
       'Cell/Story',
       'Nature/Story',
       'Science/Story'
    ]
    output = './Dataset/statistics/'
    
    # 分析字数分布
    word_counts = analyze_word_distribution(base_paths)
    
    # 绘制分布图并保存统计信息
    stats_df, files_df = plot_distribution(word_counts, output)
    
    # 绘制总体分布图
    plot_total_distribution(word_counts, output)
    
    # 打印汇总信息
    print("\n=== 字数统计汇总 ===")
    print(stats_df.to_string(index=False))
    print(f"\n结果已保存到: {output}")
    print("- word_distribution.png: 分布图")
    print("- total_distribution.png: 总体分布图")
    print("- word_statistics.csv: 统计汇总")
    print("- all_files_word_counts.csv: 详细文件列表")

if __name__ == "__main__":
    main()