import sys
from pathlib import Path

# Add tnet2 to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))

from coarsebind_public.coarsebind.coarsebind_inferrer import CoarseBindInferrer
from coarsebind_public.coarsebind.io_schema import IOSchemaCoarseBind


def main():
    # Model configuration
    model_doc_potency = "s3://path"
    smiles_tokenizer_path = "s3://path"
    graph_tokenizer_path = "s3://path"
    mol_enc_uri = "s3://path"

    device = "cuda"  # Use "cpu" if no GPU available

    print("=" * 80)
    print("CoarseBind Quickstart")
    print("=" * 80)
    print(f"Model: {model_doc_potency}")
    print(f"Device: {device}")
    print()

    # Initialize inferrer
    print("Initializing ...")
    print()

    inferrer = CoarseBindInferrer(
        model_doc_potency=model_doc_potency,
        device=device,
        batch_size=1,
        max_tokens=512,
        run_coarse_cofold=True,
        mol_enc_uri=mol_enc_uri,
        mol_enc_smiles_tokenizer_path=smiles_tokenizer_path,
        mol_enc_graph_tokenizer_path=graph_tokenizer_path,
    )

    # Initialize models (this loads all components)
    inferrer.init_models()
    print("✓ Models loaded successfully")
    print()

    # Example input: protein sequence + ligand SMILES
    test_inputs = [
        IOSchemaCoarseBind(
            smiles="CCO",  # Ethanol
            sequence=["MKTAYIAKQRQISFVKSHFSRQ"],  # Short protein sequence
            target_name="test_protein_1",
        ),
        IOSchemaCoarseBind(
            smiles="c1ccccc1",  # Benzene
            sequence=["ARNDCEQGHILKMFPSTWYV"],  # Another short protein sequence
            target_name="test_protein_2",
        ),
    ]

    print(f"Running inference on {len(test_inputs)} protein-ligand complexes...")
    print()

    # Run inference
    results = inferrer.predict(
        input_data=test_inputs,
        pairformer_model_forward_kwargs={
            "recycling_steps": 3,
        },
        potency_model_forward_kwargs={
            "epinet_samples": 100,  # Use 100 for uncertainty quantification, 0 for faster inference
        },
    )

    # Display results
    print("Results:")
    print("-" * 80)
    for i, result in enumerate(results):
        print(f"\nComplex {i+1}: {result.target_name}")
        print(f"  SMILES: {result.smiles}")
        print(
            f"  Sequence: {result.sequence[0][:30]}..."
            if len(result.sequence[0]) > 30
            else f"  Sequence: {result.sequence[0]}"
        )

        if result.error:
            print(f"  ✗ ERROR: {result.error_msg}")
        else:
            print(f"  ✓ Prediction successful")

            # Potency predictions
            if result.disto_potency_output:
                pred_quant = result.disto_potency_output.pred_quant
                pred_binary = result.disto_potency_output.pred_binary

                print(f"  IC50 (pIC50): {pred_quant:.3f}")
                print(f"  Binary (active): {pred_binary:.3f}")

                # Uncertainty bounds (if epinet_samples > 0)
                if result.disto_potency_output.pred_25_pct is not None:
                    p25 = result.disto_potency_output.pred_25_pct
                    p75 = result.disto_potency_output.pred_75_pct
                    print(f"  Uncertainty: [{p25:.3f}, {p75:.3f}] (25th-75th percentile)")

            # Distance predictions
            if result.disto_output:
                num_tokens = len(result.disto_output.res_type)
                print(f"  Tokens: {num_tokens}")
                print(f"  Distance bins shape: {result.disto_output.bin_probs.shape}")

    print()
    print("=" * 80)
    print("Done! Key outputs:")
    print("  - disto_potency_output.pred_quant: IC50 prediction (pIC50)")
    print("  - disto_potency_output.pred_binary: Active/inactive classification")
    print("  - disto_output.bin_probs: Distance distribution predictions")
    print("=" * 80)


if __name__ == "__main__":
    main()
