"""
Main Training Script for Grokking Detection
Simple interface that calls the model_class functions
"""

import argparse
from datetime import datetime

# Import our modules
from data_loader import create_dataloaders, get_available_subsets
from Grokking_model import (
    create_model_and_tokenizer,
    GrokkingDetector,
    get_model_info,
    train_with_grokking_detection  # This function should be in model_class.py
)


def main():
    """Main function - just calls the training function from model_class"""

    # ========================================
    # CONFIGURATION - Change these as needed
    # ========================================

    # Model to use - just change this line to use different models!
    MODEL_NAME = "gpt2"  # Try: "microsoft/phi-1_5", "microsoft/DialoGPT-small", etc.

    # Dataset configuration
    TOFU_SUBSET = "full"  # Try:forget01 "forget05", "forget10", "retain90", etc.

    # Training configuration for pretraining
    MAX_STEPS = 5000        # Many steps for proper pretraining
    BATCH_SIZE = 8          # Larger batch size for pretraining
    LEARNING_RATE = 1e-4    # Higher learning rate for pretraining
    MAX_LENGTH = 128
    EVAL_INTERVAL = 200     # Less frequent evaluation for efficiency
    USE_DUMMY = False

    # ========================================
    # RUN TRAINING WITH GROKKING DETECTION
    # ========================================

    print("="*60)
    print("GROKKING DETECTION EXPERIMENT")
    print("="*60)
    print(f"Model: {MODEL_NAME}")
    print(f"Dataset: TOFU ({TOFU_SUBSET})")
    print(f"Training steps: {MAX_STEPS}")
    print(f"Available subsets: {list(get_available_subsets().keys())}")

    # Create model and tokenizer
    print("\n📦 Loading model and tokenizer...")
    model, tokenizer = create_model_and_tokenizer(MODEL_NAME)
    model_info = get_model_info(model)
    print(f"✓ Model loaded: {model_info['total_params']:,} parameters")

    # Create dataloaders
    print("\n📚 Loading dataset...")
    train_loader, val_loader = create_dataloaders(
        tokenizer=tokenizer,
        subset=TOFU_SUBSET,
        max_length=MAX_LENGTH,
        batch_size=BATCH_SIZE,
        use_dummy=USE_DUMMY
    )
    print(f"✓ Dataset loaded: {len(train_loader)} train batches, {len(val_loader)} val batches")

    # Run training with grokking detection
    print("\n🚀 Starting one-pass pretraining with grokking detection...")
    print("="*60)

    detector = train_with_grokking_detection(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        max_steps=MAX_STEPS,
        learning_rate=LEARNING_RATE,
        eval_interval=EVAL_INTERVAL,
        warmup_steps=MAX_STEPS // 20,  # 5% warmup
        weight_decay=0.01,
        gradient_clip=1.0
    )

    # ========================================
    # RESULTS
    # ========================================

    print("\n" + "="*60)
    print("TRAINING COMPLETED!")
    print("="*60)

    # Show results
    summary = detector.get_summary()
    print(f"Total training steps: {summary['total_steps']}")
    print(f"Grokking events detected: {summary['num_grokking_events']}")

    if summary['grokking_events']:
        print(f"First grokking at step: {summary['first_grokking_step']}")
        print(f"All grokking steps: {summary['grokking_events']}")
    else:
        print("No grokking events detected")

    # Generate plots and save results
    plot_filename = f"grokking_plot_{MODEL_NAME.replace('/', '_')}_{MAX_STEPS}_steps.png"
    detector.plot_metrics(save_path=plot_filename)
    print(f"✓ Plot saved to: {plot_filename}")

    # Save detailed results
    results = {
        "model": MODEL_NAME,
        "tofu_subset": TOFU_SUBSET,
        "max_steps": MAX_STEPS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "timestamp": datetime.now().isoformat(),
        "model_info": model_info,
        "grokking_summary": summary
    }

    import json
    results_filename = f"grokking_results_{MODEL_NAME.replace('/', '_')}_{MAX_STEPS}_steps.json"
    with open(results_filename, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"✓ Results saved to: {results_filename}")

    # Final summary
    print("\n" + "="*60)
    print("EXPERIMENT SUMMARY")
    print("="*60)
    print(f"✓ Trained {MODEL_NAME} for {MAX_STEPS} steps")
    print(f"✓ Used TOFU subset: {TOFU_SUBSET}")
    print(f"✓ Detected {summary['num_grokking_events']} grokking events")

    if summary['grokking_events']:
        print(f"🎯 First grokking occurred at step {summary['first_grokking_step']}")
    else:
        print("⚠️  No grokking detected - try longer training or different hyperparameters")

    print("\n" + "="*60)
    print("TO USE DIFFERENT MODELS:")
    print("="*60)
    print("Just change MODEL_NAME at the top of main():")
    print('MODEL_NAME = "microsoft/phi-1_5"       # For Phi-1.5')
    print('MODEL_NAME = "microsoft/DialoGPT-small" # For DialoGPT')
    print('MODEL_NAME = "gpt2"                     # For GPT-2')
    print('MODEL_NAME = "scratch"                  # Train from scratch')
    print("\nEverything else stays the same!")


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='TOFU Grokking Detection Training')
    parser.add_argument('--model', type=str, default='microsoft/Phi-4-mini-instruct',
                       help='Model to use (default: gpt2)')
    parser.add_argument('--subset', type=str, default='full',
                       help='TOFU subset to use (default: forget01)')
    parser.add_argument('--steps', type=int, default=11000000,
                       help='Training steps (default: 1000)')
    parser.add_argument('--batch-size', type=int, default=1,
                       help='Batch size (default: 4)')
    parser.add_argument('--lr', type=float, default=1e-4,
                       help='Learning rate (default: 5e-5)')
    parser.add_argument('--dummy', action='store_true',
                       help='Use dummy data for testing')

    args = parser.parse_args()

    # If command line args provided, use those
    if any(vars(args).values()):
        print(f"Using command line arguments:")
        print(f"Model: {args.model}, Subset: {args.subset}, Steps: {args.steps}")

        # Create model and data
        model, tokenizer = create_model_and_tokenizer(args.model)
        train_loader, val_loader = create_dataloaders(
            tokenizer=tokenizer,
            subset=args.subset,
            batch_size=args.batch_size,
            use_dummy=args.dummy
        )

        # Train with grokking detection
        detector = train_with_grokking_detection(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            max_steps=args.steps,
            learning_rate=args.lr,
            eval_interval=100,
            warmup_steps=args.steps // 20,
            weight_decay=0.01,
            gradient_clip=1.0
        )

        # Show results
        summary = detector.get_summary()
        print(f"\nTraining completed! Grokking events: {summary['num_grokking_events']}")
        if summary['grokking_events']:
            print(f"Grokking steps: {summary['grokking_events']}")

        # Save plot
        detector.plot_metrics(save_path=f"grokking_{args.model.replace('/', '_')}_{args.steps}.png")

    else:
        # Use main function defaults
        main()