import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
from pathlib import Path
import seaborn as sns
from scipy.integrate import simps  
import matplotlib.font_manager as fm
import geomloss
import pandas as pd
from scipy import stats


plt.style.use('seaborn')
plt.rcParams['axes.unicode_minus'] = False

def test_model(model, test_loader, config):
    """Test model performance and generate visualization of prediction results.

    This function evaluates the model on test data and generates various visualizations
    including distribution comparisons, sample paths, and evaluation metrics. It supports
    both 1D and 2D output distributions.

    Args:
        model: torch.nn.Module, the trained model to evaluate
        test_loader: torch.utils.data.DataLoader, data loader for test dataset
        config: dict, configuration parameters including:
            - test_gpu: int, GPU device ID for testing
            - model_save_path: str, path to save model and results
            - output_len: int, length of prediction sequence

    Returns:
        None. Results are saved to the specified directory including:
            - Distribution comparison plots
            - Sample path visualizations
            - Evaluation metrics (MISE, Sinkhorn distance, tail differences)
            - Predicted and true distribution data in NPY and CSV formats
    """

    test_device = torch.device(f"cuda:{config['test_gpu']}" if torch.cuda.is_available() and 'test_gpu' in config else "cpu")
    save_dir = Path(config["model_save_path"])
    output_len = config["output_len"]
    
  
    num_bins = 300  
    
    # 加载最佳模型
    best_model_path = save_dir / "best_model.pth"
    if not best_model_path.exists():
        raise FileNotFoundError(f"找不到最佳模型文件: {best_model_path}")
    
    # 加载模型权重并移除"module."前缀
    state_dict = torch.load(best_model_path, map_location=test_device)
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
    
    # 将模型加载到指定的测试设备上
    model = model.to(test_device)  
    model.eval()
    
    # 收集所有测试数据的预测结果
    all_preds = []
    all_trues = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y, timestamps = batch
            x = x.to(test_device)
            y = y.to(test_device)
            timestamps = timestamps.to(test_device)
            
            output = model(x, timestamps)  # (batch, output_len, output_dim)
            
            all_preds.append(output.cpu())
            all_trues.append(y[:, -output_len:, :].cpu())  # 只保存对应预测长度的真实值
    
    # 连接所有预测结果
    all_preds = torch.cat(all_preds, dim=0)  # (total_batch, output_len, output_dim)
    all_trues = torch.cat(all_trues, dim=0)  # (total_batch, output_len, output_dim)
    
    print("all_preds.shape:", all_preds.shape)
    print("all_trues.shape:", all_trues.shape)
    
    # 创建结果保存目录
    result_dir = Path(config["model_save_path"]) / "eval_results"
    result_dir.mkdir(exist_ok=True)
    
    # 定义时间点
    percentages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    time_points = [(p * (output_len - 1)) // 100 for p in percentages]
    
    # 保存所有时间点的预测和真实数据到NPY文件
    # 将数据转换为numpy数组
    pred_data = all_preds.numpy()  # (batch_size, output_len, output_dim)
    true_data = all_trues.numpy()  # (batch_size, output_len, output_dim)
    
    # 保存到NPY文件
    pred_npy_path = result_dir / 'density_predicted_distribution.npy'
    true_npy_path = result_dir / 'density_true_distribution.npy'
    
    np.save(pred_npy_path, pred_data)
    np.save(true_npy_path, true_data)
    
    # 同时也保存CSV格式（为了便于查看）
    # 为每个通道创建单独的DataFrame
    for channel in range(pred_data.shape[-1]):
        # 创建列名
        columns_pred = [f'pred_t{i}' for i in range(output_len)]
        columns_true = [f'true_t{i}' for i in range(output_len)]
        
        # 创建DataFrame
        pred_df = pd.DataFrame(pred_data[:, :, channel], columns=columns_pred)
        true_df = pd.DataFrame(true_data[:, :, channel], columns=columns_true)
        
        # 保存到CSV文件
        pred_csv_path = result_dir / f'density_predicted_distribution_channel_{channel}.csv'
        true_csv_path = result_dir / f'density_true_distribution_channel_{channel}.csv'
        
        pred_df.to_csv(pred_csv_path, index=False)
        true_df.to_csv(true_csv_path, index=False)
    
    # 绘制分布对比图
    for idx, (time_point, percentage) in enumerate(zip(time_points, percentages)):
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))
        
        # 获取特定时间点的预测和真实值
        pred_data = all_preds[:, time_point, 0].numpy()  # 修改：使用第一个特征
        true_data = all_trues[:, time_point, 0].numpy()  # 修改：使用第一个特征
        
        # 绘制分布图（保持原有的绘图代码）
        ax1.hist(true_data, bins=100, density=True, alpha=0.3, color='red', 
                label='真实值 (直方图)')
        sns.kdeplot(data=true_data, label='真实值 (密度)', color='red', alpha=0.7, 
                   bw_adjust=0.5, ax=ax1, warn_singular=False)
        ax1.set_title(f'第 {percentage}% 时间点的真实数据分布')
        ax1.set_xlabel('数值')
        ax1.set_ylabel('密度')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        
        ax2.hist(pred_data, bins=100, density=True, alpha=0.3, color='blue', 
                label='预测值 (直方图)')
        sns.kdeplot(data=pred_data, label='预测值 (密度)', color='blue', alpha=0.7,
                   bw_adjust=0.5, ax=ax2, warn_singular=False)
        ax2.set_title(f'第 {percentage}% 时间点的预测数据分布')
        ax2.set_xlabel('数值')
        ax2.set_ylabel('密度')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        
        ax3.hist(true_data, bins=100, density=True, alpha=0.3, color='red', 
                label='真实值 (直方图)')
        ax3.hist(pred_data, bins=100, density=True, alpha=0.3, color='blue', 
                label='预测值 (直方图)')
        sns.kdeplot(data=true_data, label='真实值 (密度)', color='red', alpha=0.7,
                   bw_adjust=0.5, ax=ax3, warn_singular=False)
        sns.kdeplot(data=pred_data, label='预测值 (密度)', color='blue', alpha=0.7,
                   bw_adjust=0.5, ax=ax3, warn_singular=False)
        ax3.set_title(f'第 {percentage}% 时间点的分布对比')
        ax3.set_xlabel('数值')
        ax3.set_ylabel('密度')
        ax3.grid(True, alpha=0.3)
        ax3.legend()
        
        plt.tight_layout()
        result_path = result_dir / f'distribution_comparison_{percentage}p.jpg'
        plt.savefig(result_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    output_dim = all_preds.shape[-1]  # 获取输出维度
    
    if output_dim == 1:
        # 使用现有的一维数据处理逻辑
        # ... 保持现有代码不变 ...
        pass
    
    elif output_dim == 2:
        # 添加二维数据的处理逻辑
        def plot_3d_distribution(data, title, save_path):
            # 获取最后一个时间点的两个特征
            x = data[:, 0]
            y = data[:, 1]
            
            # 计算核密度估计
            xmin, xmax = x.min(), x.max()
            ymin, ymax = y.min(), y.max()
            xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
            positions = np.vstack([xx.ravel(), yy.ravel()])
            values = np.vstack([x, y])
            kernel = stats.gaussian_kde(values)
            z = np.reshape(kernel(positions).T, xx.shape)
            
            # 创建不同视角的3D图
            view_angles = [
                (30, 45),   # 右前上方
                (30, 135),  # 左前上方
                (90, 0),    # 正上方
            ]
            
            fig = plt.figure(figsize=(15, 5))
            fig.suptitle(title, fontsize=16, y=1.05)
            
            for i, (elev, azim) in enumerate(view_angles, 1):
                ax = fig.add_subplot(1, 3, i, projection='3d')
                surf = ax.plot_surface(xx, yy, z, cmap='viridis', 
                                    antialiased=True, alpha=0.8)
                ax.view_init(elev=elev, azim=azim)
                ax.set_xlabel('特征1')
                ax.set_ylabel('特征2')
                ax.set_zlabel('密度')
                ax.set_title(f'视角: 仰角={elev}°, 方位角={azim}°')
            
            plt.tight_layout()
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            plt.close()
            
            return xx, yy, z
        
        # 获取最后一个时间点的预测和真实数据
        pred_last = all_preds[:, -1, :].numpy()
        true_last = all_trues[:, -1, :].numpy()
        
        # 创建二维分布的可视化结果目录
        result_dir = Path(config["model_save_path"]) / "eval_results"
        result_dir.mkdir(exist_ok=True)
        
        # 绘制预测分布
        xx_pred, yy_pred, z_pred = plot_3d_distribution(
            pred_last,
            '预测分布的3D密度图',
            result_dir / 'predicted_3d_distribution.png'
        )
        
        # 绘制真实分布
        xx_true, yy_true, z_true = plot_3d_distribution(
            true_last,
            '真实分布的3D密度图',
            result_dir / 'true_3d_distribution.png'
        )
        
        # 绘制对比图
        fig = plt.figure(figsize=(15, 5))
        fig.suptitle('预测与真实分布对比', fontsize=16, y=1.05)
        
        # 选择一个最佳视角
        elev, azim = 30, 45
        
        # 预测分布
        ax1 = fig.add_subplot(131, projection='3d')
        surf1 = ax1.plot_surface(xx_pred, yy_pred, z_pred, cmap='viridis',
                               antialiased=True, alpha=0.8)
        ax1.view_init(elev=elev, azim=azim)
        ax1.set_title('预测分布')
        
        # 真实分布
        ax2 = fig.add_subplot(132, projection='3d')
        surf2 = ax2.plot_surface(xx_true, yy_true, z_true, cmap='viridis',
                               antialiased=True, alpha=0.8)
        ax2.view_init(elev=elev, azim=azim)
        ax2.set_title('真实分布')
        
        # 叠加对比
        ax3 = fig.add_subplot(133, projection='3d')
        surf3_pred = ax3.plot_surface(xx_pred, yy_pred, z_pred, cmap='Blues',
                                    antialiased=True, alpha=0.5)
        surf3_true = ax3.plot_surface(xx_true, yy_true, z_true, cmap='Reds',
                                    antialiased=True, alpha=0.5)
        ax3.view_init(elev=elev, azim=azim)
        ax3.set_title('分布对比')
        
        plt.tight_layout()
        plt.savefig(result_dir / 'distribution_comparison_3d.png',
                   bbox_inches='tight', dpi=300)
        plt.close()
        
        # 为二维数据计算MISE
        def calculate_2d_mise(pred_data, true_data, num_bins=50):
            """Calculate Mean Integrated Squared Error (MISE) for 2D distributions.

            Uses double integration to compute MISE between predicted and true 2D distributions.

            Args:
                pred_data: numpy.ndarray of shape (n_samples, 2), predicted distribution samples
                true_data: numpy.ndarray of shape (n_samples, 2), true distribution samples
                num_bins: int, number of bins for each dimension in histogram, defaults to 50

            Returns:
                float: The computed MISE value
            """
            # 确定二维网格的范围
            x_min = min(pred_data[:, 0].min(), true_data[:, 0].min())
            x_max = max(pred_data[:, 0].max(), true_data[:, 0].max())
            y_min = min(pred_data[:, 1].min(), true_data[:, 1].min())
            y_max = max(pred_data[:, 1].max(), true_data[:, 1].max())
            
            # 添加小的边界以避免边界效应
            x_margin = (x_max - x_min) * 0.05
            y_margin = (y_max - y_min) * 0.05
            x_min -= x_margin
            x_max += x_margin
            y_min -= y_margin
            y_max += y_margin
            
            # 创建二维网格
            x_edges = np.linspace(x_min, x_max, num_bins + 1)
            y_edges = np.linspace(y_min, y_max, num_bins + 1)
            
            # 计算二维直方图
            hist_pred, _, _ = np.histogram2d(pred_data[:, 0], pred_data[:, 1], 
                                           bins=[x_edges, y_edges], density=True)
            hist_true, _, _ = np.histogram2d(true_data[:, 0], true_data[:, 1], 
                                           bins=[x_edges, y_edges], density=True)
            
            # 计算网格单元的面积
            dx = x_edges[1] - x_edges[0]
            dy = y_edges[1] - y_edges[0]
            
            # 计算二重积分
            squared_diff = (hist_pred - hist_true) ** 2
            mise = np.sum(squared_diff) * dx * dy
            
            return mise
        
        # 计算二维MISE
        mise_last = calculate_2d_mise(pred_last, true_last, num_bins=50)
        
        # 计算二维Sinkhorn距离
        # 随机采样以减少计算量
        sample_size = 5000
        indices = torch.randperm(len(all_preds))[:sample_size]
        
        pred_last_sampled = all_preds[indices, 0, :].reshape(-1, output_dim)
        true_last_sampled = all_trues[indices, 0, :].reshape(-1, output_dim)
        
        # 计算Sinkhorn散度
        sinkhorn_loss = geomloss.SamplesLoss(
            loss='sinkhorn',
            p=2,
            blur=0.1,
            scaling=0.95,
            backend='tensorized'
        )
        
        # 在GPU上计算Sinkhorn距离
        pred_last_reshaped = pred_last_sampled.to(test_device)
        true_last_reshaped = true_last_sampled.to(test_device)
        sinkhorn_distance = sinkhorn_loss(pred_last_reshaped, true_last_reshaped)
        sinkhorn_distance = float(sinkhorn_distance.cpu())
        
        # 计算二维尾部差异
        def calculate_2d_tail_diff(data1, data2, percentile=0.01):
            """Calculate the cumulative distribution difference in tail regions for 2D distributions.

            Args:
                data1: numpy.ndarray, first distribution samples
                data2: numpy.ndarray, second distribution samples
                percentile: float, percentile threshold for defining tail region (default: 0.01)

            Returns:
                tuple:
                    - float: Absolute difference in tail probabilities
                    - float: Distance threshold used for tail definition
            """
            # 计算每个点到原点的距离
            dist1 = np.sqrt(np.sum(data1**2, axis=1))
            dist2 = np.sqrt(np.sum(data2**2, axis=1))
            
            # 使用距离来定义尾部
            threshold = np.percentile(np.concatenate([dist1, dist2]), 100 - percentile)
            
            # 计算尾部概率
            tail_prob1 = (dist1 >= threshold).mean()
            tail_prob2 = (dist2 >= threshold).mean()
            
            return abs(tail_prob1 - tail_prob2), threshold
        
        # 计算二维尾部差异
        tail_diff, threshold = calculate_2d_tail_diff(
            true_last_sampled.numpy(),
            pred_last_sampled.numpy(),
            percentile=1  # 使用1%作为尾部定义
        )
        
        # 保存评估指标
        with open(result_dir / 'evaluation_metrics.txt', 'w') as f:
            f.write("二维分布评估指标:\n")
            f.write(f"Mean Integrated Squared Error (MISE) for each dimension:\n")
            f.write(f"MISE: {mise_last:.6f}\n")
            f.write(f"\nSinkhorn Distance: {sinkhorn_distance:.6f}\n")
            f.write(f"\n二维尾部差异 (阈值距离 = {threshold:.6f}):\n")
            f.write(f"Tail Difference: {tail_diff:.6f}\n")
        
        # 打印评估结果
        print("\n二维分布评估指标:")
        print(f"MISE: {mise_last:.6f}")
        print(f"Sinkhorn distance: {sinkhorn_distance:.6f}")
        print(f"二维尾部差异 (阈值距离 = {threshold:.6f}): {tail_diff:.6f}")
    
    else:
        raise ValueError(f"不支持的输出维度: {output_dim}")

    # 计算MISE和Sinkhorn散度
    pred_last = all_preds[:, 0, :].flatten()  # 修改索引从-1到0
    true_last = all_trues[:, 0, :].flatten()  # 修改索引从-1到0
    
    # 在CPU上计算直方图
    data_min = float(min(true_last.min(), pred_last.min()))
    data_max = float(max(true_last.max(), pred_last.max()))

    if data_min == data_max:
        data_min -= 1e-8
        data_max += 1e-8

    bins = torch.linspace(data_min, data_max, num_bins + 1)
    
    # 在CPU上计算直方图
    hist_true = torch.histogram(true_last, bins=bins, density=True)[0]
    hist_pred = torch.histogram(pred_last, bins=bins, density=True)[0]
    bin_width = (data_max - data_min) / num_bins
    mise_last = ((hist_true - hist_pred) ** 2).sum().item() * bin_width

    # 修改Sinkhorn距离计算部分
    # 随机采样一个较小的子集来计算Sinkhorn距离
    sample_size = 5000  # 减少样本量
    indices = torch.randperm(len(pred_last))[:sample_size]
    
    pred_last_sampled = all_preds[indices, 0, :].reshape(-1, all_preds.shape[-1])
    true_last_sampled = all_trues[indices, 0, :].reshape(-1, all_trues.shape[-1])
    
    # 计算Sinkhorn散度
    sinkhorn_loss = geomloss.SamplesLoss(
        loss='sinkhorn',
        p=2,
        blur=0.1,
        scaling=0.95,
        backend='tensorized'
    )
    
    # 使用采样后的数据计算Sinkhorn距离
    pred_last_reshaped = pred_last_sampled.to(test_device)
    true_last_reshaped = true_last_sampled.to(test_device)

    # 在GPU上计算Sinkhorn距离
    sinkhorn_distance = sinkhorn_loss(pred_last_reshaped, true_last_reshaped)
    
    # 将结果转回CPU以进行保存和打印
    mise_last = float(mise_last)
    sinkhorn_distance = float(sinkhorn_distance.cpu())

    # 保存MISE和Sinkhorn散度结果
    with open(result_dir / 'evaluation_metrics.txt', 'w') as f:
        f.write(f"Mean Integrated Squared Error (MISE) for last timestep:\n")
        f.write(f"MISE: {mise_last:.6f}\n")
        f.write(f"\nSinkhorn Distance for last timestep:\n")
        f.write(f"Sinkhorn: {sinkhorn_distance:.6f}\n")

    print(f"MISE for last timestep: {mise_last:.6f}")
    print(f"Sinkhorn distance for last timestep: {sinkhorn_distance:.6f}")

    # 计算尾部累计概率差异
    def calculate_tail_diff(data1, data2, percentile=0.01):
        """Calculate the cumulative distribution difference in tail regions.

        Args:
            data1: numpy.ndarray, samples from first distribution
            data2: numpy.ndarray, samples from second distribution
            percentile: float, percentile threshold for tail region (default: 0.01)

        Returns:
            tuple:
                - float: Absolute difference in tail probabilities
                - float: Threshold value used for tail definition
        """
        # 计算阈值（使用两个分布数据的组合来确定阈值）
        combined_data = np.concatenate([data1, data2])
        threshold = np.percentile(combined_data, percentile)
        
        # 计算小于等于阈值的数据点比例（即CDF(threshold)）
        tail_prob1 = (data1 <= threshold).mean()
        tail_prob2 = (data2 <= threshold).mean()
        
        # 计算尾部CDF差异的绝对值
        tail_diff = abs(tail_prob1 - tail_prob2)
        
        return tail_diff, threshold

    # 计算0.01分位数处的尾部差异
    tail_diff, threshold = calculate_tail_diff(
        true_last.numpy(), 
        pred_last.numpy(), 
        percentile=0.01
    )

    # 将尾部差异结果添加到评估指标文件中
    with open(result_dir / 'evaluation_metrics.txt', 'a') as f:
        f.write(f"\n尾部累计概率差异 (Tail CDF Difference at 0.01 percentile):\n")
        f.write(f"Threshold (T) = {threshold:.6f}\n")
        f.write(f"Tail Difference: {tail_diff:.6f}\n")

    # 打印尾部差异结果
    print(f"\n尾部累计概率差异 (T = {threshold:.6f}):")
    print(f"Tail Difference: {tail_diff:.6f}")

    # 修改样本路径可视化部分
    num_samples = 5
    selected_indices = torch.randperm(len(all_preds))[:num_samples]
    
    # 创建时间步数组 - 从100到150
    time_steps = np.arange(100, 100 + output_len)
    
    # 准备保存路径数据的字典
    path_data = {
        'time_step': time_steps
    }
    
    # 定义高斯滤波函数
    def gaussian_smooth(data, sigma=2):
        """Apply Gaussian smoothing to the input data.

        Args:
            data: numpy.ndarray, input data to be smoothed
            sigma: float, standard deviation for Gaussian kernel (default: 2)

        Returns:
            numpy.ndarray: Smoothed data
        """
        window_size = 2 * int(4 * sigma) + 1
        gauss = np.exp(-np.square(np.arange(-4*sigma, 4*sigma+1)/(2*sigma**2)))
        kernel = gauss/np.sum(gauss)
        padded_data = np.pad(data, (window_size//2, window_size//2), mode='edge')
        return np.convolve(padded_data, kernel, mode='valid')
    
    # 绘制路径对比图并收集数据
    plt.figure(figsize=(15, 8))
    colors = sns.color_palette("husl", num_samples)
    
    for i, idx in enumerate(selected_indices):
        # 获取预测和真实路径
        pred_path = all_preds[idx, :, 0].numpy()
        true_path = all_trues[idx, :, 0].numpy()
        pred_path_smooth = gaussian_smooth(pred_path, sigma=1.5)
        
        # 保存数据到字典
        path_data[f'原始预测_{i+1}'] = pred_path
        path_data[f'平滑预测_{i+1}'] = pred_path_smooth
        path_data[f'真实路径_{i+1}'] = true_path
        
        # 绘制图像部分保持不变
        plt.plot(time_steps, pred_path, 
                color=colors[i], linestyle='-', alpha=0.2,
                label=f'原始预测 {i+1}')
        plt.plot(time_steps, pred_path_smooth, 
                color=colors[i], linestyle='-', alpha=0.8, linewidth=2,
                label=f'平滑预测 {i+1}')
        plt.plot(time_steps, true_path, 
                color=colors[i], linestyle='--', alpha=0.8,
                label=f'真实路径 {i+1}')
    
    # 保存路径数据到CSV文件
    df = pd.DataFrame(path_data)
    csv_path = result_dir / 'sample_paths_data.csv'
    df.to_csv(csv_path, index=False)
    
    # 图像相关设置保持不变
    plt.title('预测路径与真实路径对比 (t=100~150)')
    plt.xlabel('时间步')
    plt.ylabel('数值')
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    result_path = result_dir / 'sample_paths_comparison.jpg'
    plt.savefig(result_path, dpi=300, bbox_inches='tight')
    plt.close()

    # 清理GPU内存
    del pred_last_reshaped, true_last_reshaped
    torch.cuda.empty_cache()
