import json
import os
import glob
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.spatial.distance import cosine, euclidean
from scipy.stats import wasserstein_distance
import warnings
warnings.filterwarnings('ignore')


def load_model_and_tokenizer(model_path, device='cuda'):
    """
    加载模型和分词器
    
    Args:
        model_path: 模型路径
        device: 设备 (cuda 或 cpu)
    
    Returns:
        model, tokenizer
    """
    print(f"正在加载模型: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    model = model.to(device)
    model.eval()
    print(f"模型已加载到 {device}")
    return model, tokenizer


def extract_embeddings(texts, model, tokenizer, device='cuda', max_length=512, batch_size=16, pooling='mean'):
    """
    提取文本的 embeddings (last hidden states)
    
    Args:
        texts: 文本列表
        model: 模型
        tokenizer: 分词器
        device: 设备
        max_length: 最大序列长度
        batch_size: 批次大小
        pooling: 池化方式 ('mean', 'max', 'cls', 'last')
    
    Returns:
        numpy array: embeddings
    """
    embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="提取 embeddings"):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize
        inputs = tokenizer(
            batch_texts, 
            padding=True, 
            truncation=True, 
            max_length=max_length,
            return_tensors='pt'
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            # 获取最后一层的 hidden states
            last_hidden_states = outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
            
            # 根据pooling方式进行池化
            if pooling == 'mean':
                # Mean pooling (考虑 attention mask)
                attention_mask = inputs['attention_mask'].unsqueeze(-1)
                embeddings_batch = (last_hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
            elif pooling == 'max':
                # Max pooling
                embeddings_batch = last_hidden_states.max(1)[0]
            elif pooling == 'cls':
                # CLS token
                embeddings_batch = last_hidden_states[:, 0, :]
            elif pooling == 'last':
                # Last token
                seq_lengths = inputs['attention_mask'].sum(1) - 1
                embeddings_batch = last_hidden_states[torch.arange(last_hidden_states.size(0)), seq_lengths]
            else:
                raise ValueError(f"Unknown pooling method: {pooling}")
            
            embeddings.append(embeddings_batch.cpu().numpy())
    
    return np.vstack(embeddings)


def read_responses_from_file(file_path):
    """
    从 jsonl 文件读取所有 responses
    """
    responses = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line.strip())
                    if 'output' in data:
                        responses.append(data['output'])
                    elif 'response' in data:
                        responses.append(data['response'])
                except (json.JSONDecodeError, KeyError):
                    continue
    except FileNotFoundError:
        return []
    return responses


def dimensionality_reduction(embeddings, method='pca', n_components=2, random_state=42):
    """
    降维
    
    Args:
        embeddings: embeddings 数组
        method: 降维方法 ('pca' 或 'tsne')
        n_components: 降维后的维度
        random_state: 随机种子
    
    Returns:
        降维后的数组
    """
    print(f"使用 {method.upper()} 进行降维到 {n_components} 维...")
    
    if method == 'pca':
        reducer = PCA(n_components=n_components, random_state=random_state)
    elif method == 'tsne':
        # t-SNE 建议先用 PCA 降到 50 维左右
        if embeddings.shape[1] > 50:
            print("  先使用 PCA 降维到 50 维...")
            pca = PCA(n_components=50, random_state=random_state)
            embeddings = pca.fit_transform(embeddings)
        reducer = TSNE(n_components=n_components, random_state=random_state, perplexity=30, n_iter=1000)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    reduced = reducer.fit_transform(embeddings)
    
    if method == 'pca':
        explained_variance = reducer.explained_variance_ratio_
        print(f"  解释方差比: {explained_variance}")
        print(f"  累计解释方差: {explained_variance.sum():.4f}")
    
    return reduced


def compute_distribution_metrics(embeddings1, embeddings2):
    """
    计算两个 embedding 分布之间的度量
    
    Args:
        embeddings1: 第一组 embeddings
        embeddings2: 第二组 embeddings
    
    Returns:
        dict: 包含各种度量的字典
    """
    metrics = {}
    
    # 计算中心点（均值）
    center1 = embeddings1.mean(axis=0)
    center2 = embeddings2.mean(axis=0)
    
    # 中心点之间的距离
    metrics['center_euclidean'] = euclidean(center1, center2)
    metrics['center_cosine'] = cosine(center1, center2)
    
    # 计算标准差（分布的分散程度）
    std1 = embeddings1.std(axis=0).mean()
    std2 = embeddings2.std(axis=0).mean()
    metrics['std1'] = std1
    metrics['std2'] = std2
    
    # Wasserstein distance (对每个维度单独计算，然后取平均)
    n_dims = min(embeddings1.shape[1], 10)  # 只对前10个维度计算以节省时间
    wasserstein_dists = []
    for dim in range(n_dims):
        wd = wasserstein_distance(embeddings1[:, dim], embeddings2[:, dim])
        wasserstein_dists.append(wd)
    metrics['wasserstein_mean'] = np.mean(wasserstein_dists)
    
    return metrics


def plot_comparison(embeddings1, embeddings2, output_path, step, dir1_name='Dir1', dir2_name='Dir2', 
                   method='PCA', metrics=None):
    """
    绘制两个分布的对比图
    
    Args:
        embeddings1: 第一组降维后的 embeddings (N x 2)
        embeddings2: 第二组降维后的 embeddings (M x 2)
        output_path: 输出路径
        step: 训练步数
        dir1_name: 第一个目录名称
        dir2_name: 第二个目录名称
        method: 降维方法名称
        metrics: 度量字典
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 10))
    
    # 绘制散点图
    alpha = 0.5
    s = 20
    
    ax.scatter(embeddings1[:, 0], embeddings1[:, 1], 
              c='#1f77b4', label=dir1_name, alpha=alpha, s=s, edgecolors='none')
    ax.scatter(embeddings2[:, 0], embeddings2[:, 1], 
              c='#ff7f0e', label=dir2_name, alpha=alpha, s=s, edgecolors='none')
    
    # 绘制中心点
    center1 = embeddings1.mean(axis=0)
    center2 = embeddings2.mean(axis=0)
    ax.scatter(center1[0], center1[1], c='#1f77b4', marker='*', s=500, 
              edgecolors='black', linewidths=2, label=f'{dir1_name} 中心', zorder=5)
    ax.scatter(center2[0], center2[1], c='#ff7f0e', marker='*', s=500, 
              edgecolors='black', linewidths=2, label=f'{dir2_name} 中心', zorder=5)
    
    # 绘制中心点连线
    ax.plot([center1[0], center2[0]], [center1[1], center2[1]], 
           'k--', linewidth=2, alpha=0.5, label='中心连线')
    
    ax.set_xlabel(f'{method} Component 1', fontsize=13, fontweight='bold')
    ax.set_ylabel(f'{method} Component 2', fontsize=13, fontweight='bold')
    
    title = f'Representation Distribution Comparison (Step {step})\n{method} Visualization'
    if metrics:
        title += f"\ncenter euclidean: {metrics['center_euclidean']:.4f}, center cosine: {metrics['center_cosine']:.4f}"
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    
    ax.legend(fontsize=10, loc='best', framealpha=0.9)
    ax.grid(True, alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"图表已保存到: {output_path}")
    plt.close()


def compare_two_directories_representations(dir1, dir2, model_path, pattern="*_16384.jsonl", 
                                           reduction_method='pca', pooling='mean',
                                           max_length=512, batch_size=16,
                                           output_dir='representation_analysis',
                                           device='cuda', max_samples=None):
    """
    对比两个目录中对应文件的表示分布差异
    
    Args:
        dir1: 第一个目录路径
        dir2: 第二个目录路径
        model_path: 模型路径
        pattern: 文件匹配模式
        reduction_method: 降维方法 ('pca' 或 'tsne')
        pooling: 池化方式
        max_length: 最大序列长度
        batch_size: 批次大小
        output_dir: 输出目录
        device: 设备
        max_samples: 每个文件最大样本数（用于加速）
    
    Returns:
        list: 对比结果
    """
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 加载模型
    model, tokenizer = load_model_and_tokenizer(model_path, device=device)
    
    # 找到两个目录中的所有匹配文件
    files1 = glob.glob(os.path.join(dir1, pattern))
    files2 = glob.glob(os.path.join(dir2, pattern))
    
    if not files1:
        print(f"错误: 在 {dir1} 中没有找到匹配 {pattern} 的文件")
        return None
    if not files2:
        print(f"错误: 在 {dir2} 中没有找到匹配 {pattern} 的文件")
        return None
    
    # 提取 step 映射
    step_to_file1 = {}
    for file_path in files1:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
            step_to_file1[step] = file_path
        except ValueError:
            continue
    
    step_to_file2 = {}
    for file_path in files2:
        filename = os.path.basename(file_path)
        try:
            step = int(filename.split('_')[0])
            step_to_file2[step] = file_path
        except ValueError:
            continue
    
    # 找到共同的 steps
    common_steps = sorted(set(step_to_file1.keys()) & set(step_to_file2.keys()))
    
    if not common_steps:
        print(f"错误: 两个目录没有共同的 step 文件")
        return None
    
    print(f"找到 {len(common_steps)} 个共同的 steps: {common_steps}")
    
    # 获取目录名称
    dir1_name = os.path.basename(dir1.rstrip('/'))
    dir2_name = os.path.basename(dir2.rstrip('/'))
    
    # 对每个 step 进行对比
    comparison_results = []
    
    for step in common_steps:
        print(f"\n{'='*60}")
        print(f"处理 Step {step}")
        print(f"{'='*60}")
        
        file1 = step_to_file1[step]
        file2 = step_to_file2[step]
        
        print(f"文件1: {os.path.basename(file1)}")
        print(f"文件2: {os.path.basename(file2)}")
        
        # 读取 responses
        responses1 = read_responses_from_file(file1)
        responses2 = read_responses_from_file(file2)
        
        if not responses1 or not responses2:
            print(f"警告: 某个文件为空，跳过")
            continue
        
        # 限制样本数量
        if max_samples:
            responses1 = responses1[:max_samples]
            responses2 = responses2[:max_samples]
        
        print(f"样本数: {len(responses1)} vs {len(responses2)}")
        
        # 提取 embeddings
        print("提取 Dir1 embeddings...")
        embeddings1 = extract_embeddings(responses1, model, tokenizer, device, max_length, batch_size, pooling)
        print(f"Dir1 embeddings shape: {embeddings1.shape}")
        
        print("提取 Dir2 embeddings...")
        embeddings2 = extract_embeddings(responses2, model, tokenizer, device, max_length, batch_size, pooling)
        print(f"Dir2 embeddings shape: {embeddings2.shape}")
        
        # 合并 embeddings 进行降维（保证在同一空间）
        all_embeddings = np.vstack([embeddings1, embeddings2])
        reduced_all = dimensionality_reduction(all_embeddings, method=reduction_method, n_components=2)
        
        # 分割降维后的结果
        reduced1 = reduced_all[:len(embeddings1)]
        reduced2 = reduced_all[len(embeddings1):]
        
        print(f"降维后 shape: {reduced1.shape}, {reduced2.shape}")
        
        # 计算度量
        metrics = compute_distribution_metrics(embeddings1, embeddings2)
        print(f"度量结果:")
        for key, value in metrics.items():
            print(f"  {key}: {value:.6f}")
        
        # 绘制对比图
        output_path = os.path.join(output_dir, f"step_{step}_{reduction_method}.png")
        plot_comparison(reduced1, reduced2, output_path, step, 
                       dir1_name, dir2_name, reduction_method.upper(), metrics)
        
        # 保存结果
        result = {
            'step': step,
            'num_samples1': len(responses1),
            'num_samples2': len(responses2),
            'embedding_dim': embeddings1.shape[1],
            **metrics
        }
        comparison_results.append(result)
    
    return comparison_results


def plot_metrics_over_steps(comparison_results, output_dir, dir1_name='Dir1', dir2_name='Dir2'):
    """
    绘制度量随 step 的变化
    """
    if not comparison_results:
        print("没有数据可以绘图")
        return
    
    steps = [r['step'] for r in comparison_results]
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 中心欧氏距离
    center_euclidean = [r['center_euclidean'] for r in comparison_results]
    axes[0, 0].plot(steps, center_euclidean, marker='o', linewidth=2, markersize=8, color='#1f77b4')
    axes[0, 0].set_xlabel('Training Step', fontsize=12, fontweight='bold')
    axes[0, 0].set_ylabel('Center Euclidean Distance', fontsize=12, fontweight='bold')
    axes[0, 0].set_title('中心点欧氏距离', fontsize=14, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 中心余弦距离
    center_cosine = [r['center_cosine'] for r in comparison_results]
    axes[0, 1].plot(steps, center_cosine, marker='s', linewidth=2, markersize=8, color='#ff7f0e')
    axes[0, 1].set_xlabel('Training Step', fontsize=12, fontweight='bold')
    axes[0, 1].set_ylabel('Center Cosine Distance', fontsize=12, fontweight='bold')
    axes[0, 1].set_title('中心点余弦距离', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 标准差对比
    std1 = [r['std1'] for r in comparison_results]
    std2 = [r['std2'] for r in comparison_results]
    axes[1, 0].plot(steps, std1, marker='^', linewidth=2, markersize=8, color='#2ca02c', label=dir1_name)
    axes[1, 0].plot(steps, std2, marker='v', linewidth=2, markersize=8, color='#d62728', label=dir2_name)
    axes[1, 0].set_xlabel('Training Step', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('Standard Deviation', fontsize=12, fontweight='bold')
    axes[1, 0].set_title('分布标准差（分散程度）', fontsize=14, fontweight='bold')
    axes[1, 0].legend(fontsize=10)
    axes[1, 0].grid(True, alpha=0.3)
    
    # Wasserstein 距离
    wasserstein = [r['wasserstein_mean'] for r in comparison_results]
    axes[1, 1].plot(steps, wasserstein, marker='D', linewidth=2, markersize=8, color='#9467bd')
    axes[1, 1].set_xlabel('Training Step', fontsize=12, fontweight='bold')
    axes[1, 1].set_ylabel('Wasserstein Distance', fontsize=12, fontweight='bold')
    axes[1, 1].set_title('Wasserstein 距离', fontsize=14, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    output_path = os.path.join(output_dir, 'metrics_over_steps.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n度量变化图已保存到: {output_path}")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='对比两个目录中样本的表示（representation）分布差异')
    parser.add_argument('--dir1', type=str, required=True,
                       help='第一个目录路径')
    parser.add_argument('--dir2', type=str, required=True,
                       help='第二个目录路径')
    parser.add_argument('--model-path', type=str, required=True,
                       help='模型路径（用于提取 embeddings）')
    parser.add_argument('--pattern', '-p', type=str, default='*_16384.jsonl',
                       help='文件匹配模式 (默认: *_16384.jsonl)')
    parser.add_argument('--method', '-m', type=str, default='pca', choices=['pca', 'tsne'],
                       help='降维方法: pca 或 tsne (默认: pca)')
    parser.add_argument('--pooling', type=str, default='mean', choices=['mean', 'max', 'cls', 'last'],
                       help='Pooling 方式: mean, max, cls, last (默认: mean)')
    parser.add_argument('--max-length', type=int, default=512,
                       help='最大序列长度 (默认: 512)')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='批次大小 (默认: 16)')
    parser.add_argument('--output-dir', '-o', type=str, default='representation_analysis',
                       help='输出目录 (默认: representation_analysis)')
    parser.add_argument('--device', type=str, default='cuda',
                       help='设备: cuda 或 cpu (默认: cuda)')
    parser.add_argument('--max-samples', type=int, default=None,
                       help='每个文件最大样本数（用于加速测试）')
    
    args = parser.parse_args()
    
    print("="*60)
    print("表示分布对比分析")
    print("="*60)
    print(f"目录1: {args.dir1}")
    print(f"目录2: {args.dir2}")
    print(f"模型: {args.model_path}")
    print(f"文件模式: {args.pattern}")
    print(f"降维方法: {args.method}")
    print(f"Pooling: {args.pooling}")
    print(f"设备: {args.device}")
    print(f"输出目录: {args.output_dir}")
    if args.max_samples:
        print(f"最大样本数: {args.max_samples}")
    print("="*60)
    print()
    
    # 检查设备
    if args.device == 'cuda' and not torch.cuda.is_available():
        print("警告: CUDA 不可用，切换到 CPU")
        args.device = 'cpu'
    
    # 运行对比分析
    comparison_results = compare_two_directories_representations(
        args.dir1, args.dir2,
        model_path=args.model_path,
        pattern=args.pattern,
        reduction_method=args.method,
        pooling=args.pooling,
        max_length=args.max_length,
        batch_size=args.batch_size,
        output_dir=args.output_dir,
        device=args.device,
        max_samples=args.max_samples
    )
    
    if comparison_results:
        # 打印汇总表格
        print("\n" + "="*80)
        print("对比结果汇总")
        print("="*80)
        print(f"{'Step':<10} {'中心欧氏距离':<18} {'中心余弦距离':<18} {'Wasserstein':<15}")
        print("-"*80)
        for r in comparison_results:
            print(f"{r['step']:<10} {r['center_euclidean']:<18.6f} {r['center_cosine']:<18.6f} {r['wasserstein_mean']:<15.6f}")
        print("="*80)
        
        # 绘制度量变化图
        dir1_name = os.path.basename(args.dir1.rstrip('/'))
        dir2_name = os.path.basename(args.dir2.rstrip('/'))
        plot_metrics_over_steps(comparison_results, args.output_dir, dir1_name, dir2_name)
        
        print(f"\n所有结果已保存到: {args.output_dir}")
