#!/usr/bin/env python3
"""
Analyze Training Dynamics

Reproduces Figure 4 (curriculum dynamics) and Figure 5 (gradient variance).

Usage:
    python scripts/analyze_training_dynamics.py --curriculum_history outputs/gdo_dpo/curriculum_history.npz \
                                                  --output_dir outputs/analysis
"""

import argparse
import os
import sys
import numpy as np
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.analysis.training_dynamics import TrainingDynamicsAnalyzer


def parse_args():
    parser = argparse.ArgumentParser(description="Analyze training dynamics")
    parser.add_argument("--curriculum_history", type=str, required=True,
                       help="Path to curriculum history .npz file")
    parser.add_argument("--output_dir", type=str, default="outputs/analysis")
    parser.add_argument("--compare_baselines", type=str, nargs='*',
                       help="Paths to baseline curriculum histories for comparison")
    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("\n" + "="*60)
    print("Training Dynamics Analysis")
    print("="*60)

    # Load curriculum history
    print(f"\nLoading curriculum history from {args.curriculum_history}")
    history_data = np.load(args.curriculum_history)

    history = {
        'lambda_sem': history_data['lambda_sem'].tolist(),
        'lambda_unc': history_data['lambda_unc'].tolist(),
        'Srep': history_data['Srep'].tolist() if 'Srep' in history_data else [],
        'Adisc': history_data['Adisc'].tolist() if 'Adisc' in history_data else [],
    }

    print(f"Loaded {len(history['lambda_sem'])} curriculum update steps")

    # Create analyzer
    analyzer = TrainingDynamicsAnalyzer()

    # Plot curriculum dynamics (Figure 4)
    print("\nGenerating curriculum dynamics plot (Figure 4)...")
    dynamics_path = os.path.join(args.output_dir, "curriculum_dynamics.pdf")
    analyzer.plot_curriculum_dynamics(history, save_path=dynamics_path)

    # If baseline comparisons provided, plot gradient variance (Figure 5)
    if args.compare_baselines:
        print("\nGenerating gradient variance comparison (Figure 5)...")

        variance_histories = {}

        # Load baselines
        for baseline_path in args.compare_baselines:
            baseline_name = os.path.basename(os.path.dirname(baseline_path))
            baseline_data = np.load(baseline_path)

            if 'gradient_variance' in baseline_data:
                variance_histories[baseline_name] = baseline_data['gradient_variance'].tolist()

        # Add GDO-DPO
        if 'gradient_variance' in history_data:
            variance_histories['GDO-DPO'] = history_data['gradient_variance'].tolist()

        if variance_histories:
            variance_path = os.path.join(args.output_dir, "gradient_variance.pdf")
            analyzer.plot_gradient_variance(variance_histories, save_path=variance_path)

    print("\n" + "="*60)
    print("Analysis Complete!")
    print(f"Plots saved to {args.output_dir}")
    print("="*60)


if __name__ == "__main__":
    main()
