#!/usr/bin/env python3
"""
多方法降维可视化脚本
支持 t-SNE / UMAP / PCA / Isomap / LLE（2D），
其中 t-SNE/UMAP 可选进行 PCA 预降维
"""

import os
import json
import sys
import random
import argparse
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from rich.console import Console

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding
from umap import UMAP

from train_linear import bbh_linear, collect_input_dims

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.train_utils.tsne_vis_linear import load_model_data, extract_test_features
from src.vis_utils.tsne_vis_linear import safe_rich_print, set_global_seed, maybe_pre_pca, reduce_dim, scatter_2d, create_benchmark_comparison_plot

plt.rcParams['font.family'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 14

console = Console()

def main():
    parser = argparse.ArgumentParser(description="多方法降维可视化")
    parser.add_argument("--feats_dir", default=os.path.join(PROJECT_ROOT,"feats/bbh/bbh_feats_split"), help="特征目录")
    parser.add_argument("--results_dir", default=os.path.join(PROJECT_ROOT,"results/bbh/bbh_results"), help="结果目录")
    parser.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT,"results/bbh/bbh_results"), help="输出目录")
    parser.add_argument("--device", default="cuda", help="设备")
    parser.add_argument("--layer_type", default="last", choices=["middle", "last", "second_last"])
    parser.add_argument("--operate_type", default="last_token", choices=["last_token", "avg_with_prompt", "avg_without_prompt"])
    parser.add_argument("--max_samples", type=int, default=800, help="最大样本数")
    parser.add_argument("--level_filter", default=None, help="只处理指定level的数据，例如'Level 5'")
    parser.add_argument("--filter_incomplete", type=bool, default=False, help="是否过滤pred_answer为'Incomplete'的样本")
    parser.add_argument("--dot_size", type=int, default=20, help="点的大小")
    parser.add_argument("--alpha", type=float, default=0.5, help="点的透明度")

    # 方法选择
    parser.add_argument("--methods", nargs="+",
                        default=["tsne", "pca"],
                        choices=["tsne", "pca"],
                        help="选择一种或多种降维方法")

    # 预 PCA（仅对 t-SNE/UMAP 生效）
    parser.add_argument("--pre_pca_ratio", type=float, default=0.95,
                        help="t-SNE/UMAP 的 PCA 预降维方差保留比例（0 关闭）")
    parser.add_argument("--seed", type=int, default=42, help="随机种子")

    # t-SNE
    parser.add_argument("--tsne_perplexity", type=float, default=30.0)
    parser.add_argument("--tsne_metric", default="cosine")

    # # UMAP
    # parser.add_argument("--umap_n_neighbors", type=int, default=15)
    # parser.add_argument("--umap_min_dist", type=float, default=0.1)
    # parser.add_argument("--umap_metric", default="cosine")

    # # Isomap
    # parser.add_argument("--isomap_n_neighbors", type=int, default=15)
    # parser.add_argument("--isomap_metric", default="minkowski",
    #                     help="常用：minkowski/euclidean/manhattan 等")
    # parser.add_argument("--isomap_p", type=float, default=2.0,
    #                     help="Minkowski 距离的 p，p=2 为欧氏")

    # # LLE
    # parser.add_argument("--lle_n_neighbors", type=int, default=15)
    # parser.add_argument("--lle_method", default="standard",
    #                     choices=["standard", "modified", "hessian", "ltsa"])

    args = parser.parse_args()
    set_global_seed(args.seed)

    # 修复：添加输出目录检查
    if not os.path.exists(args.output_dir):
        safe_rich_print(f"错误: 输出目录不存在 {args.output_dir}")
        return

    # 选择 output_dir 为 results 中最新一次
    output_dirs = [d for d in os.listdir(args.output_dir) if os.path.isdir(os.path.join(args.output_dir, d))]
    
    if not output_dirs:
        safe_rich_print(f"错误: 在 {args.output_dir} 中未找到子目录")
        return
    
    # 修复：处理可能的日期格式错误
    try:
        output_dirs.sort(key=lambda x: datetime.strptime(x[:15], "%Y%m%d_%H%M%S"))
        args.output_dir = os.path.join(args.output_dir, output_dirs[-1])
        args.results_dir = args.output_dir
    except ValueError as e:
        safe_rich_print(f"错误: 目录名格式不正确，无法解析日期: {e}")
        # 使用字符串排序作为备用方案
        output_dirs.sort()
        args.output_dir = os.path.join(args.output_dir, output_dirs[-1])
        args.results_dir = args.output_dir

    # 模型 & 数据集
    # models = ["llama3.2-3b-instruct", "falcon3-10b-instruct", "qwen2.5-7b-instruct", "granite3.1-8b-instruct", "qwen2.5-14b-instruct"]
    # models = ["ministral-8b-instruct-2410", "qwen2.5-3b-instruct", "llama3.1-8b-instruct", "internlm2.5-7b-chat", "qwen2.5-14b-instruct"]
    # models = ["ministral-8b-instruct-2410", "falcon3-3b-instruct", "internlm2.5-20b-chat", "falcon3-10b-instruct", "qwen2.5-7b-instruct"]
    # models = ["qwen2.5-3b-instruct", "falcon3-3b-instruct", "llama3.2-3b-instruct"]
    models = ["qwen2.5-7b-instruct", "internlm2.5-7b-chat", "llama3.1-8b-instruct", "ministral-8b-instruct-2410", "granite3.1-8b-instruct"]
    # models = ["qwen2.5-14b-instruct", "phi4-14b", "internlm2.5-20b-chat", "falcon3-10b-instruct"]
    datasets = ["bbh"]

    # 加载待可视化的特征索引
    try:
        models_data = load_model_data(args.feats_dir, layer_type=args.layer_type,
                                      datasets=datasets, models=models, operate_type=args.operate_type)
    except Exception as e:
        safe_rich_print(f"错误: 加载模型数据失败: {e}")
        return

    models_list = list(models_data.keys())
    in_dims = collect_input_dims(
        args.feats_dir if args else "../../bbh_feats_split", 
        args.layer_type if args else "last", 
        datasets, models_list, 
        args.operate_type if args else "last_token"
    )

    model = bbh_linear(target_dim=4096, in_dims=in_dims).to(args.device)

    # 生成多方法可视化
    create_benchmark_comparison_plot(model, datasets, models_data, args.results_dir, args.output_dir,
                                     device=args.device, max_samples=args.max_samples,
                                     methods=args.methods, args=args, level_filter=args.level_filter, filter_incomplete=args.filter_incomplete)
    console.print("[bold green]可视化完成！[/bold green]")

if __name__ == "__main__":
    main() 