#!/usr/bin/env python3
"""
Run Gradient Analysis

Reproduces Figure 1: Layer-wise gradient localization analysis.

Usage:
    python scripts/run_gradient_analysis.py --model_path outputs/gdo_dpo/checkpoint-1000 \
                                             --difficulty_scores outputs/gdo_dpo/difficulty_scores.npz \
                                             --output_dir outputs/analysis
"""

import argparse
import os
import sys
import torch
import numpy as np
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from src.analysis.gradient_analysis import GradientAnalyzer
from src.data.data_loader import PreferenceDatasetLoader


def parse_args():
    parser = argparse.ArgumentParser(description="Run gradient analysis")
    parser.add_argument("--model_path", type=str, required=True,
                       help="Path to model checkpoint")
    parser.add_argument("--difficulty_scores", type=str, required=True,
                       help="Path to precomputed difficulty scores")
    parser.add_argument("--dataset", type=str, default="ultrafeedback",
                       choices=["ultrafeedback", "hh-rlhf"])
    parser.add_argument("--num_samples", type=int, default=1000,
                       help="Number of samples per tercile")
    parser.add_argument("--output_dir", type=str, default="outputs/analysis")
    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("\n" + "="*60)
    print("Gradient Analysis for Layer-wise Localization")
    print("="*60)

    # Load model and tokenizer
    print(f"\nLoading model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Get number of layers
    num_layers = model.config.num_hidden_layers
    print(f"Model has {num_layers} layers")

    # Load dataset
    print(f"\nLoading {args.dataset} dataset")
    data_loader = PreferenceDatasetLoader(tokenizer)

    if args.dataset == "ultrafeedback":
        dataset = data_loader.load_ultrafeedback(split='train')
    else:
        dataset = data_loader.load_hh_rlhf(split='train')

    # Load difficulty scores
    print(f"\nLoading difficulty scores from {args.difficulty_scores}")
    difficulty_scores = dict(np.load(args.difficulty_scores))

    # Attach to dataset
    dataset = data_loader.attach_difficulty_scores(dataset, difficulty_scores)

    # Convert to list of dicts
    dataset_list = [dict(sample) for sample in dataset]

    # Create analyzer
    analyzer = GradientAnalyzer(
        model=model,
        num_layers=num_layers,
        device="cuda"
    )

    # Run analysis
    print("\n" + "="*60)
    print("Analyzing gradient patterns by difficulty dimensions")
    print("="*60)

    gradient_stats = analyzer.analyze_by_difficulty(
        dataset=dataset_list,
        tokenizer=tokenizer,
        num_samples=args.num_samples,
        tercile_threshold=0.33
    )

    # Plot results (Figure 1)
    output_path = os.path.join(args.output_dir, "gradient_localization.pdf")
    analyzer.plot_gradient_localization(gradient_stats, save_path=output_path)

    # Save numerical results
    results_path = os.path.join(args.output_dir, "gradient_stats.npz")
    np.savez(
        results_path,
        **{k: np.array(list(v.values())) for k, v in gradient_stats.items()}
    )
    print(f"\nSaved numerical results to {results_path}")

    print("\n" + "="*60)
    print("Gradient Analysis Complete!")
    print("="*60)


if __name__ == "__main__":
    main()
