import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import pickle
from scipy.spatial.transform import Rotation
from scipy.linalg import orthogonal_procrustes
from collections import defaultdict

DIMS_PER_FRAME = 26
JOINTS_PER_FRAME = 13

def procrustes_alignment(source, target):
    source_centroid = np.mean(source, axis=0)
    target_centroid = np.mean(target, axis=0)
    source_centered = source - source_centroid
    target_centered = target - target_centroid
    R, _ = orthogonal_procrustes(source_centered, target_centered)
    source_aligned = np.dot(source_centered, R) + target_centroid
    return source_aligned

def calculate_mpjpe(pred, target):
    return np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=-1)))

def calculate_pampjpe(pred, target):
    aligned_pred = np.zeros_like(pred)
    for i in range(pred.shape[0]):
        aligned_pred[i] = procrustes_alignment(pred[i], target[i])
    return np.mean(np.sqrt(np.sum((aligned_pred - target) ** 2, axis=-1)))

def calculate_mpjve(pred, target):
    pred_vel = pred[1:] - pred[:-1]
    target_vel = target[1:] - target[:-1]
    return np.mean(np.sqrt(np.sum((pred_vel - target_vel) ** 2, axis=-1)))

def calculate_diversity(sequences):
    diversity = 0
    count = 0
    for i in range(len(sequences)):
        for j in range(i + 1, len(sequences)):
            diversity += calculate_mpjpe(sequences[i], sequences[j])
            count += 1
    if count > 0:
        diversity /= count
    return diversity

def load_and_process_sample(npz_file, scaler_file):
    data = np.load(npz_file)
    input_sequence = data['input_sequence']
    target_sequence = data['target_sequence']
    reconstruction = data['reconstruction']
    with open(scaler_file, 'rb') as f:
        scaler = pickle.load(f)
    frames = len(input_sequence) // DIMS_PER_FRAME
    sequences = {}
    for name, seq in [('input', input_sequence), ('target', target_sequence), ('reconstruction', reconstruction)]:
        if len(seq) % DIMS_PER_FRAME != 0:
            max_frames = len(seq) // DIMS_PER_FRAME
            seq = seq[:max_frames * DIMS_PER_FRAME]
        reshaped = seq.reshape(frames, DIMS_PER_FRAME)
        try:
            denormalized = scaler.inverse_transform(reshaped)
        except Exception:
            denormalized = reshaped
        final = denormalized.reshape(frames, JOINTS_PER_FRAME, 2)
        sequences[name] = final
    return sequences

def visualize_metrics(metrics, output_dir):
    sample_ids = sorted(list(metrics.keys()))
    mpjpe_values = [metrics[s]['mpjpe'] for s in sample_ids]
    pampjpe_values = [metrics[s]['pampjpe'] for s in sample_ids]
    mpjve_values = [metrics[s]['mpjve'] for s in sample_ids]
    plt.style.use('ggplot')
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    bar_color = '#4C72B0'
    avg_line_color = '#C44E52'
    bars = axs[0].bar(sample_ids, mpjpe_values, color=bar_color, alpha=0.7)
    axs[0].set_title('Mean Per Joint Position Error (MPJPE)', fontsize=14, fontweight='bold')
    axs[0].set_xlabel('Sample ID', fontsize=12)
    axs[0].set_ylabel('Error (pixel)', fontsize=12)
    axs[0].grid(True, linestyle='--', alpha=0.7)
    avg_mpjpe = np.mean(mpjpe_values)
    axs[0].axhline(y=avg_mpjpe, color=avg_line_color, linestyle='-', linewidth=2, label=f'Average: {avg_mpjpe:.4f}')
    for bar in bars:
        height = bar.get_height()
        axs[0].annotate(f'{height:.4f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3),
                        textcoords="offset points", ha='center', va='bottom', fontsize=9, rotation=90)
    axs[0].legend()
    bars = axs[1].bar(sample_ids, pampjpe_values, color=bar_color, alpha=0.7)
    axs[1].set_title('Procrustes Aligned MPJPE (PA-MPJPE)', fontsize=14, fontweight='bold')
    axs[1].set_xlabel('Sample ID', fontsize=12)
    axs[1].set_ylabel('Error (pixel)', fontsize=12)
    axs[1].grid(True, linestyle='--', alpha=0.7)
    avg_pampjpe = np.mean(pampjpe_values)
    axs[1].axhline(y=avg_pampjpe, color=avg_line_color, linestyle='-', linewidth=2, label=f'Average: {avg_pampjpe:.4f}')
    for bar in bars:
        height = bar.get_height()
        axs[1].annotate(f'{height:.4f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3),
                        textcoords="offset points", ha='center', va='bottom', fontsize=9, rotation=90)
    axs[1].legend()
    bars = axs[2].bar(sample_ids, mpjve_values, color=bar_color, alpha=0.7)
    axs[2].set_title('Mean Per Joint Velocity Error (MPJVE)', fontsize=14, fontweight='bold')
    axs[2].set_xlabel('Sample ID', fontsize=12)
    axs[2].set_ylabel('Error (pixel/frame)', fontsize=12)
    axs[2].grid(True, linestyle='--', alpha=0.7)
    avg_mpjve = np.mean(mpjve_values)
    axs[2].axhline(y=avg_mpjve, color=avg_line_color, linestyle='-', linewidth=2, label=f'Average: {avg_mpjve:.4f}')
    for bar in bars:
        height = bar.get_height()
        axs[2].annotate(f'{height:.4f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3),
                        textcoords="offset points", ha='center', va='bottom', fontsize=9, rotation=90)
    axs[2].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'evaluation_metrics.png'), dpi=300)
    plt.close()
    stats = {
        'mpjpe': {'mean': np.mean(mpjpe_values), 'std': np.std(mpjpe_values), 'min': np.min(mpjpe_values), 'max': np.max(mpjpe_values)},
        'pampjpe': {'mean': np.mean(pampjpe_values), 'std': np.std(pampjpe_values), 'min': np.min(pampjpe_values), 'max': np.max(pampjpe_values)},
        'mpjve': {'mean': np.mean(mpjve_values), 'std': np.std(mpjve_values), 'min': np.min(mpjve_values), 'max': np.max(mpjve_values)}
    }
    with open(os.path.join(output_dir, 'evaluation_stats.txt'), 'w') as f:
        f.write("Evaluation Metrics Statistics\n")
        f.write("============================\n\n")
        for metric_name, metric_stats in stats.items():
            f.write(f"{metric_name.upper()}:\n")
            f.write(f"  Mean: {metric_stats['mean']:.4f}\n")
            f.write(f"  Std Dev: {metric_stats['std']:.4f}\n")
            f.write(f"  Min: {metric_stats['min']:.4f}\n")
            f.write(f"  Max: {metric_stats['max']:.4f}\n\n")
        if 'diversity' in metrics[sample_ids[0]]:
            diversity_values = [metrics[s]['diversity'] for s in sample_ids]
            avg_diversity = np.mean(diversity_values)
            f.write(f"DIVERSITY:\n")
            f.write(f"  Mean: {avg_diversity:.4f}\n")
            f.write(f"  Std Dev: {np.std(diversity_values):.4f}\n")
            f.write(f"  Min: {np.min(diversity_values):.4f}\n")
            f.write(f"  Max: {np.max(diversity_values):.4f}\n")

def main():
    results_dir = './results'
    scaler_file = os.path.join(results_dir, 'scaler.pkl')
    sample_files = sorted(glob.glob(os.path.join(results_dir, 'sample_*.npz')))
    all_metrics = {}
    all_reconstructions = []
    for sample_file in sample_files:
        sample_num = int(os.path.basename(sample_file).split('_')[1].split('.')[0])
        sequences = load_and_process_sample(sample_file, scaler_file)
        target = sequences['target']
        recon = sequences['reconstruction']
        all_reconstructions.append(recon)
        mpjpe = calculate_mpjpe(recon, target)
        pampjpe = calculate_pampjpe(recon, target)
        mpjve = calculate_mpjve(recon, target)
        all_metrics[sample_num] = {'mpjpe': mpjpe, 'pampjpe': pampjpe, 'mpjve': mpjve}
    if len(all_reconstructions) > 1:
        diversity = calculate_diversity(all_reconstructions)
        for sample_num in all_metrics:
            all_metrics[sample_num]['diversity'] = diversity
    visualize_metrics(all_metrics, results_dir)
    mpjpe_values = [metrics['mpjpe'] for metrics in all_metrics.values()]
    pampjpe_values = [metrics['pampjpe'] for metrics in all_metrics.values()]
    mpjve_values = [metrics['mpjve'] for metrics in all_metrics.values()]
    print("MPJPE mean:", np.mean(mpjpe_values), "std:", np.std(mpjpe_values))
    print("PA-MPJPE mean:", np.mean(pampjpe_values), "std:", np.std(pampjpe_values))
    print("MPJVE mean:", np.mean(mpjve_values), "std:", np.std(mpjve_values))

if __name__ == "__main__":
    main()
