"""
ESM2 Quickstart Example

Minimal example to get started with ESM2 protein sequence embedding inference.

Usage:
    python esm2_quickstart.py
"""

import sys
from pathlib import Path

# Add tnet2 to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))

from coarsebind_public.esm2.esm2_inference import ESM2Infer


def main():
    # Configuration
    model_name = "esm2_t33_650M_UR50D"  # ESM2 650M parameter model
    rep_layer = 33  # Final layer for 650M model
    device = "cuda"  # or "cpu"

    print("=" * 60)
    print("ESM2 Quickstart")
    print("=" * 60)
    print(f"Model: {model_name}")
    print(f"Representation Layer: {rep_layer}")
    print(f"Device: {device}")
    print()

    # Initialize inferrer
    print("Initializing ESM2 inferrer...")
    print("This will download the model if not cached locally")
    print()
    inferrer = ESM2Infer(model_name=model_name, rep_layer=rep_layer, device=device)

    # Test protein sequences
    sequences = [
        "MKTAYIAKQRQISFVKSHFSRQ",  # Short test sequence
        "ARNDCEQGHILKMFPSTWYV",  # All 20 amino acids
        "MKLLVLGLLLGTTVQSTSR",  # Another test sequence
    ]

    print(f"Processing {len(sequences)} protein sequences...")
    print()

    # Run inference
    results = inferrer.predict(sequences, chunk_size=16, max_seq_len=2048)

    # Display results
    print("Results:")
    print("-" * 60)
    for result in results:
        if not result.error:
            print(f"✓ Sequence ({len(result.sequence)} residues)")
            print(
                f"  {result.sequence[:50]}{'...' if len(result.sequence) > 50 else ''}"
            )
            print(
                f"  Embedding: shape {result.embed.shape}, "
                f"mean={result.embed.mean():.4f}, std={result.embed.std():.4f}"
            )
            print(f"  Model: {result.model_name}, layer {result.rep_layer}")
        else:
            print(f"✗ Sequence ({len(result.sequence)} residues)")
            print(
                f"  {result.sequence[:50]}{'...' if len(result.sequence) > 50 else ''}"
            )
            print(f"  ERROR: {result.error_msg}")
        print()

    print("=" * 60)
    print("Done! ESM2 embeddings generated successfully")
    print("=" * 60)
    print()
    print("Note:")
    print("  - Embeddings include start/stop tokens (length = seq_len + 2)")
    print("  - Embedding shape: [seq_len + 2, hidden_dim]")
    print("  - Use embed[1:-1] to get per-residue embeddings without special tokens")


if __name__ == "__main__":
    main()
