import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pickle
import os
from matplotlib.animation import FuncAnimation
import glob

H36M_CONNECTIONS = [
    (0, 1), (1, 2), (2, 3),
    (0, 4), (4, 5), (5, 6),
    (0, 7), (7, 8), (8, 9),
    (3, 10), (10, 11), (11, 12),
    (3, 13), (13, 14), (14, 15)
]

DIMS_PER_FRAME = 99
JOINTS_PER_FRAME = DIMS_PER_FRAME // 3

def load_and_process_sample(npz_file, scaler_file):
    data = np.load(npz_file)
    input_sequence = data['input_sequence']
    target_sequence = data['target_sequence']
    reconstruction = data['reconstruction']
    with open(scaler_file, 'rb') as f:
        scaler = pickle.load(f)
    frames = len(input_sequence) // DIMS_PER_FRAME
    sequences = {}
    for name, seq in [('input', input_sequence),
                      ('target', target_sequence),
                      ('reconstruction', reconstruction)]:
        if len(seq) % DIMS_PER_FRAME != 0:
            max_frames = len(seq) // DIMS_PER_FRAME
            seq = seq[:max_frames * DIMS_PER_FRAME]
        reshaped = seq.reshape(frames, DIMS_PER_FRAME)
        try:
            denormalized = scaler.inverse_transform(reshaped)
        except Exception as e:
            print(f"Failed to inverse transform {name} sequence, error: {e}")
            denormalized = reshaped
        final = denormalized.reshape(frames, JOINTS_PER_FRAME, 3)
        sequences[name] = final
    return sequences

def visualize_pose_comparison(sequences, sample_num):
    num_frames = sequences['input'].shape[0]
    frame_indices = [0, num_frames // 2, num_frames - 1]
    fig = plt.figure(figsize=(15, 12))
    plt.suptitle(f'Sample {sample_num} - Pose Comparison', fontsize=16)
    all_data = np.concatenate([sequences['input'], sequences['target'], sequences['reconstruction']])
    min_val = np.min(all_data)
    max_val = np.max(all_data)
    margin = (max_val - min_val) * 0.1
    seq_props = {
        'input': {'color': 'b', 'title': 'Input Sequence'},
        'target': {'color': 'g', 'title': 'Target Sequence'},
        'reconstruction': {'color': 'r', 'title': 'Reconstructed Sequence'}
    }
    for i, frame_idx in enumerate(frame_indices):
        for j, (seq_name, props) in enumerate(seq_props.items()):
            ax = fig.add_subplot(3, 3, i * 3 + j + 1, projection='3d')
            joints = sequences[seq_name][frame_idx]
            for start, end in H36M_CONNECTIONS:
                if start < joints.shape[0] and end < joints.shape[0]:
                    ax.plot([joints[start, 0], joints[end, 0]],
                            [joints[start, 1], joints[end, 1]],
                            [joints[start, 2], joints[end, 2]],
                            color=props["color"], linestyle='-', linewidth=2)
            ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color=props["color"], marker='o')
            ax.set_xlim(min_val - margin, max_val + margin)
            ax.set_ylim(min_val - margin, max_val + margin)
            ax.set_zlim(min_val - margin, max_val + margin)
            if i == 2:
                ax.set_xlabel('X')
                ax.set_ylabel('Y')
                ax.set_zlabel('Z')
            if i == 0:
                ax.set_title(f'{props["title"]}')
            ax.text2D(0.05, 0.95, f'Frame {frame_idx + 1}', transform=ax.transAxes)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    return fig

def create_pose_animation(sequence, title, output_path):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    min_val = np.min(sequence)
    max_val = np.max(sequence)
    margin = (max_val - min_val) * 0.1

    def update(frame):
        ax.clear()
        joints = sequence[frame]
        for start, end in H36M_CONNECTIONS:
            if start < joints.shape[0] and end < joints.shape[0]:
                ax.plot([joints[start, 0], joints[end, 0]],
                        [joints[start, 1], joints[end, 1]],
                        [joints[start, 2], joints[end, 2]],
                        color='blue', linestyle='-', linewidth=2)
        ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='red', marker='o')
        ax.set_xlim(min_val - margin, max_val + margin)
        ax.set_ylim(min_val - margin, max_val + margin)
        ax.set_zlim(min_val - margin, max_val + margin)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f'{title} - Frame {frame + 1}/{len(sequence)}')

    ani = FuncAnimation(fig, update, frames=len(sequence), interval=100, repeat=True)
    ani.save(output_path, writer='pillow', fps=10)
    plt.close(fig)
    print(f"Animation saved to {output_path}")

def main():
    results_dir = './results'
    scaler_file = os.path.join(results_dir, 'scaler.pkl')
    sample_files = glob.glob(os.path.join(results_dir, 'sample_*.npz'))
    for sample_file in sample_files:
        sample_num = os.path.basename(sample_file).split('_')[1].split('.')[0]
        print(f"Processing sample {sample_num}...")
        sequences = load_and_process_sample(sample_file, scaler_file)
        for name, seq in sequences.items():
            print(f"{name} shape: {seq.shape}")
            print(f"{name} range: [{np.min(seq):.4f}, {np.max(seq):.4f}]")
        fig = visualize_pose_comparison(sequences, sample_num)
        comparison_path = os.path.join(results_dir, f'sample_{sample_num}_comparison.png')
        plt.savefig(comparison_path)
        plt.close(fig)
        print(f"Comparison plot saved to {comparison_path}")
        for seq_name, sequence in sequences.items():
            output_path = os.path.join(results_dir, f'sample_{sample_num}_{seq_name}.gif')
            create_pose_animation(sequence, f'Sample {sample_num} - {seq_name.capitalize()}', output_path)

if __name__ == "__main__":
    main()
