from datasets import Dataset
from typing import List, Tuple, Literal
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from prompt.text_time_series_prompt import TextTimeSeriesPrompt
from time_series_datasets.QADataset import QADataset
from time_series_datasets.ecg_qa.ecgqa_cot_loader import load_ecg_qa_cot_splits
import numpy as np

class ECGQACoTQADataset(QADataset):
    """
    ECG-QA Chain-of-Thought Dataset for question answering with electrocardiogram data.
    
    This dataset combines ECG time series data from PTB-XL with 
    question-answer pairs and chain-of-thought reasoning from the ECG-QA CoT dataset.
    
    Requires: pip install wfdb
    """
    
    def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, 
                 format_sample_str: bool = False, time_series_format_function=None,
                 max_samples: int = None, exclude_comparison: bool = False,
                 preload_processed_data: bool = True):
        """
        Initialize ECG-QA CoT Dataset.
        
        Args:
            split: Dataset split to load
            EOS_TOKEN: End-of-sequence token
            format_sample_str: Whether to format samples as strings
            time_series_format_function: Function to format time series data
            max_samples: Maximum number of samples per split (for testing)
            exclude_comparison: If True, exclude comparison questions (question_type starting with "comparison_")
            preload_processed_data: If True, preload processed ECG data for maximum speed (uses more memory). Default: True
        """
        self.max_samples = max_samples
        self.exclude_comparison = exclude_comparison
        self.preload_processed_data = preload_processed_data
        super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function)

    def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Load the ECG-QA CoT dataset splits."""
        print("Loading ECG-QA CoT dataset splits...")
        train, val, test = load_ecg_qa_cot_splits()
        
        # Filter out comparison questions if requested
        if self.exclude_comparison:
            print("Filtering out comparison questions...")
            
            def filter_comparison(dataset):
                filtered_data = []
                for sample in dataset:
                    question_type = sample.get("question_type")
                    if question_type is None:
                        raise ValueError(f"Sample missing required 'question_type' field: {sample}")
                    if not question_type.startswith("comparison"):
                        filtered_data.append(sample)
                return Dataset.from_list(filtered_data)
            
            original_train_len = len(train)
            original_val_len = len(val)
            original_test_len = len(test)
            
            train = filter_comparison(train)
            val = filter_comparison(val)
            test = filter_comparison(test)
            
            print(f"Filtered out comparison questions:")
            print(f"  Train: {original_train_len} -> {len(train)} ({original_train_len - len(train)} removed)")
            print(f"  Val: {original_val_len} -> {len(val)} ({original_val_len - len(val)} removed)")
            print(f"  Test: {original_test_len} -> {len(test)} ({original_test_len - len(test)} removed)")
        
        # Limit samples for faster testing if requested
        if self.max_samples:
            print(f"Limiting to {self.max_samples} samples per split for testing...")
            if len(train) > self.max_samples:
                train = train.select(range(self.max_samples))
            if len(val) > self.max_samples:
                val = val.select(range(self.max_samples))
            if len(test) > self.max_samples:
                test = test.select(range(self.max_samples))
        
        # Preload ECG data for better performance
        self.preload_ecg_data([train, val, test])
        
        # Optionally preload processed data for maximum performance
        if self.preload_processed_data:
            self.preload_processed_ecg_data([train, val, test])
        
        return train, val, test

    def _get_answer(self, row) -> str:
        """Get the answer from the row, which is the chain-of-thought reasoning."""
        return row.get("rationale", "No chain-of-thought reasoning available.")

    def _get_pre_prompt(self, row) -> str:
        """Generate the pre-prompt explaining the task with clinical context."""
        question_type = row.get("question_type")
        if question_type is None:
            raise ValueError(f"Sample missing required 'question_type' field: {row}")
        
        question = row.get("question")
        if question is None:
            raise ValueError(f"Sample missing required 'question' field: {row}")
        
        # Get clinical context if available
        clinical_contexts = row.get("clinical_contexts", [])
        if not clinical_contexts:
            raise ValueError(f"Sample missing required 'clinical_contexts' field: {row}")
        clinical_context = clinical_contexts[0]
        
        base_prompt = f"""You are an expert cardiologist analyzing an ECG (electrocardiogram). 

Clinical Context: {clinical_context}

Your task is to examine the ECG signal and answer the following medical question:

Question: {question}

Instructions:
- Begin by analyzing the time series without assuming a specific answer.
- Think step-by-step about what the observed patterns suggest regarding the cardiac condition.
- Write your rationale as a single, natural paragraph — do not use bullet points, numbered steps, or section headings.
- Do **not** mention any final answer until the very end.
- Consider the ECG morphology, intervals, and any abnormalities that relate to the question."""
        
        

        
        return base_prompt 

    def _get_post_prompt(self, row) -> str:
        """Generate the post-prompt with possible answers and instructions."""
        # Try to get template-specific answers first
        template_id = row.get("template_id")
        if template_id is None:
            raise ValueError(f"Sample missing required 'template_id' field: {row}")
        
        possible_answers = ECGQACoTQADataset.get_possible_answers_for_template(template_id)
        
        if possible_answers:
            answers_text = ", ".join(possible_answers)
            prompt = f"""
Based on your analysis of the ECG data, select your answer from the following options:
{answers_text}

- Make sure that your last word is the answer. You MUST end your response with "Answer: "
"""
        else:
            prompt = """
Based on your analysis of the ECG data, provide your answer.
Make sure that your last word is the answer. You MUST end your response with "Answer: "
"""
        
        return prompt.strip()

    # Class-level cache for template answers
    _template_answers_cache = None
    
    @classmethod
    def _load_template_answers_cache(cls):
        """Load template answers cache once."""
        if cls._template_answers_cache is None:
            try:
                import pandas as pd
                import ast
                from time_series_datasets.ecg_qa.ecgqa_loader import ECG_QA_DIR
                
                # Load template answers directly
                template_answers_path = os.path.join(ECG_QA_DIR, "ecgqa", "ptbxl", "answers_for_each_template.csv")
                template_df = pd.read_csv(template_answers_path)
                
                # Build cache dictionary
                cls._template_answers_cache = {}
                for _, row in template_df.iterrows():
                    template_id = row['template_id']
                    answers_str = row['classes']
                    try:
                        cls._template_answers_cache[template_id] = ast.literal_eval(answers_str)
                    except Exception as e:
                        print(f"Warning: Failed to parse answers for template {template_id}: {e}")
                        cls._template_answers_cache[template_id] = []
                        
                print(f"Loaded template answers cache with {len(cls._template_answers_cache)} templates")
                
            except Exception as e:
                print(f"Error loading template answers cache: {e}")
                cls._template_answers_cache = {}
    
    @staticmethod
    def get_possible_answers_for_template(template_id: int) -> List[str]:
        """Get possible answers for a specific template ID."""
        # Load cache if not already loaded
        ECGQACoTQADataset._load_template_answers_cache()
        
        # Return cached result
        return ECGQACoTQADataset._template_answers_cache.get(template_id, [])
    
    @staticmethod
    def get_labels() -> List[str]:
        """Get all possible answer labels from answers_for_each_template.csv."""
        try:
            import pandas as pd
            import ast
            from time_series_datasets.ecg_qa.ecgqa_loader import ECG_QA_DIR

            template_answers_path = os.path.join(ECG_QA_DIR, "ecgqa", "ptbxl", "answers_for_each_template.csv")
            df = pd.read_csv(template_answers_path)

            all_labels = set()
            for _, row in df.iterrows():
                classes_str = row.get("classes", "[]")
                try:
                    classes = ast.literal_eval(classes_str)
                    for c in classes:
                        if isinstance(c, str):
                            cleaned = c.strip()
                            if cleaned:
                                all_labels.add(cleaned)
                except Exception as e:
                    print(f"Warning: Failed to parse classes for template {row.get('template_id')}: {e}")
                    continue

            labels = sorted(all_labels)
            print(f"Loaded {len(labels)} unique labels from answers_for_each_template.csv")
            return labels
        except Exception as e:
            print(f"Error loading labels from CSV: {e}")
            return ["yes", "no", "not sure", "none", "normal", "abnormal"]
    
    # Class-level cache for ECG data
    _ecg_data_cache = {}
    _processed_ecg_cache = {}  # Cache for processed (downsampled + normalized) signals
    
    @classmethod
    def _load_ecg_data(cls, ecg_path: str) -> Tuple[np.ndarray, str, str]:
        """Load and cache ECG data for a given path."""
        if ecg_path not in cls._ecg_data_cache:
            try:
                import wfdb
                
                # Load ECG data using wfdb
                base_path = ecg_path.replace('.dat', '').replace('.hea', '')
                
                if not os.path.exists(base_path + '.dat'):
                    raise FileNotFoundError(f"ECG data file not found: {base_path}.dat")
                
                if not os.path.exists(base_path + '.hea'):
                    raise FileNotFoundError(f"ECG header file not found: {base_path}.hea")
                
                # Read the ECG record
                record = wfdb.rdrecord(base_path)
                
                # Get the signal data - shape is (samples, leads)
                ecg_signal = record.p_signal  # Physical signal
                
                if ecg_signal is None:
                    raise ValueError(f"ECG signal is None for file {base_path}")
                
                if ecg_signal.shape[0] == 0:
                    raise ValueError(f"ECG signal is empty (0 samples) for file {base_path}")
                
                # Determine frequency info
                if len(ecg_signal) > 1000:  # Likely 500Hz data
                    original_freq = "500Hz"
                    target_freq = "100Hz"
                else:  # Likely already 100Hz data
                    original_freq = "100Hz"
                    target_freq = "100Hz"
                
                # Cache the raw signal and frequency info
                cls._ecg_data_cache[ecg_path] = (ecg_signal, original_freq, target_freq)
                
            except Exception as e:
                raise RuntimeError(f"Failed to read ECG record from {base_path}: {str(e)}")
        
        return cls._ecg_data_cache[ecg_path]
    
    @classmethod
    def preload_ecg_data(cls, dataset_splits: List[Dataset]):
        """Preload ECG data for all samples to improve formatting performance."""
        print("Preloading ECG data for faster formatting...")
        
        # Collect all unique ECG paths
        ecg_paths = set()
        for dataset in dataset_splits:
            for sample in dataset:
                sample_ecg_paths = sample.get("ecg_paths", [])
                if sample_ecg_paths:
                    ecg_paths.update(sample_ecg_paths)
                else:
                    # Fallback: construct path from ecg_id
                    ecg_id = sample.get("ecg_id")
                    if ecg_id and isinstance(ecg_id, list) and len(ecg_id) > 0:
                        from time_series_datasets.ecg_qa.ecgqa_loader import get_ptbxl_ecg_path
                        ecg_path = get_ptbxl_ecg_path(ecg_id[0]) + ".dat"
                        ecg_paths.add(ecg_path)
        
        print(f"Found {len(ecg_paths)} unique ECG files to preload...")
        
        # Preload ECG data with progress bar
        from tqdm import tqdm
        for ecg_path in tqdm(ecg_paths, desc="Preloading ECG data"):
            try:
                cls._load_ecg_data(ecg_path)
            except Exception as e:
                print(f"Warning: Failed to preload ECG data from {ecg_path}: {e}")
        
        print(f"Preloaded {len(cls._ecg_data_cache)} ECG files")
    
    @classmethod
    def preload_processed_ecg_data(cls, dataset_splits: List[Dataset]):
        """Preload processed ECG data (downsampled + normalized) for all samples."""
        print("Preloading processed ECG data for maximum performance...")
        
        # Collect all unique ECG paths and lead combinations
        ecg_lead_combinations = set()
        for dataset in dataset_splits:
            for sample in dataset:
                sample_ecg_paths = sample.get("ecg_paths", [])
                if sample_ecg_paths:
                    for ecg_path in sample_ecg_paths:
                        # Load ECG data to determine number of leads
                        ecg_signal, _, _ = cls._load_ecg_data(ecg_path)
                        n_leads = ecg_signal.shape[1] if len(ecg_signal.shape) > 1 else 1
                        for lead_idx in range(n_leads):
                            ecg_lead_combinations.add((ecg_path, lead_idx))
                else:
                    # Fallback: construct path from ecg_id
                    ecg_id = sample.get("ecg_id")
                    if ecg_id and isinstance(ecg_id, list) and len(ecg_id) > 0:
                        from time_series_datasets.ecg_qa.ecgqa_loader import get_ptbxl_ecg_path
                        ecg_path = get_ptbxl_ecg_path(ecg_id[0]) + ".dat"
                        ecg_signal, _, _ = cls._load_ecg_data(ecg_path)
                        n_leads = ecg_signal.shape[1] if len(ecg_signal.shape) > 1 else 1
                        for lead_idx in range(n_leads):
                            ecg_lead_combinations.add((ecg_path, lead_idx))
        
        print(f"Found {len(ecg_lead_combinations)} unique ECG lead combinations to preprocess...")
        
        # Preprocess ECG leads with progress bar
        from tqdm import tqdm
        for ecg_path, lead_idx in tqdm(ecg_lead_combinations, desc="Preprocessing ECG leads"):
            try:
                cls._process_ecg_lead(ecg_path, lead_idx)
            except Exception as e:
                print(f"Warning: Failed to preprocess ECG lead {lead_idx} from {ecg_path}: {e}")
        
        print(f"Preprocessed {len(cls._processed_ecg_cache)} ECG leads")
    
    @classmethod
    def clear_caches(cls):
        """Clear all caches to free memory."""
        cls._ecg_data_cache.clear()
        cls._processed_ecg_cache.clear()
        cls._template_answers_cache = None
        print("Cleared all ECG-QA CoT dataset caches")
    
    @classmethod
    def get_cache_stats(cls) -> dict:
        """Get cache statistics for monitoring."""
        return {
            "ecg_data_cache_size": len(cls._ecg_data_cache),
            "processed_ecg_cache_size": len(cls._processed_ecg_cache),
            "template_answers_loaded": cls._template_answers_cache is not None,
            "template_answers_count": len(cls._template_answers_cache) if cls._template_answers_cache else 0
        }
    
    @classmethod
    def _process_ecg_lead(cls, ecg_path: str, lead_idx: int) -> Tuple[np.ndarray, float, float]:
        """Process and cache a single ECG lead (downsample + normalize)."""
        cache_key = f"{ecg_path}:lead_{lead_idx}"
        
        if cache_key not in cls._processed_ecg_cache:
            # Load raw ECG data
            ecg_signal, original_freq, target_freq = cls._load_ecg_data(ecg_path)
            
            # Extract lead signal
            if len(ecg_signal.shape) > 1:
                lead_signal = ecg_signal[:, lead_idx]
            else:
                lead_signal = ecg_signal
            
            if len(lead_signal) == 0:
                raise ValueError(f"Lead {lead_idx} is empty for file {ecg_path}")
            
            # Downsample if needed
            if len(lead_signal) > 1000:  # Likely 500Hz data
                downsampled_signal = lead_signal[::5]  # Downsample to 100Hz
            else:  # Already 100Hz data
                downsampled_signal = lead_signal
            
            if len(downsampled_signal) == 0:
                raise ValueError(f"Downsampled signal is empty for lead {lead_idx} in file {ecg_path}")
            
            # Normalize the signal
            mean_val = float(np.mean(downsampled_signal))
            std_val = float(np.std(downsampled_signal))
            
            if np.isnan(mean_val) or np.isnan(std_val):
                raise ValueError(f"NaN values detected in ECG signal statistics for lead {lead_idx} in file {ecg_path}")
            
            if std_val > 1e-6:  # Avoid division by zero
                normalized_signal = (downsampled_signal - mean_val) / std_val
            else:
                print(f"Warning: Lead {lead_idx} in file {ecg_path} has very low std deviation ({std_val}), signal may be flat")
                normalized_signal = downsampled_signal - mean_val
            
            # Verify normalized signal is valid
            if np.any(np.isnan(normalized_signal)) or np.any(np.isinf(normalized_signal)):
                raise ValueError(f"Invalid values (NaN/Inf) in normalized signal for lead {lead_idx} in file {ecg_path}")
            
            # Cache the processed signal and statistics
            cls._processed_ecg_cache[cache_key] = (normalized_signal, mean_val, std_val)
        
        return cls._processed_ecg_cache[cache_key]

    def _get_text_time_series_prompt_list(self, row) -> List[TextTimeSeriesPrompt]:
        """Load ECG data and convert to TextTimeSeriesPrompt format using cached data."""
        
        ecg_prompts = []
        ecg_paths = row.get("ecg_paths")
        if ecg_paths is None:
            raise ValueError(f"Sample missing required 'ecg_paths' field: {row}")
        
        if not ecg_paths:
            # Fallback: single ECG path
            ecg_id = row.get("ecg_id")
            if ecg_id is None:
                raise ValueError(f"Sample missing required 'ecg_id' field: {row}")
            
            if not isinstance(ecg_id, list) or len(ecg_id) == 0:
                raise ValueError(f"Sample 'ecg_id' must be a non-empty list: {ecg_id}")
            
            from time_series_datasets.ecg_qa.ecgqa_loader import get_ptbxl_ecg_path
            ecg_path = get_ptbxl_ecg_path(ecg_id[0]) + ".dat"
            ecg_paths = [ecg_path]
        
        for i, ecg_path in enumerate(ecg_paths):
            # Load ECG data to determine number of leads
            ecg_signal, original_freq, target_freq = self._load_ecg_data(ecg_path)
            
            # Load all available leads (typically 12 for standard ECG)
            if len(ecg_signal.shape) == 1:
                # Single lead case
                n_leads = 1
            elif len(ecg_signal.shape) == 2:
                n_leads = ecg_signal.shape[1]
                if ecg_signal.shape[1] < 12:
                    print(f"Warning: ECG file {ecg_path} has only {ecg_signal.shape[1]} leads, expected 12 for standard ECG")
            else:
                raise ValueError(f"Unexpected ECG signal shape {ecg_signal.shape} for file {ecg_path}")
            
            for lead_idx in range(n_leads):
                # Use cached processed signal
                normalized_signal, mean_val, std_val = self._process_ecg_lead(ecg_path, lead_idx)
                
                # Verify we have exactly 1000 samples (10 seconds at 100Hz)
                if len(normalized_signal) != 1000:
                    print(f"Warning: Lead {lead_idx} in file {ecg_path} has {len(normalized_signal)} samples, expected 1000 for 100Hz")
                
                # Create lead name
                lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
                lead_name = lead_names[lead_idx] if lead_idx < len(lead_names) else f"Lead_{lead_idx}"
                
                ecg_label = f"This is ECG Lead {lead_name}"
                if len(ecg_paths) > 1:
                    ecg_label += f" (Recording {i+1})"
                    
                ecg_label += f", it has mean {mean_val:.4f} and std {std_val:.4f}:"
                
                try:
                    ecg_prompts.append(
                        TextTimeSeriesPrompt(ecg_label, normalized_signal.tolist())
                    )
                except Exception as e:
                    raise RuntimeError(f"Failed to create TextTimeSeriesPrompt for lead {lead_name} in file {ecg_path}: {str(e)}")
        
        if not ecg_prompts:
            raise RuntimeError(f"No ECG prompts were created for sample. ECG paths attempted: {ecg_paths}")
        
        return ecg_prompts

    def _format_sample(self, row):
        # Call parent method to get the standard formatted sample
        formatted_sample = super()._format_sample(row)
        
        # Add CoT-specific fields if they exist in the original row
        if 'rationale' in row:
            formatted_sample['rationale'] = row['rationale']
        if 'cot_question_id' in row:
            formatted_sample['cot_question_id'] = row['cot_question_id']
        if 'cot_template_id' in row:
            formatted_sample['cot_template_id'] = row['cot_template_id']
        if 'cot_question_type' in row:
            formatted_sample['cot_question_type'] = row['cot_question_type']
        
        # Add original ECG-QA fields
        if 'template_id' in row:
            formatted_sample['template_id'] = row['template_id']
        if 'question_type' in row:
            formatted_sample['question_type'] = row['question_type']
        if 'question' in row:
            formatted_sample['question'] = row['question']
        
        # Add ECG data fields
        if 'ecg_id' in row:
            formatted_sample['ecg_id'] = row['ecg_id']
        if 'ecg_paths' in row:
            formatted_sample['ecg_paths'] = row['ecg_paths']
        if 'clinical_contexts' in row:
            formatted_sample['clinical_contexts'] = row['clinical_contexts']
        
        # Store the ground-truth answer according to ECG-QA and possible answers for this template
        if 'answer' in row:
            formatted_sample['correct_answer'] = row['answer']
        if 'template_id' in row and row['template_id'] is not None:
            try:
                formatted_sample['possible_answers'] = ECGQACoTQADataset.get_possible_answers_for_template(row['template_id'])
            except Exception:
                formatted_sample['possible_answers'] = []
        
        return formatted_sample

    def _format_sample_str(self, time_series_format_function, row):
        """Override to preserve template_id and other fields needed for evaluation."""
        # Call parent method to get the basic formatted sample
        formatted_sample = super()._format_sample_str(time_series_format_function, row)
        
        # Add template_id and other fields needed for evaluation
        if 'template_id' in row:
            formatted_sample['template_id'] = row['template_id']
        if 'cot_template_id' in row:
            formatted_sample['cot_template_id'] = row['cot_template_id']
        if 'question_type' in row:
            formatted_sample['question_type'] = row['question_type']
        if 'question' in row:
            formatted_sample['question'] = row['question']
        
        # Also include correct answer and possible answers to aid evaluation
        if 'answer' in row:
            formatted_sample['correct_answer'] = row['answer']
        
        return formatted_sample


if __name__ == "__main__":
    # Test the dataset with limited samples
    print("Testing ECGQACoTQADataset...")
    
    try:
        # Test with just 5 samples per split for faster testing
        dataset = ECGQACoTQADataset(split="train", EOS_TOKEN="", max_samples=5)
        dataset_val = ECGQACoTQADataset(split="validation", EOS_TOKEN="", max_samples=5)
        dataset_test = ECGQACoTQADataset(split="test", EOS_TOKEN="", max_samples=5)
        
        print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}")
        
        # Show cache statistics
        cache_stats = ECGQACoTQADataset.get_cache_stats()
        print(f"\nCache statistics: {cache_stats}")
        
        if len(dataset) > 0:
            sample = dataset[0]
            print("\nSample keys:", sample.keys())
            print("Sample question:", sample.get("question", "N/A"))
            print("Sample answer (rationale):", sample["answer"][:200] + "..." if len(sample["answer"]) > 200 else sample["answer"])
            print("Sample question type:", sample.get("question_type", "N/A"))
            print("Sample ECG IDs:", sample.get("ecg_id", "N/A"))
            if "time_series_text" in sample:
                print("Time series prompts:", len(sample["time_series_text"]))
                if len(sample["time_series_text"]) > 0:
                    first_ts = sample["time_series_text"][0]
                    if hasattr(first_ts, 'text'):
                        print("First time series label:", first_ts.text)
                        print("First time series length:", len(first_ts.time_series))
                    else:
                        print("Time series format:", type(first_ts))
            print("Pre prompt:", sample["pre_prompt"][:100] + "..." if len(sample["pre_prompt"]) > 100 else sample["pre_prompt"])
            print("Post prompt:", sample["post_prompt"])
            
            # Show CoT-specific fields
            if 'rationale' in sample:
                print("CoT Rationale:", sample['rationale'][:100] + "..." if len(sample['rationale']) > 100 else sample['rationale'])
            if 'cot_question_id' in sample:
                print("CoT Question ID:", sample['cot_question_id'])
        
        # Test performance with processed data preloading
        print("\nTesting with processed data preloading...")
        dataset_fast = ECGQACoTQADataset(split="train", EOS_TOKEN="", max_samples=3, preload_processed_data=True)
        print(f"Fast dataset size: {len(dataset_fast)}")
        
        # Show updated cache statistics
        cache_stats_after = ECGQACoTQADataset.get_cache_stats()
        print(f"Cache statistics after preloading: {cache_stats_after}")
                
    except Exception as e:
        print(f"Error testing dataset: {e}")
        import traceback
        traceback.print_exc()
