#!/usr/bin/env python3
"""
Run CKA Similarity Analysis

Reproduces Table 4: CKA Similarity with SFT Base Model

Usage:
    python scripts/run_cka_analysis.py --trained_model outputs/gdo_dpo/final_model \
                                        --base_model meta-llama/Meta-Llama-3-8B-Instruct \
                                        --dataset ultrafeedback \
                                        --output_dir outputs/analysis
"""

import argparse
import os
import sys
import json
import torch
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
from src.analysis.cka_analysis import CKAAnalyzer
from src.data.data_loader import PreferenceDatasetLoader


def parse_args():
    parser = argparse.ArgumentParser(description="Run CKA similarity analysis")
    parser.add_argument("--trained_model", type=str, required=True,
                       help="Path to trained model")
    parser.add_argument("--base_model", type=str, required=True,
                       help="Path or name of base SFT model")
    parser.add_argument("--dataset", type=str, default="ultrafeedback",
                       choices=["ultrafeedback", "hh-rlhf"])
    parser.add_argument("--num_samples", type=int, default=1000,
                       help="Number of samples for CKA computation")
    parser.add_argument("--output_dir", type=str, default="outputs/analysis")
    parser.add_argument("--method", type=str, default="linear",
                       choices=["linear", "rbf"])
    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("\n" + "="*60)
    print("CKA Similarity Analysis")
    print("="*60)

    # Load models
    print(f"\nLoading trained model from {args.trained_model}")
    trained_model = AutoModelForCausalLM.from_pretrained(
        args.trained_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    print(f"\nLoading base model: {args.base_model}")
    base_model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.trained_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Get number of layers
    num_layers = trained_model.config.num_hidden_layers
    print(f"\nModel has {num_layers} layers")

    # Define layer ranges (following Table 4)
    layer_ranges = {
        f"Layers 0-{num_layers//2 - 1}": list(range(0, num_layers//2)),
        f"Layers {num_layers//2}-{num_layers-1}": list(range(num_layers//2, num_layers)),
    }

    print(f"\nLayer ranges:")
    for range_name, layers in layer_ranges.items():
        print(f"  {range_name}: {len(layers)} layers")

    # Load dataset
    print(f"\nLoading {args.dataset} dataset")
    data_loader_obj = PreferenceDatasetLoader(tokenizer)

    if args.dataset == "ultrafeedback":
        dataset = data_loader_obj.load_ultrafeedback(split='test')
    else:
        dataset = data_loader_obj.load_hh_rlhf(split='test')

    # Limit samples
    if len(dataset) > args.num_samples:
        dataset = dataset.select(range(args.num_samples))

    print(f"Using {len(dataset)} samples for CKA computation")

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=False,
        collate_fn=data_loader_obj.collate_fn
    )

    # Create CKA analyzer
    analyzer = CKAAnalyzer(device="cuda")

    # Compute CKA
    print("\n" + "="*60)
    print(f"Computing {args.method} CKA similarity")
    print("="*60)

    cka_results = analyzer.compute_model_cka(
        model1=trained_model,
        model2=base_model,
        dataloader=dataloader,
        layer_ranges=layer_ranges,
        max_samples=args.num_samples,
        method=args.method
    )

    # Print results
    print("\n" + "="*60)
    print("CKA Similarity Results")
    print("="*60)
    for range_name, cka_score in cka_results.items():
        print(f"{range_name}: {cka_score:.4f}")

    # Save results
    results_path = os.path.join(args.output_dir, "cka_results.json")
    with open(results_path, 'w') as f:
        json.dump(cka_results, f, indent=2)

    print(f"\nResults saved to {results_path}")

    print("\n" + "="*60)
    print("CKA Analysis Complete!")
    print("="*60)


if __name__ == "__main__":
    main()
