"""
Inference Demo for HEdit

This script demonstrates how to use HEdit for identifying anchor and trigger tokens,
and applying KV corrections during inference.
"""

import argparse
import torch
from pathlib import Path

from hedit import AnchorTokenDetector, TriggerTokenDetector, KVCorrectionMLP


def parse_args():
    parser = argparse.ArgumentParser(description="HEdit Inference Demo")
    
    # Model arguments
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to the base language model")
    parser.add_argument("--mlp_checkpoint", type=str, default=None,
                        help="Path to trained MLP checkpoint (optional)")
    
    # Detection arguments
    parser.add_argument("--anchor_k", type=int, default=5,
                        help="Number of anchor tokens to identify")
    parser.add_argument("--trigger_k", type=int, default=5,
                        help="Number of trigger tokens to identify")
    parser.add_argument("--attention_layer", type=int, default=20,
                        help="Layer for attention analysis in anchor detection")
    parser.add_argument("--ffn_layer_start", type=int, default=20,
                        help="Start layer for FFN analysis in anchor detection")
    parser.add_argument("--ffn_layer_end", type=int, default=21,
                        help="End layer for FFN analysis in anchor detection")
    
    # Input arguments
    parser.add_argument("--input_text", type=str, default=None,
                        help="Input text to analyze")
    parser.add_argument("--input_file", type=str, default=None,
                        help="File containing input text (one per line)")
    
    # Output arguments
    parser.add_argument("--output_dir", type=str, default="./outputs",
                        help="Directory to save outputs")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to run on (cuda/cpu)")
    
    return parser.parse_args()


def detect_anchor_tokens(text, model_path, k_value, attention_layer, ffn_layer_start, ffn_layer_end):
    """Detect anchor tokens in the input text."""
    print("\n" + "="*60)
    print("Anchor Token Detection")
    print("="*60)
    
    detector = AnchorTokenDetector(
        model_path=model_path,
        k_value=k_value,
        attention_layer=attention_layer,
        ffn_layer_start=ffn_layer_start,
        ffn_layer_end=ffn_layer_end
    )
    
    result = detector.identify_anchor_tokens(text)
    detector.print_results(result)
    
    return result


def detect_trigger_tokens(text, model_path, k_value):
    """Detect trigger tokens in the input text."""
    print("\n" + "="*60)
    print("Trigger Token Detection")
    print("="*60)
    
    detector = TriggerTokenDetector(
        model_path=model_path,
        k_value=k_value
    )
    
    result = detector.identify_trigger_tokens(text)
    detector.print_results(result)
    
    return result


def load_mlp_model(checkpoint_path, device):
    """Load trained MLP model from checkpoint."""
    if checkpoint_path is None or not Path(checkpoint_path).exists():
        print("No MLP checkpoint provided or file not found, skipping KV correction")
        return None
    
    print("\n" + "="*60)
    print("Loading MLP Model")
    print("="*60)
    print(f"Checkpoint: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create model with saved configuration
    config = checkpoint['model_config']
    model = KVCorrectionMLP(
        input_dim=config['input_dim'],
        output_dim=config['output_dim']
    )
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded successfully")
    print(f"  Input dim: {config['input_dim']}")
    print(f"  Output dim: {config['output_dim']}")
    print(f"  Training epoch: {checkpoint['epoch']}")
    print(f"  Validation loss: {checkpoint['val_loss']:.6f}")
    
    return model


def main():
    args = parse_args()
    
    # Setup device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Prepare input texts
    if args.input_text:
        texts = [args.input_text]
    elif args.input_file:
        with open(args.input_file, 'r', encoding='utf-8') as f:
            texts = [line.strip() for line in f if line.strip()]
    else:
        # Default example
        texts = [
            "The capital of France is Paris, which is known for its art, fashion, and culture. "
            "Many tourists visit the Eiffel Tower every year. The tower was built in 1889."
        ]
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load MLP model if checkpoint provided
    mlp_model = load_mlp_model(args.mlp_checkpoint, device)
    
    # Process each text
    for idx, text in enumerate(texts):
        print(f"\n{'='*60}")
        print(f"Processing Text {idx + 1}/{len(texts)}")
        print(f"{'='*60}")
        print(f"Text: {text[:100]}..." if len(text) > 100 else f"Text: {text}")
        
        # Detect anchor tokens
        anchor_result = detect_anchor_tokens(
            text=text,
            model_path=args.model_path,
            k_value=args.anchor_k,
            attention_layer=args.attention_layer,
            ffn_layer_start=args.ffn_layer_start,
            ffn_layer_end=args.ffn_layer_end
        )
        
        # Detect trigger tokens
        trigger_result = detect_trigger_tokens(
            text=text,
            model_path=args.model_path,
            k_value=args.trigger_k
        )
        
        # Summary
        print("\n" + "="*60)
        print("Summary")
        print("="*60)
        print(f"Anchor tokens found: {len(anchor_result['anchor_tokens'])}")
        if anchor_result['anchor_tokens']:
            print("Top anchor tokens:")
            for i, anchor in enumerate(anchor_result['anchor_tokens'][:3], 1):
                print(f"  {i}. '{anchor['decoded_text']}' at position {anchor['position']}")
        
        print(f"\nTrigger tokens found: {len(trigger_result['trigger_tokens'])}")
        if trigger_result['trigger_tokens']:
            print("Top trigger tokens:")
            for i, trigger in enumerate(trigger_result['trigger_tokens'][:3], 1):
                print(f"  {i}. '{trigger['decoded_text']}' at position {trigger['position']}")
        
        if mlp_model:
            print("\n✓ MLP model loaded and ready for KV correction")
            print("  (Note: Full inference pipeline requires integration with model forward pass)")
        
    print("\n" + "="*60)
    print("Inference Demo Completed!")
    print("="*60)


if __name__ == "__main__":
    main()
