#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Audio Feature Extraction Core Module

Main Features:
1. Speech transcription using Whisper
2. Word-level alignment using Wav2Vec2
3. CTC-based forced alignment
4. Audio feature extraction based on word-level timestamps

Usage:
    extractor = AudioFeatureExtractor(language="en")
    features = extractor.extract_features(audio_array, sample_rate)

Core Implementation:
1. Load Whisper and Wav2Vec2 models
2. Perform speech recognition with Whisper
3. Use Wav2Vec2 for word-level alignment
4. Extract audio features based on word-level timestamps
"""

import os
import json
import warnings
import tempfile
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from dataclasses import dataclass
import numpy as np
import torch
import torchaudio
import librosa
import soundfile as sf
import parselmouth
from tqdm import tqdm

# Triton optimization support
try:
    import triton
    import triton.language as tl
    TRITON_AVAILABLE = True
except ImportError:
    TRITON_AVAILABLE = False
    print("Warning: Triton not available, using original implementation")

# HuggingFace libraries
from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration,
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    pipeline
)

# ===== Configuration Constants =====
# Audio processing parameters
SAMPLE_RATE = 16000  # Standard sample rate for Whisper and Wav2Vec2
MAX_DURATION = 30    # Maximum audio segment length (seconds)

# Supported languages and corresponding Wav2Vec2 models
DEFAULT_ALIGN_MODELS = {
    "en": "facebook/wav2vec2-large-960h-lv60-self",
}

# Languages without space separation
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]


# ===== Data Structure Definitions =====
@dataclass
class WordSegment:
    """Word-level segment information"""
    word: str                    # Word text
    start: Optional[float]       # Start time (seconds)
    end: Optional[float]         # End time (seconds)
    score: Optional[float]       # Confidence score


@dataclass
class AlignedSegment:
    """Aligned sentence segment"""
    text: str                    # Sentence text
    start: Optional[float]       # Start time (seconds)
    end: Optional[float]         # End time (seconds)
    words: List[WordSegment]     # Word-level information list


class AudioFeatureExtractor:
    """
    Audio Feature Extractor Core Class
    
    Main Features:
    1. Speech transcription using Whisper
    2. Word-level feature extraction
    3. CTC-based forced alignment
    
    Usage:
        extractor = AudioFeatureExtractor(language="en")
        features = extractor.extract_features(audio_array, sample_rate)
    """
    
    def __init__(self, 
                 language: str = "en",
                 whisper_model: str = "openai/whisper-large-v3",
                 device: str = "auto",
                 merge_threshold: float = 0.5):
        """
        Initialize Audio Feature Extractor
        
        Args:
            language: Target language code
            whisper_model: Whisper model path or name
            device: Computing device ("auto", "cpu", "cuda")
            merge_threshold: Word merging threshold (seconds)
        """
        self.language = language
        self.merge_threshold = merge_threshold
        
        # Device selection
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
            
        print(f"🚀 Initializing Audio Feature Extractor...")
        print(f"   Language: {language}")
        print(f"   Device: {self.device}")
        print(f"   Word merge threshold: {merge_threshold}s")
        
        # Initialize models
        self._load_models(whisper_model)
    
    def _load_models(self, whisper_model: str):
        """
        Load Whisper and Wav2Vec2 models
        
        Implementation Logic:
        1. Load Whisper model for speech recognition
        2. Load corresponding language Wav2Vec2 model for alignment
        3. Set models to evaluation mode
        
        Args:
            whisper_model: Whisper model path or name
        """
        try:
            # Load Whisper model
            print("📥 Loading Whisper model...")
            self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model)
            self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model)
            self.whisper_model.to(self.device)
            self.whisper_model.eval()
            
            # Load Wav2Vec2 alignment model
            align_model_name = DEFAULT_ALIGN_MODELS.get(self.language)
            if not align_model_name:
                raise ValueError(f"Unsupported language: {self.language}")
                
            print(f"📥 Loading Wav2Vec2 alignment model: {align_model_name}")
            self.align_processor = Wav2Vec2Processor.from_pretrained(align_model_name)
            self.align_model = Wav2Vec2ForCTC.from_pretrained(align_model_name)
            self.align_model.to(self.device)
            self.align_model.eval()
            
            # Build character-level vocabulary dictionary
            labels = self.align_processor.tokenizer.get_vocab()
            # Create character to ID mapping, convert all characters to lowercase
            self.vocab = {char.lower(): code for char, code in labels.items()}
            self.id_to_token = {v: k for k, v in self.vocab.items()}
            
            print("✅ Model loading completed")
            print(f"Vocabulary size: {len(self.vocab)}")
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            raise

    # ===== Core Feature Extraction Functions =====
    def transcribe_audio(self, audio: np.ndarray, sampling_rate: int) -> str:
        """
        Transcribe audio using Whisper
        
        Implementation Logic:
        1. Preprocess audio data
        2. Use Whisper model for transcription
        3. Return transcribed text
        
        Args:
            audio: Audio array
            sampling_rate: Sample rate
            
        Returns:
            Transcribed text
        """
        try:
            # Resample to 16kHz
            if sampling_rate != SAMPLE_RATE:
                audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLE_RATE)
            
            # Preprocess audio
            inputs = self.whisper_processor(
                audio, 
                sampling_rate=SAMPLE_RATE, 
                return_tensors="pt"
            )
            inputs = inputs.to(self.device)
            
            # Generate transcription
            with torch.no_grad():
                predicted_ids = self.whisper_model.generate(inputs["input_features"])
                transcription = self.whisper_processor.batch_decode(
                    predicted_ids, skip_special_tokens=True
                )[0]
            
            return transcription.strip()
            
        except Exception as e:
            print(f"❌ Transcription failed: {e}")
            return ""
    
    def get_word_timestamps(self, audio: np.ndarray, text: str) -> List[AlignedSegment]:
        """
        Get word-level timestamps (based on Wav2Vec2 forced alignment)
        
        Implementation Logic:
        1. Use Wav2Vec2 model for CTC alignment
        2. Calculate word-level boundaries
        3. Return list of word segments with timestamps
        
        Args:
            audio: Audio array
            text: Transcribed text
            
        Returns:
            List of aligned segments
        """
        try:
            print("Starting Wav2Vec2 forced alignment...")
            
            # Preprocess text
            clean_transcript = self._preprocess_text(text)
            if not clean_transcript:
                print("Warning: Text preprocessing resulted in empty string, returning original timestamps")
                return [AlignedSegment(
                    text=text,
                    start=0.0,
                    end=len(audio) / SAMPLE_RATE,
                    words=[WordSegment(
                        word=text,
                        start=0.0,
                        end=len(audio) / SAMPLE_RATE,
                        score=0.0
                    )]
                )]
            
            # Preprocess audio
            inputs = self.align_processor(
                audio, 
                sampling_rate=SAMPLE_RATE, 
                return_tensors="pt"
            )
            inputs = inputs.to(self.device)
            
            # Get model output
            with torch.no_grad():
                logits = self.align_model(inputs.input_values).logits
            
            # Convert to probability distribution
            log_probs = torch.log_softmax(logits, dim=-1)
            emission = log_probs[0]  # Remove batch dimension
            
            # Perform CTC alignment
            aligned_segments = self._ctc_align(emission, clean_transcript, audio)
            
            print("Forced alignment completed")
            return aligned_segments
            
        except Exception as e:
            print(f"❌ Word-level alignment failed: {e}")
            return []
    
    def _preprocess_text(self, text: str) -> str:
        """
        Preprocess text, remove characters not in vocabulary
        
        Args:
            text: Original text
            
        Returns:
            Cleaned text
        """
        # Convert to lowercase
        text = text.lower().strip()
        
        # For languages with spaces, replace spaces with | (Wav2Vec2 convention)
        if self.language not in LANGUAGES_WITHOUT_SPACES:
            text = text.replace(" ", "|")
        
        # Keep only characters in vocabulary
        clean_chars = []
        for char in text:
            if char in self.vocab:
                clean_chars.append(char)
            elif char == " " and self.language in LANGUAGES_WITHOUT_SPACES:
                continue  # Skip spaces in languages without space separation
            else:
                # Replace unknown characters with wildcard
                clean_chars.append("*")
        
        return "".join(clean_chars)
    
    def _ctc_align(self, emission: torch.Tensor, transcript: str, audio: np.ndarray) -> List[AlignedSegment]:
        """
        Perform CTC forced alignment
        
        Args:
            emission: Model output emission probabilities
            transcript: Cleaned transcription text
            audio: Original audio
            
        Returns:
            List of aligned segments
        """
        # Convert text to token IDs
        tokens = [self.vocab.get(char, self.vocab.get("[UNK]", 0)) for char in transcript]
        
        # Get blank token ID
        blank_id = self.vocab.get("[PAD]", 0)
        if "[PAD]" not in self.vocab:
            blank_id = self.vocab.get("<pad>", 0)
        
        # Build trellis (grid)
        trellis = self._get_trellis(emission, tokens, blank_id)
        
        # Backtrack optimal path
        path = self._backtrack(trellis, emission, tokens, blank_id)
        
        if path is None:
            print("Warning: CTC alignment failed, returning original timestamps")
            return [AlignedSegment(
                text=transcript.replace("|", " "),
                start=0.0,
                end=len(audio) / SAMPLE_RATE,
                words=[WordSegment(
                    word=transcript.replace("|", " "),
                    start=0.0,
                    end=len(audio) / SAMPLE_RATE,
                    score=0.0
                )]
            )]
        
        # Merge repeated characters
        char_segments = self._merge_repeats(path, transcript)
        
        # Convert to timestamps
        duration = len(audio) / SAMPLE_RATE
        time_ratio = duration / (emission.size(0) - 1)
        
        # Generate word-level alignments
        words = self._generate_word_alignments(char_segments, transcript, time_ratio)
        
        return [AlignedSegment(
            text=transcript.replace("|", " "),
            start=words[0].start if words else 0.0,
            end=words[-1].end if words else duration,
            words=words
        )]
    
    @staticmethod
    @triton.jit
    def _trellis_row_kernel_optimized(
        # Pointers
        trellis_t_ptr,
        trellis_tm1_ptr,
        emission_t_ptr,
        tokens_ptr,
        # Scalar arguments
        num_tokens,
        blank_emit: tl.float32,  # Optimization 1: scalar input blank emission
        t,
        # Tensor strides
        trellis_stride_n,
        # Meta-parameters
        BLOCK_SIZE_N: tl.constexpr,
    ):
        # --- Fix 2: Avoid negative indexing ---
        # Calculate starting from j=1
        pid = tl.program_id(axis=0)
        offs_n = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + 1

        # --- Fix 1: Correct dependency relationship ---
        # j must be less than num_tokens and less than or equal to current time step t
        mask = (offs_n < num_tokens) & (offs_n <= t)

        # Load trellis[t-1, j] (stay)
        prev_stay_ptr = trellis_tm1_ptr + offs_n * trellis_stride_n
        prev_stay_score = tl.load(prev_stay_ptr, mask=mask, other=float('-inf'))

        # Load trellis[t-1, j-1] (advance)
        prev_advance_ptr = trellis_tm1_ptr + (offs_n - 1) * trellis_stride_n
        prev_advance_score = tl.load(prev_advance_ptr, mask=mask, other=float('-inf'))

        # Load emission[t, tokens[j]]
        tokens_j = tl.load(tokens_ptr + offs_n, mask=mask, other=0)  # other=0 is safe due to mask
        emission_token = tl.load(emission_t_ptr + tokens_j, mask=mask, other=float('-inf'))

        # Calculate path scores
        stay_score = prev_stay_score + blank_emit
        advance_score = prev_advance_score + emission_token

        # --- Fix 3: Numerically stable logsumexp ---
        max_val = tl.maximum(stay_score, advance_score)
        # Use log(e^a + e^b) = max + log(1 + exp(-abs(a-b)))
        # Avoids (inf - inf) = NaN problem
        min_val = tl.minimum(stay_score, advance_score)
        log_sum = tl.where(
            max_val > float('-inf'),
            max_val + tl.log(1.0 + tl.exp(min_val - max_val)),
            float('-inf')
        )

        # Write back result
        trellis_t_ptr_j = trellis_t_ptr + offs_n * trellis_stride_n
        tl.store(trellis_t_ptr_j, log_sum, mask=mask)
    
    def _get_trellis(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor:
        """
        Build CTC alignment trellis (grid) - Triton optimized version
        
        Args:
            emission: Emission probability matrix [T, V]
            tokens: Token ID list
            blank_id: Blank token ID
            
        Returns:
            Trellis matrix [T, N]
        """
        # If Triton is available and on CUDA device, use Triton optimized version
        if TRITON_AVAILABLE and torch.cuda.is_available() and emission.device.type == 'cuda':
            return self._get_trellis_triton(emission, tokens, blank_id)
        else:
            # Fallback to original implementation
            print("Warning: Triton not enabled or CUDA unavailable, falling back to original CTC alignment implementation")
            return self._get_trellis_original(emission, tokens, blank_id)
    
    def _get_trellis_original(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor:
        """
        Original trellis construction implementation (as fallback)
        """
        num_frame = emission.size(0)
        num_tokens = len(tokens)
        
        # Initialize trellis
        trellis = torch.full((num_frame, num_tokens), float('-inf'), device=emission.device)
        trellis[0, 0] = emission[0, blank_id]
        
        # Fill first row
        for t in range(1, num_frame):
            trellis[t, 0] = trellis[t-1, 0] + emission[t, blank_id]
        
        # Fill trellis
        for t in range(1, num_frame):
            for j in range(1, min(num_tokens, t + 1)):
                # Stay at current token (insert blank)
                stay_score = trellis[t-1, j] + emission[t, blank_id]
                
                # Advance to next token
                advance_score = trellis[t-1, j-1] + emission[t, tokens[j]]
                
                trellis[t, j] = torch.logsumexp(torch.stack([stay_score, advance_score]), dim=0)
        
        return trellis
    
    def _get_trellis_triton(self, emission: torch.Tensor, tokens: List[int], blank_id: int) -> torch.Tensor:
        """
        Triton optimized trellis construction - optimized version
        """
        assert emission.is_cuda, "Input tensor must be on CUDA device"
        
        num_frame, vocab_size = emission.size()
        tokens_tensor = torch.tensor(tokens, device=emission.device, dtype=torch.long)
        num_tokens = len(tokens_tensor)

        # --- Optimization 3: Ensure memory contiguity ---
        trellis = torch.full((num_frame, num_tokens), float('-inf'),
                            device=emission.device,
                            dtype=torch.float32).contiguous()

        if num_tokens == 0:
            return trellis

        # --- Optimization 4: Use vectorized cumsum to initialize first column ---
        # Calculate cumulative blank probabilities starting from t=0
        trellis[:, 0] = emission[:, blank_id].cumsum(dim=0)

        # --- Optimization 2: Dynamic block size ---
        # Adapt to different token sequence lengths, improve GPU utilization
        BLOCK_SIZE_N = min(1024, triton.next_power_of_2(num_tokens)) if num_tokens > 1 else 1

        # Main loop
        for t in range(1, num_frame):
            # --- Optimization 1: Scalar broadcasting ---
            # Pass blank emission as scalar, avoid redundant loading
            blank_emit = emission[t, blank_id].item()

            # Launch grid, only compute j > 0 part
            if num_tokens > 1:
                grid = lambda meta: (triton.cdiv(num_tokens - 1, meta['BLOCK_SIZE_N']),)
                self._trellis_row_kernel_optimized[grid](
                    trellis_t_ptr=trellis[t],
                    trellis_tm1_ptr=trellis[t-1],
                    emission_t_ptr=emission[t],
                    tokens_ptr=tokens_tensor,
                    num_tokens=num_tokens,
                    blank_emit=blank_emit,
                    t=t,
                    trellis_stride_n=trellis.stride(1),
                    BLOCK_SIZE_N=BLOCK_SIZE_N,
                )
                
        return trellis
    
    def _backtrack(self, trellis: torch.Tensor, emission: torch.Tensor, tokens: List[int], blank_id: int) -> Optional[List]:
        """
        Backtrack optimal path from trellis
        
        Args:
            trellis: Trellis matrix
            emission: Emission probability matrix
            tokens: Token ID list
            blank_id: Blank token ID
            
        Returns:
            List of optimal path points
        """
        from dataclasses import dataclass
        
        @dataclass
        class Point:
            token_index: int
            time_index: int
            score: float
        
        t, j = trellis.size(0) - 1, trellis.size(1) - 1
        
        path = [Point(j, t, emission[t, blank_id].exp().item())]
        
        while j > 0 and t > 0:
            # Calculate stay and advance scores
            stay_score = trellis[t-1, j] + emission[t, blank_id]
            advance_score = trellis[t-1, j-1] + emission[t, tokens[j]]
            
            # Choose path with higher score
            if advance_score > stay_score:
                j -= 1
                prob = emission[t, tokens[j+1]].exp().item()
            else:
                prob = emission[t, blank_id].exp().item()
            
            t -= 1
            path.append(Point(j, t, prob))
        
        # Fill remaining time steps
        while t > 0:
            prob = emission[t-1, blank_id].exp().item()
            path.append(Point(j, t-1, prob))
            t -= 1
        
        return path[::-1]  # Reverse path
    
    def _merge_repeats(self, path: List, transcript: str) -> List:
        """
        Merge repeated characters in path
        
        Args:
            path: CTC path
            transcript: Transcription text
            
        Returns:
            List of character segments
        """
        from dataclasses import dataclass
        
        @dataclass
        class Segment:
            label: str
            start: int
            end: int
            score: float
        
        segments = []
        i1, i2 = 0, 0
        
        while i1 < len(path):
            # Find range of same token
            while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                i2 += 1
            
            # Calculate average score
            score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
            
            # Create segment
            if path[i1].token_index < len(transcript):
                segments.append(Segment(
                    label=transcript[path[i1].token_index],
                    start=path[i1].time_index,
                    end=path[i2-1].time_index + 1,
                    score=score
                ))
            
            i1 = i2
        
        return segments
    
    def _generate_word_alignments(self, char_segments: List, transcript: str, time_ratio: float) -> List[WordSegment]:
        """
        Generate word-level alignments from character-level alignments
        
        Args:
            char_segments: Character segment list
            transcript: Transcription text
            time_ratio: Time ratio factor
            
        Returns:
            List of word segments
        """
        words = []
        
        if self.language in LANGUAGES_WITHOUT_SPACES:
            # For languages without spaces, each character is a word
            for seg in char_segments:
                if seg.label != "|":
                    words.append(WordSegment(
                        word=seg.label,
                        start=round(seg.start * time_ratio, 3),
                        end=round(seg.end * time_ratio, 3),
                        score=round(seg.score, 3)
                    ))
        else:
            # For languages with spaces, split words by |
            current_word_chars = []
            current_word_start = None
            current_word_scores = []
            
            for seg in char_segments:
                if seg.label == "|": # Word separator
                    if current_word_chars:
                        # Complete current word
                        word_text = "".join(current_word_chars)
                        word_start = current_word_start * time_ratio
                        word_end = seg.start * time_ratio
                        word_score = sum(current_word_scores) / len(current_word_scores)
                        
                        words.append(WordSegment(
                            word=word_text,
                            start=round(word_start, 3),
                            end=round(word_end, 3),
                            score=round(word_score, 3)
                        ))
                        
                        # Reset
                        current_word_chars = []
                        current_word_start = None
                        current_word_scores = []
                else:
                    # Add character to current word
                    current_word_chars.append(seg.label)
                    if current_word_start is None:
                        current_word_start = seg.start
                    current_word_scores.append(seg.score)
            
            # Handle last word
            if current_word_chars:
                word_text = "".join(current_word_chars)
                word_start = current_word_start * time_ratio
                word_end = char_segments[-1].end * time_ratio
                word_score = sum(current_word_scores) / len(current_word_scores)
                
                words.append(WordSegment(
                    word=word_text,
                    start=round(word_start, 3),
                    end=round(word_end, 3),
                    score=round(word_score, 3)
                ))
        
        return words

    def merge_short_words(self, word_segments: List[WordSegment]) -> List[WordSegment]:
        """
        Merge word segments with short duration
        
        Implementation Logic:
        1. Identify words with duration below threshold
        2. Merge with adjacent shortest word
        3. Update time boundaries and text
        4. Recalculate confidence scores
        
        Args:
            word_segments: Original word segment list
            
        Returns:
            Merged word segment list
        """
        if not word_segments or self.merge_threshold <= 0:
            return word_segments
        
        segments = word_segments.copy()
        merged = True
        
        while merged:
            merged = False
            
            for i, segment in enumerate(segments):
                duration = segment.end - segment.start
                
                if duration < self.merge_threshold:
                    # Find adjacent shortest segment to merge
                    merge_target_idx = self._find_shortest_neighbor(segments, i)
                    
                    if merge_target_idx is not None:
                        # Merge word segments
                        if merge_target_idx < i:
                            merged_segment = self._merge_two_segments(
                                segments[merge_target_idx], segments[i]
                            )
                            segments[merge_target_idx] = merged_segment
                            segments.pop(i)
                        else:
                            merged_segment = self._merge_two_segments(
                                segments[i], segments[merge_target_idx]
                            )
                            segments[i] = merged_segment
                            segments.pop(merge_target_idx)
                        
                        merged = True
                        break
        
        return segments
    
    def _find_shortest_neighbor(self, segments: List[WordSegment], current_idx: int) -> Optional[int]:
        """
        Find index of shortest adjacent segment to specified word segment
        
        Args:
            segments: Word segment list
            current_idx: Current word segment index
            
        Returns:
            Index of adjacent shortest segment
        """
        neighbors = []
        
        # Check previous segment
        if current_idx > 0:
            duration = segments[current_idx - 1].end - segments[current_idx - 1].start
            neighbors.append((current_idx - 1, duration))
        
        # Check next segment
        if current_idx < len(segments) - 1:
            duration = segments[current_idx + 1].end - segments[current_idx + 1].start
            neighbors.append((current_idx + 1, duration))
        
        if not neighbors:
            return None
        
        # Return index of adjacent segment with shortest duration
        return min(neighbors, key=lambda x: x[1])[0]
    
    def _merge_two_segments(self, segment1: WordSegment, segment2: WordSegment) -> WordSegment:
        """
        Merge two word segments
        
        Args:
            segment1: First word segment (earlier in time)
            segment2: Second word segment (later in time)
            
        Returns:
            Merged word segment
        """
        # Merge text content
        merged_text = f"{segment1.word} {segment2.word}"
        
        # Time range: from first word start to second word end
        start_time = segment1.start
        end_time = segment2.end
        
        # Calculate average confidence
        scores = [seg.score for seg in [segment1, segment2] if seg.score is not None]
        avg_score = sum(scores) / len(scores) if scores else None
        
        return WordSegment(
            word=merged_text,
            start=start_time,
            end=end_time,
            score=avg_score
        )
    
    def extract_audio_features(self, 
                             audio: np.ndarray, 
                             sampling_rate: int,
                             word_segments: List[WordSegment],
                             original_word_count: int) -> Dict[str, Any]:
        """
        Extract audio features based on word-level timestamps
        
        Implementation Logic:
        1. Use parselmouth to extract global pitch and intensity
        2. Iterate through each word segment to extract local features
        3. Calculate global statistical features
        4. Return complete feature dictionary
        
        Args:
            audio: Audio array
            sampling_rate: Sample rate
            word_segments: Word segment list (possibly merged)
            original_word_count: Original word count (for calculating speaking rate)
            
        Returns:
            Dictionary containing all features
        """
        # Create temporary audio file for parselmouth analysis
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_audio:
            sf.write(temp_audio.name, audio, sampling_rate)
            sound = parselmouth.Sound(temp_audio.name)
        
        try:
            # Calculate global features
            total_duration = len(audio) / sampling_rate
            speaking_rate = original_word_count / total_duration if total_duration > 0 else 0

            # Batch extract pitch and intensity (improve efficiency)
            try:
                pitch = sound.to_pitch(pitch_floor=50.0, pitch_ceiling=600.0)
                intensity = sound.to_intensity()
            except:
                pitch = None
                intensity = None
            
            word_features = []
            
            # Iterate through each word segment to extract features
            for word_segment in word_segments:
                if word_segment.start is None or word_segment.end is None:
                    continue
                
                start, end = word_segment.start, word_segment.end
                duration = end - start
                
                # Clean word text, remove special symbols
                clean_word = word_segment.word
                
                # Extract corresponding audio segment
                start_sample = int(start * sampling_rate)
                end_sample = int(end * sampling_rate)
                segment_audio = audio[start_sample:end_sample]
                
                if len(segment_audio) == 0:
                    continue
                
                # === Pitch feature extraction ===
                avg_pitch = np.nan
                pitch_slope = np.nan
                
                if pitch is not None:
                    try:
                        # Get pitch values for current word time segment
                        pitch_times = pitch.xs()
                        pitch_values = pitch.selected_array['frequency']
                        
                        # Find pitch values for corresponding time segment
                        mask = (pitch_times >= start) & (pitch_times <= end)
                        if np.any(mask):
                            segment_pitch = pitch_values[mask]
                            segment_pitch[segment_pitch == 0] = np.nan
                            
                            # Remove all null values, keep only valid pitch values
                            valid_pitch = segment_pitch[~np.isnan(segment_pitch)]
                            
                            if len(valid_pitch) > 0:
                                avg_pitch = np.mean(valid_pitch)
                                
                                # Calculate pitch change trend (slope) - safer method
                                if len(valid_pitch) >= 15:
                                    # Only calculate slope when there are enough points (using linear regression)
                                    time_points = np.linspace(0, duration, len(valid_pitch))
                                    coeffs = np.polyfit(time_points, valid_pitch, 1)
                                    pitch_slope = coeffs[0]  # Slope
                                else:
                                    # Set to NaN when too few data points, safer
                                    pitch_slope = np.nan
                            else:
                                # Set to default values when no valid pitch values
                                avg_pitch = np.nan
                                pitch_slope = np.nan
                    except:
                        pass
                
                # === Energy feature extraction ===
                rms_energy = np.sqrt(np.mean(segment_audio**2))  # RMS energy
                energy_slope = np.nan
                
                if intensity is not None:
                    try:
                        intensity_times = intensity.xs()
                        intensity_values = intensity.values[0]
                        
                        # Find intensity values for corresponding time segment
                        mask = (intensity_times >= start) & (intensity_times <= end)
                        if np.any(mask):
                            segment_intensity = intensity_values[mask]
                            
                            if len(segment_intensity) > 2:
                                start_energy = np.nanmean(segment_intensity[:len(segment_intensity)//3])
                                end_energy = np.nanmean(segment_intensity[len(segment_intensity)*2//3:])
                                if not (np.isnan(start_energy) or np.isnan(end_energy)):
                                    energy_slope = (end_energy - start_energy) / duration
                    except:
                        pass
                
                # === Spectral feature extraction ===
                try:
                    # Spectral centroid: reflects timbre characteristics
                    segment_length = len(segment_audio)
                    if segment_length < 2048:
                        n_fft = 2 ** int(np.log2(segment_length))
                        n_fft = max(n_fft, 512)  # Minimum value is 512
                    else:
                        n_fft = 2048  # Default value
                    
                    spectral_centroid = librosa.feature.spectral_centroid(
                        y=segment_audio, sr=sampling_rate, n_fft=n_fft)[0].mean()
                except:
                    spectral_centroid = np.nan
                
                # Store all features for current word
                word_features.append({
                    "word": clean_word,
                    "start_time": start,
                    "end_time": end,
                    "duration": duration,
                    "confidence_score": word_segment.score,
                    "pitch_mean": float(avg_pitch) if not np.isnan(avg_pitch) else None,
                    "pitch_slope": float(pitch_slope) if not np.isnan(pitch_slope) else None,
                    "energy_rms": float(rms_energy),
                    "energy_slope": float(energy_slope) if not np.isnan(energy_slope) else None,
                    "spectral_centroid": float(spectral_centroid) if not np.isnan(spectral_centroid) else None
                })
            
            # Return complete feature dictionary
            return {
                "total_duration": total_duration,
                "speaking_rate": speaking_rate,
                "original_word_count": original_word_count,
                "processed_word_count": len(word_features),
                "word_features": word_features
            }
            
        finally:
            # Clean up temporary file
            try:
                os.unlink(temp_audio.name)
            except:
                pass
    
    def extract_features_for_sample(self, 
                                  audio_array: np.ndarray, 
                                  sampling_rate: int, 
                                  text: Optional[str] = None,
                                  enable_word_merging: bool = True) -> Dict[str, Any]:
        """
        Extract complete word-level features for a single sample
        
        Implementation Logic:
        1. If no text provided, use Whisper for transcription
        2. Get word-level timestamps
        3. Optional word merging processing
        4. Extract audio features
        5. Return feature dictionary
        
        Args:
            audio_array: Audio array
            sampling_rate: Sample rate
            text: Transcription text (optional)
            enable_word_merging: Whether to enable word merging
            
        Returns:
            Dictionary containing all features
        """
        try:
            # Resample to standard sample rate
            if sampling_rate != SAMPLE_RATE:
                audio_array = librosa.resample(
                    audio_array, orig_sr=sampling_rate, target_sr=SAMPLE_RATE
                )
                sampling_rate = SAMPLE_RATE
            
            # If no text provided, perform transcription
            if text is None:
                text = self.transcribe_audio(audio_array, sampling_rate)
            
            if not text.strip():
                return {
                    "error": "Transcription text is empty",
                    "word_features": []
                }
            
            # Get word-level timestamps
            aligned_segments = self.get_word_timestamps(audio_array, text)
            
            if not aligned_segments:
                return {
                    "error": "Word-level alignment failed",
                    "word_features": []
                }
            
            # Collect all word segments
            all_word_segments = []
            for segment in aligned_segments:
                all_word_segments.extend(segment.words)
            
            # Record original word count
            original_word_count = len(all_word_segments)
            
            # Optional word merging processing
            if enable_word_merging:
                all_word_segments = self.merge_short_words(all_word_segments)
            
            # Extract audio features
            features = self.extract_audio_features(
                audio_array, sampling_rate, all_word_segments, original_word_count
            )
            
            # Add transcription text
            features["transcribed_text"] = text
            
            return features
            
        except Exception as e:
            return {
                "error": str(e),
                "word_features": []
            }


# ===== Example Usage =====
def example_usage():
    """
    Example usage of AudioFeatureExtractor
    
    This function demonstrates how to use the core feature extraction functionality
    """
    print("🎵 Audio Feature Extractor Example")
    print("=" * 50)
    
    # Initialize extractor
    extractor = AudioFeatureExtractor(
        language="en",  # Can be modified based on needs
        device="cuda",
        merge_threshold=0.5
    )
    
    # Example: Load audio file and extract features
    # audio_path = "path/to/your/audio.wav"  # Replace with your audio file path
    # audio_array, sampling_rate = librosa.load(audio_path, sr=None)
    
    # Extract features
    # features = extractor.extract_features_for_sample(
    #     audio_array, 
    #     sampling_rate, 
    #     text=None,  # Will use Whisper for transcription if None
    #     enable_word_merging=True
    # )
    
    # print("Feature extraction completed!")
    # print(f"Total duration: {features.get('total_duration', 0):.2f}s")
    # print(f"Speaking rate: {features.get('speaking_rate', 0):.2f} words/s")
    # print(f"Word count: {features.get('processed_word_count', 0)}")
    
    print("Please uncomment the code above and provide an audio file path to test.")


if __name__ == "__main__":
    example_usage()