#!/usr/bin/env python3
"""
Test script for ECG-QA dataset implementation.

This script tests:
1. ECG-QA repository cloning
2. PTB-XL dataset downloading  
3. Dataset loading and basic functionality
4. Sample data access

Usage:
    python test_ecgqa.py
"""

import sys
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', "src"))
def test_ecgqa_loader():
    """Test the ECG-QA loader functions."""
    print("Testing ECG-QA loader...")
    
    try:
        from time_series_datasets.ecg_qa.ecgqa_loader import (
            does_ecg_qa_exist, 
            does_ptbxl_exist,
            download_ecg_qa_if_not_exists,
            download_ptbxl_if_not_exists
        )
        
        print(f"ECG-QA exists: {does_ecg_qa_exist()}")
        print(f"PTB-XL exists: {does_ptbxl_exist()}")
        
        # Download if needed (this might take a while)
        print("Ensuring datasets are available...")
        download_ecg_qa_if_not_exists()
        download_ptbxl_if_not_exists()
        
        print(f"After download - ECG-QA exists: {does_ecg_qa_exist()}")
        print(f"After download - PTB-XL exists: {does_ptbxl_exist()}")
        
        return True
        
    except Exception as e:
        print(f"Error in loader test: {e}")
        return False

def test_ecgqa_dataset():
    """Test the ECG-QA dataset class."""
    print("\nTesting ECGQADataset...")
    
    try:
        from time_series_datasets.ecg_qa.ECGQADataset import ECGQADataset
        
        # Try to create dataset instances with limited samples for faster testing
        print("Creating dataset instances (limited to 5 samples each for testing)...")
        dataset = ECGQADataset(split="train", EOS_TOKEN="", max_samples=5)
        dataset_val = ECGQADataset(split="validation", EOS_TOKEN="", max_samples=5)  
        dataset_test = ECGQADataset(split="test", EOS_TOKEN="", max_samples=5)
        
        print(f"Dataset sizes:")
        print(f"  Train: {len(dataset)} samples")
        print(f"  Validation: {len(dataset_val)} samples")
        print(f"  Test: {len(dataset_test)} samples")
        
        if len(dataset) > 0:
            print(f"\nExamining first training sample:")
            sample = dataset[0]
            
            print(f"Sample keys: {list(sample.keys())}")
            print(f"Question: {sample.get('question', 'N/A')}")
            print(f"Answer: {sample['answer']}")
            print(f"Question type: {sample.get('question_type', 'N/A')}")
            print(f"ECG IDs: {sample.get('ecg_id', 'N/A')}")
            
            if 'time_series_text' in sample:
                print(f"Number of time series: {len(sample['time_series_text'])}")
                
                # Show first time series info and data
                if len(sample['time_series_text']) > 0:
                    first_ts = sample['time_series_text'][0]
                    if hasattr(first_ts, 'text'):
                        print(f"First time series label: {first_ts.text}")
                        print(f"First time series length: {len(first_ts.time_series)}")
                        
                        # Show the actual ECG data values
                        print(f"First 20 ECG data points: {first_ts.time_series[:20]}")
                        print(f"ECG data range: min={min(first_ts.time_series):.4f}, max={max(first_ts.time_series):.4f}")
                        
                        # Show statistics
                        import numpy as np
                        ecg_data = np.array(first_ts.time_series)
                        print(f"ECG statistics: mean={np.mean(ecg_data):.4f}, std={np.std(ecg_data):.4f}")
                        
                        # Show a few more leads if available
                        if len(sample['time_series_text']) > 1:
                            print(f"\nOther ECG leads available:")
                            for i, ts in enumerate(sample['time_series_text'][1:4], 1):  # Show up to 3 more leads
                                if hasattr(ts, 'text'):
                                    print(f"  Lead {i+1}: {ts.text}")
                                    print(f"    First 10 values: {ts.time_series[:10]}")
                                    ecg_data = np.array(ts.time_series)
                                    print(f"    Stats: mean={np.mean(ecg_data):.4f}, std={np.std(ecg_data):.4f}")
                    else:
                        print(f"Time series format issue - got {type(first_ts)}: {first_ts}")
                        
                        # If it's a string, try to parse the actual time series data
                        if 'time_series' in sample:
                            print(f"\nRaw time_series field type: {type(sample['time_series'])}")
                            if isinstance(sample['time_series'], list) and len(sample['time_series']) > 0:
                                print(f"First time series data (first 20 points): {sample['time_series'][:20]}")
                                print(f"Time series length: {len(sample['time_series'])}")
                                
                                import numpy as np
                                ts_data = np.array(sample['time_series'])
                                print(f"Time series stats: mean={np.mean(ts_data):.4f}, std={np.std(ts_data):.4f}")
            
            print(f"\nPre-prompt: {sample['pre_prompt']}...")
            print(f"Post-prompt: {sample['post_prompt']}...")
        
        return True
        
    except ImportError as e:
        print(f"Import error (likely missing wfdb): {e}")
        print("Please install wfdb: pip install wfdb")
        return False
    except Exception as e:
        print(f"Error in dataset test: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    """Run all tests."""
    print("ECG-QA Dataset Test Suite")
    print("=" * 50)
    
    # Test 1: Loader functionality
    loader_success = test_ecgqa_loader()
    
    # Test 2: Dataset functionality (only if loader works)
    if loader_success:
        dataset_success = test_ecgqa_dataset()
    else:
        dataset_success = False
        print("Skipping dataset test due to loader failure")
    
    # Summary
    print("\n" + "=" * 50)
    print("Test Results:")
    print(f"  Loader test: {'✓ PASS' if loader_success else '✗ FAIL'}")
    print(f"  Dataset test: {'✓ PASS' if dataset_success else '✗ FAIL'}")
    
    if loader_success and dataset_success:
        print("\n🎉 All tests passed! ECG-QA dataset is ready to use.")
        print("\nTo use the full dataset (without sample limits), create ECGQADataset without max_samples parameter:")
        print("  dataset = ECGQADataset(split='train', EOS_TOKEN='')")
    else:
        print("\n❌ Some tests failed. Check the error messages above.")
        
    return loader_success and dataset_success

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1) 