import sys
from pathlib import Path

# Add tnet2 to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))

from coarsebind_public.mol_encoder.mol_enc_inferrer import MolEncInferrer


def main():

    model_uri = "s3://path"
    smiles_tokenizer_path = "s3://path"
    graph_tokenizer_path = "s3://path"

    device = "cuda"

    print("=" * 60)
    print("MolEnc Quickstart")
    print("=" * 60)
    print(f"Model: {model_uri}")
    print(f"Device: {device}")
    print()

    # Initialize inferrer (always generates dual embeddings)
    print("Initializing inferrer...")
    print("This will generate both SMILES-based and 3D conformer-based embeddings")
    print()
    inferrer = MolEncInferrer(
        model_uri=model_uri,
        smiles_tokenizer_path=smiles_tokenizer_path,
        graph_tokenizer_path=graph_tokenizer_path,
        device=device,
    )

    # Test molecules
    smiles = [
        "CCO",  # Ethanol
        "c1ccccc1",  # Benzene
        "CC(C)C",  # Isobutane
    ]

    print(f"Processing {len(smiles)} molecules...")
    print()

    # Run inference
    results = inferrer.predict(smiles, batch_size=128, prog_bar=False)

    # Display results
    print("Results:")
    print("-" * 60)
    for result in results:
        if not result.error:
            print(f"✓ {result.smiles:15s}")
            if result.smiles_embed is not None:
                print(
                    f"  SMILES embed: shape {result.smiles_embed.shape}, "
                    f"mean={result.smiles_embed.mean():.4f}, std={result.smiles_embed.std():.4f}"
                )
            if result.e3nn_embed is not None:
                print(
                    f"  E3NN embed:   shape {result.e3nn_embed.shape}, "
                    f"mean={result.e3nn_embed.mean():.4f}, std={result.e3nn_embed.std():.4f}"
                )
        else:
            print(f"✗ {result.smiles:15s} -> ERROR: {result.error_msg}")
        print()

    print("=" * 60)
    print("Done! For more examples, see mol_enc_example.py")
    print("=" * 60)


if __name__ == "__main__":
    main()
