import os
import glob
import argparse
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
import inspect
import json

# ----------------------------------------------------
# 导入项目模块
# ----------------------------------------------------
from utils.testloss import TestLoss
from model_dict import get_model
from utils.normalizer import UnitTransformer

# ==============================================================================
#  核心评估函数 (与您提供的版本相同，无需修改)
# ==============================================================================
def evaluate_single_model(checkpoint_path, cli_args, true_depth, test_loader, y_normalizer, masks, device):
    """加载单个模型检查点并计算其分区误差。"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    if 'args' in checkpoint:
        model_args_from_ckpt = checkpoint['args']
        state_dict = checkpoint['model']
    else:
        cli_args_dict = vars(cli_args).copy()
        cli_args_dict['n_layers'] = true_depth
        model_args_from_ckpt = argparse.Namespace(**cli_args_dict)
        state_dict = checkpoint

    s1 = int(((129 - 1) / getattr(model_args_from_ckpt, 'downsamplex', 1)) + 1)
    s2 = int(((129 - 1) / getattr(model_args_from_ckpt, 'downsampley', 1)) + 1)
    
    model_class = get_model(model_args_from_ckpt).Model
    constructor_params = inspect.signature(model_class.__init__).parameters
    model_params = vars(model_args_from_ckpt).copy()
    if 'n_head' in constructor_params and 'n_heads' in model_params:
        model_params['n_head'] = model_params['n_heads']
    
    final_model_params = {key: val for key, val in model_params.items() if key in constructor_params}
    final_model_params.update({'space_dim': 2, 'Time_Input': False, 'fun_dim': 0, 'out_dim': 1, 'H': s1, 'W': s2})

    model = model_class(**final_model_params).to(device)
    model.load_state_dict(state_dict)
    model.eval()

    myloss = TestLoss(size_average=True)
    complex_mask_torch = torch.from_numpy(masks['complex_masks']).bool().to(device)
    simple_mask_torch = torch.from_numpy(masks['simple_masks']).bool().to(device)
    
    all_preds = []
    all_true = []
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            out_enc = model(x, None).squeeze(-1)
            out = y_normalizer.decode(out_enc)
            all_preds.append(out.cpu())
            all_true.append(y.cpu())

    preds_tensor = torch.cat(all_preds, dim=0)
    true_tensor = torch.cat(all_true, dim=0)

    errors = {}
    errors['global'] = myloss(preds_tensor, true_tensor).item()
    if complex_mask_torch.any():
        errors['complex'] = myloss(preds_tensor[complex_mask_torch], true_tensor[complex_mask_torch]).item()
    else:
        errors['complex'] = float('nan')
    if simple_mask_torch.any():
        errors['simple'] = myloss(preds_tensor[simple_mask_torch], true_tensor[simple_mask_torch]).item()
    else:
        errors['simple'] = float('nan')

    return errors

# ==============================================================================
#  主流程
# ==============================================================================
def main():
    parser = argparse.ArgumentParser('Analysis and Plotting for Transolver Depth Study')
    
    # --- 参数 (与您提供的版本相同) ---
    parser.add_argument('--model_paths', type=str, nargs='+', required=True)
    parser.add_argument('--model_depths', type=int, nargs='+', required=True)
    parser.add_argument('--model_seeds', type=int, nargs='+', required=True)
    parser.add_argument('--mask_path', type=str, required=True)
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--baseline_depth', type=int, default=8)
    parser.add_argument("--gpu", type=str, default='0')
    parser.add_argument('--save_name', type=str, default='depth_impact_analysis.png')
    
    parser.add_argument('--model', type=str, default='Transolver_Structured_Mesh_2D')
    parser.add_argument('--n-hidden', type=int, default=128)
    parser.add_argument('--n-heads', type=int, default=8)
    parser.add_argument('--mlp_ratio', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--unified_pos', type=int, default=0)
    parser.add_argument('--ref', type=int, default=8)
    parser.add_argument('--slice_num', type=int, default=64)
    
    args = parser.parse_args()
    
    if not (len(args.model_paths) == len(args.model_depths) == len(args.model_seeds)):
        raise ValueError("错误: --model_paths, --model_depths, 和 --model_seeds 列表长度必须一致。")
        
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')

    # --- 1. 加载数据和掩码 (与您提供的版本相同) ---
    print("--- 1. Loading Data and Masks ---")
    try:
        masks = np.load(args.mask_path)
        print(f"Masks loaded successfully from '{args.mask_path}'")
    except FileNotFoundError:
        raise FileNotFoundError(f"掩码文件未找到: {args.mask_path}。请先运行 01_prepare_masks.py。")
    
    if 'args' in masks:
        args_json_string = masks['args'].item() 
        mask_args = json.loads(args_json_string)
        print("  -> Metadata from mask file loaded and parsed.")
    else:
        mask_args = {}
        print("  -> Warning: No 'args' metadata found in mask file. Using default values.")

    r1, r2 = mask_args.get('downsamplex', 1), mask_args.get('downsampley', 1)
    s1, s2 = int(((129 - 1) / r1) + 1), int(((129 - 1) / r2) + 1)
    
    INPUT_X = os.path.join(args.data_path, 'Pipe_X.npy')
    INPUT_Y = os.path.join(args.data_path, 'Pipe_Y.npy')
    OUTPUT_Sigma = os.path.join(args.data_path, 'Pipe_Q.npy')
    ntrain, ntest = 1000, 200

    input_coords = torch.from_numpy(np.stack([np.load(INPUT_X), np.load(INPUT_Y)], axis=-1)).float()
    output = torch.from_numpy(np.load(OUTPUT_Sigma)[:, 0]).float()
    
    x_test = input_coords[-ntest:, ::r1, ::r2][:, :s1, :s2].reshape(ntest, -1, 2)
    y_test = output[-ntest:, ::r1, ::r2][:, :s1, :s2].reshape(ntest, -1)
    y_train = output[:ntrain, ::r1, ::r2][:, :s1, :s2].reshape(ntrain, -1)
    
    y_normalizer = UnitTransformer(y_train)
    y_normalizer.to(device)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=16, shuffle=False)
    print("Test data and masks are ready.")

    # --- 2. 评估所有指定的模型 (与您提供的版本相同) ---
    print("\n--- 2. Evaluating all specified model variants ---")
    results = []
    pbar = tqdm(zip(args.model_paths, args.model_depths, args.model_seeds), 
                total=len(args.model_paths), 
                desc="Evaluating Models")

    for path, depth, seed in pbar:
        pbar.set_postfix({'depth': depth, 'seed': seed})
        try:
            errors = evaluate_single_model(path, args, depth, test_loader, y_normalizer, masks, device)
            results.append({'depth': depth, 'seed': seed, **errors})
        except Exception as e:
            print(f"\n处理 {path} 时发生严重错误: {e}")
            
    if not results:
        raise RuntimeError("未能成功评估任何模型。请检查错误信息和文件路径。")

    df = pd.DataFrame(results)
    print("\n--- Raw Evaluation Results ---")
    print(df.to_string())
    
    # --- 3. 数据聚合与分析 (与您提供的版本相同) ---
    print("\n--- 3. Aggregating results by depth ---")
    grouped = df.groupby('depth').agg({
        'global': ['mean', 'std'], 'complex': ['mean', 'std'], 'simple': ['mean', 'std']
    })
    grouped.columns = ['_'.join(col).strip() for col in grouped.columns.values]
    grouped = grouped.reset_index()
    print(grouped.to_string())
    
    try:
        baseline_row = grouped[grouped['depth'] == args.baseline_depth]
        baseline_error_complex = baseline_row['complex_mean'].iloc[0]
        baseline_error_simple = baseline_row['simple_mean'].iloc[0]
        
        grouped['improvement_complex'] = (baseline_error_complex - grouped['complex_mean']) / baseline_error_complex * 100
        grouped['improvement_simple'] = (baseline_error_simple - grouped['simple_mean']) / baseline_error_simple * 100
    except (IndexError, KeyError):
        raise ValueError(f"基线深度 L={args.baseline_depth} 在评估结果中未找到。请选择一个已评估的深度作为基线。")
    
    # ==================== 核心修改部分：可视化 ====================
    print("\n--- 4. Generating analysis plot (Bar Chart) ---")
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # 准备柱状图所需的数据和位置
    depths = grouped['depth']
    x_positions = np.arange(len(depths))  # X轴刻度的位置 [0, 1, 2, ...]
    bar_width = 0.35  # 每个柱子的宽度

    # 绘制“复杂区域”的分组柱状图
    ax.bar(x_positions - bar_width/2, 
           grouped['improvement_complex'], 
           width=bar_width,
           yerr=grouped['complex_std'], 
           label='Complex Region', 
           color='crimson', 
           capsize=5) # capsize让误差棒有“帽子”

    # 绘制“简单区域”的分组柱状图
    ax.bar(x_positions + bar_width/2, 
           grouped['improvement_simple'],
           width=bar_width,
           yerr=grouped['simple_std'], 
           label='Simple Region', 
           color='mediumseagreen', 
           capsize=5)
                
    # 保留基线和标签等设置
    ax.axhline(0, color='black', linestyle='-.', lw=1.5, label=f'Baseline (L={args.baseline_depth})')
    
    # 设置X轴标签
    ax.set_xlabel('Network Depth (Number of Layers, L)', fontsize=14, weight='bold')
    # 关键：将X轴的位置(0, 1, 2)替换为真实的深度标签(8, 12, 16)
    ax.set_xticks(x_positions)
    ax.set_xticklabels(depths)
    
    ax.set_ylabel(f'Relative Error Improvement (%)', fontsize=14, weight='bold')
    ax.set_title('Impact of Transolver Network Depth on Prediction Accuracy', fontsize=18, weight='bold', pad=20)
    ax.legend(fontsize=12, loc='best')
    ax.grid(True, which='major', axis='y', linestyle='--', linewidth=0.7) # 只显示水平网格线
    ax.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    
    plt.savefig(args.save_name, dpi=300)
    print(f"\nAnalysis plot saved to '{args.save_name}'")
    # ==================== 修改结束 ====================

    # --- 5. 统计验证 (与您提供的版本相同，无需修改) ---
    print("\n--- 5. Statistical Validation (Paired t-test on Complex Region errors) ---")
    try:
        df_pivot = df.pivot(index='seed', columns='depth', values='complex')
        data_baseline = df_pivot[args.baseline_depth]
        
        for depth_to_compare in df_pivot.columns:
            if depth_to_compare == args.baseline_depth: continue
            data_compare = df_pivot[depth_to_compare]
            
            valid_mask = ~np.isnan(data_baseline) & ~np.isnan(data_compare)
            if np.sum(valid_mask) < 2:
                print(f"Skipping L={args.baseline_depth} vs L={depth_to_compare}: Not enough paired samples.")
                continue

            t_stat, p_value = ttest_rel(data_baseline[valid_mask], data_compare[valid_mask])
            mean_diff = np.mean(data_baseline[valid_mask] - data_compare[valid_mask])
            
            print(f"Comparing L={args.baseline_depth} vs L={depth_to_compare}: p-value = {p_value:.5f}")
            if p_value < 0.05:
                verdict = "SIGNIFICANT improvement" if mean_diff > 0 else "SIGNIFICANT degradation"
                print(f"  -> Statistically {verdict} (L={depth_to_compare} is {'better' if mean_diff > 0 else 'worse'}).")
            else:
                print(f"  -> Difference is NOT statistically significant (p >= 0.05).")
    except Exception as e:
        print(f"Could not perform t-test: {e}")

if __name__ == "__main__":
    main()