#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
from mpl_toolkits.mplot3d import Axes3D  # 用于3D绘图
import torch
import torch.utils.data
from torch import nn, optim

# 仅设置这两行字体参数
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 21
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'

# 导入你训练时使用的数据集和模型代码（请根据实际项目路径调整）
from motion.dataset import MotionDataset
from models.model import *  # 包含 ESTAG、STFT 等模型定义
from models.model_t import EqMotion  # 如果需要

# ---------------------------
# 1. 命令行参数设置
# ---------------------------
parser = argparse.ArgumentParser(description='Motion Visualization')

# 数据加载相关参数
parser.add_argument('--batch_size', type=int, default=100,
                    help='输入 batch 大小（训练时用）')
parser.add_argument('--max_training_samples', type=int, default=3000,
                    help='最大训练样本数量')
parser.add_argument('--data_dir', type=str, default='motion',
                    help='数据目录')
parser.add_argument('--delta_frame', type=int, default=1,
                    help='帧间隔')
parser.add_argument('--num_past', type=int, default=10,
                    help='用于预测的过去帧数')
parser.add_argument('--time_gap', type=int, default=10,
                    help='过去帧与未来帧之间的时间间隔')
parser.add_argument('--case', type=str, default='walk', choices=['walk', 'run', 'basketball'],
                    help='运动类型')

# 模型相关参数（保留部分训练时的参数）
parser.add_argument('--model', type=str, default='estag',
                    help='可选模型: estag, stft, ...')
parser.add_argument('--n_layers', type=int, default=2,
                    help='网络层数')
parser.add_argument('--nf', type=int, default=16,
                    help='隐藏层维度')
parser.add_argument('--fft', type=eval, default=True,
                    help='是否使用 FFT')
parser.add_argument('--eat', type=eval, default=True,
                    help='是否使用 EAT')
parser.add_argument('--with_mask', action='store_true', default=False,
                    help='使用 EAT 时是否屏蔽未来帧')
parser.add_argument('--tempo', type=eval, default=True,
                    help='是否使用 temporal pooling')
parser.add_argument('--seed', type=int, default=1,
                    help='随机种子')
parser.add_argument('--lr', type=float, default=5e-3,
                    help='学习率')

# 新增参数：选择要可视化的样本索引（从数据集里挑选的样本）
parser.add_argument('--sample_idx', type=int, default=0,
                    help='要可视化的样本索引')

args = parser.parse_args()

# ---------------------------
# 2. 固定随机种子及设备设置
# ---------------------------
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------
# 3. 定义可视化函数（采用 notebook 中的版本）
# ---------------------------
def plot(x_gt, x_pred, edges, case, method):
    """
    利用 3D 散点图和连线显示真实数据与预测结果的骨架对比。

    参数：
      x_gt: (N, 3) 的 numpy 数组，真实位置（ground truth）
      x_pred: (N, 3) 的 numpy 数组，预测位置
      edges: list，每个元素为二元组，如 (i, j)，表示要连线的两个关节点索引
      case: 字符串，运动类型（例如 'walk'）
      method: 字符串，预测方法名称（例如 'estag' 或 'stft'）
    """
    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111, projection='3d')
    lw = 1.2

    # 散点：按照 x, z, y 顺序显示，并添加图例标签
    ax.scatter(x_gt[..., 0], x_gt[..., 2], x_gt[..., 1], color="r", label="GT", s=50)
    ax.scatter(x_pred[..., 0], x_pred[..., 2], x_pred[..., 1], color="b", label="Prediction", s=50)

    for edge in edges:
        # 这里 edge 为 (i, j) 形式，直接用 numpy 索引即可
        s_gt, t_gt = x_gt[edge]
        s_pred, t_pred = x_pred[edge]
        pos_gt = list(zip(s_gt, t_gt))
        pos_pred = list(zip(s_pred, t_pred))
        ax.plot(pos_gt[0], pos_gt[2], pos_gt[1], color='r', ls='-', linewidth=2)
        ax.plot(pos_pred[0], pos_pred[2], pos_pred[1], color='b', ls='-', linewidth=2)

    # 添加图例
    ax.legend()
    # 添加标题，若 method 为 'stft' 则标题显示为 NS-EGNN，否则按 method 大写显示
    caption = "NS-EGNN" if method.lower() == "stft" else method.upper()
    ax.set_title(f"{case.capitalize()} - {caption}")

    # 调用 tight_layout 去除白边
    plt.tight_layout()

    # 保存为 PDF 格式
    save_path = os.path.join("figures", f'{case}_{method}.pdf')
    plt.savefig(save_path)
    print(f"Figure saved to {save_path}")
    plt.show()


# ---------------------------
# 4. 定义预处理转换函数（保持与训练时一致）
# ---------------------------
def transform_tensor(d):
    """
    模仿训练时的预处理：
      - 如果 d 为 3 维，则使用 view(-1, d.size(2))（例如将 [1, n_nodes, feat] 转为 [n_nodes, feat]）
      - 如果 d 为 4 维，则先 permute(1, 0, 2, 3)（将 num_past 移到最前），再 reshape 成 [num_past, batch_size*n_nodes, feat]
    """
    if len(d.shape) == 3:
        return d.view(-1, d.size(2))
    else:
        return (d.permute(1, 0, 2, 3)).reshape(d.size(1), -1, d.size(3))


# ---------------------------
# 5. 主程序：加载数据、模型、推理及可视化
# ---------------------------
def main():
    # ① 加载数据集（仅用于获得模型推理所需的输入和 ground truth）
    dataset = MotionDataset(partition='test',
                            max_samples=600,
                            data_dir=args.data_dir,
                            delta_frame=args.delta_frame,
                            num_past=args.num_past,
                            case=args.case,
                            time_gap=args.time_gap)
    print(f"Dataset loaded: {args.case} (test partition), total samples: {len(dataset)}")

    # 选取指定样本（索引由 --sample_idx 指定）
    sample = dataset[args.sample_idx]
    # 假设 sample 返回 (loc_raw, edge_attr_raw, charges_raw, loc_end_raw)
    loc_raw, edge_attr_raw, charges_raw, loc_end_raw = sample

    # 为模拟 batch_size=1，给所有数据添加 batch 维度，并转到 device 上
    loc = loc_raw.unsqueeze(0).to(device)  # shape: [1, num_past, n_nodes, 3]
    edge_attr = edge_attr_raw.unsqueeze(0).to(device)  # shape: [1, num_past, num_edges, edge_feat]
    charges = charges_raw.unsqueeze(0).to(device)  # shape: [1, n_nodes, charge_dim]
    loc_end = loc_end_raw.unsqueeze(0).to(device)  # shape: [1, n_nodes, 3]

    # 将数据转换为模型 forward 时所需的格式
    loc_t = transform_tensor(loc)  # 预期 shape: [num_past, n_nodes, 3]
    edge_attr_t = transform_tensor(edge_attr)  # 预期 shape: [num_past, num_edges, edge_feat]
    charges_t = transform_tensor(charges)  # 预期 shape: [n_nodes, charge_dim]
    loc_end_t = transform_tensor(loc_end)  # 预期 shape: [n_nodes, 3]

    print("Transformed shapes:")
    print(" loc_t:", loc_t.shape)
    print(" edge_attr_t:", edge_attr_t.shape)
    print(" charges_t:", charges_t.shape)
    print(" loc_end_t:", loc_end_t.shape)

    # ② 获取用于模型前向传播的边信息：调用数据集提供的 get_edges 方法
    n_nodes = loc_end_t.shape[0]  # 经过 transform 后，loc_end_t shape 为 [n_nodes, 3]
    edges_tensor = dataset.get_edges(1, n_nodes)  # 返回 (src, tgt)
    # 将 edges_tensor 转到 device 上
    src, tgt = edges_tensor
    src = src.to(device)
    tgt = tgt.to(device)
    edges_tensor = (src, tgt)
    print(f"Edges from dataset: number of edges = {src.shape[0]}")

    # ③ 同时加载用于可视化的边信息，从 pkl 文件中获得
    pkl_path = os.path.join("motion", f"motion_{args.case}.pkl")
    with open(pkl_path, "rb") as f:
        pkl_edges, _ = pkl.load(f)
    print(f"Edges loaded from {pkl_path} for visualization, number of edges: {len(pkl_edges)}")

    # ④ 加载模型：分别加载 ESTAG 和 STFT 模型（保存路径依据 args.case 构造）
    estag_model_path = os.path.join("logs", "motion_logs", f"exp_{args.case}", "estag", "saved_model.pth")
    stft_model_path = os.path.join("logs", "motion_logs", f"exp_{args.case}", "stft", "saved_model.pth")

    print("Loading ESTAG model from:", estag_model_path)
    estag_model = torch.load(estag_model_path, map_location=device)
    estag_model.eval()

    print("Loading STFT model from:", stft_model_path)
    stft_model = torch.load(stft_model_path, map_location=device)
    stft_model.eval()

    # ⑤ 推理：使用前 num_past 帧预测未来 1 帧（模型调用方式与训练时一致）
    with torch.no_grad():
        pred_estag = estag_model(charges_t, loc_t, edges_tensor, edge_attr_t)
        pred_stft = stft_model(charges_t, loc_t, edges_tensor, edge_attr_t)

    # 假设模型输出形状为 [n_nodes, 3]，即预测未来帧各关节点位置
    pred_estag_np = pred_estag.cpu().numpy()  # shape: [n_nodes, 3]
    pred_stft_np = pred_stft.cpu().numpy()  # shape: [n_nodes, 3]
    gt_np = loc_end_t.cpu().numpy()  # ground truth, shape: [n_nodes, 3]

    # ⑥ 可视化：调用 plot 函数对比真实数据与预测结果（使用 pkl_edges 进行绘图）
    plot(gt_np, pred_estag_np, pkl_edges, args.case, "estag")
    plot(gt_np, pred_stft_np, pkl_edges, args.case, "stft")


if __name__ == "__main__":
    main()