"""
Demo script for testing ECG QA CoT (ECG Question Answering Chain-of-Thought) model from HuggingFace.

This script:
1. Loads a pretrained model from HuggingFace Hub
2. Loads the ECG QA CoT test dataset
3. Generates predictions on the evaluation set
4. Prints model outputs
"""

from opentslm.model.llm.OpenTSLM import OpenTSLM
from opentslm.time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset
from opentslm.time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
from torch.utils.data import DataLoader
from opentslm.model_config import PATCH_SIZE
import torch

# Model repository ID - change this to test different models
REPO_ID = "OpenTSLM/llama-3.2-1b-ecg-sp"


def main():
    print("=" * 60)
    print("ECG QA CoT Model Demo")
    print("=" * 60)

    # Load model from HuggingFace
    print(f"\n📥 Loading model from {REPO_ID}...")
    enable_lora = False
    if "-sp" in REPO_ID:
        enable_lora = True
    model = OpenTSLM.load_pretrained(
        REPO_ID, enable_lora=enable_lora, device="cuda" if torch.cuda.is_available() else "cpu"
    )

    # Create dataset
    print("\n📊 Loading ECG QA CoT test dataset...")
    test_dataset = ECGQACoTQADataset("test", EOS_TOKEN=model.get_eos_token())

    # Create data loader
    test_loader = DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=1,
        collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(batch, patch_size=PATCH_SIZE),
    )

    print(f"\n🔍 Running inference on {len(test_dataset)} test samples...")
    print("=" * 60)

    # Iterate over evaluation set
    for i, batch in enumerate(test_loader):
        # Generate predictions
        predictions = model.generate(batch, max_new_tokens=500)

        # Print results
        for sample, pred in zip(batch, predictions):
            print(f"\n📝 Sample {i + 1}:")
            if "pre_prompt" in sample:
                print(f"   Question: {sample['pre_prompt']}")
            if "template_id" in sample:
                print(f"   Template ID: {sample['template_id']}")
            if "ecg_id" in sample:
                print(f"   ECG ID: {sample['ecg_id']}")
            print(f"   Gold Answer: {sample.get('answer', 'N/A')}")
            print(f"   Model Output: {pred}")
            print("-" * 60)

        # Limit to first 5 samples for demo
        if i >= 4:
            print("\n✅ Demo complete! (Showing first 5 samples)")
            break


if __name__ == "__main__":
    main()
