#!/usr/bin/env python3
"""
Final optimized feature map - STABLE VERSION
Based on the proven stable implementation from train_regularized_gradient_guided_cv.py
Only reorganized feature order to avoid overlaps, no complex validation logic.
"""

import numpy as np
from collections import defaultdict

# ==============================================================================
# FINAL OPTIMIZED FEATURE MAP - 48 Features (STABLE VERSION)
# Based on train_regularized_gradient_guided_cv.py implementation
# Reorganized into 4 clear categories without overlaps
# ==============================================================================

FEATURE_MAP_FINAL_STABLE = {
    "Dynamics": { # Features capturing temporal changes and gradients across steps (19 features)
        "1.1: Cross-step Gradients": [
            'top1_gradient_mean',           # 0
            'top1_gradient_std',            # 1
            'top1_gradient_max',            # 2
            'top1_gradient_min',            # 3
            'top1_gradient_trend',          # 4
            'topk_gradient_mean',           # 5
            'topk_gradient_std',            # 6
            'topk_gradient_max',            # 7
            'topk_gradient_min',            # 8
            'topk_gradient_trend',          # 9
        ],
        "1.2: Token-level Gradients": [
            'token_gradient_mean',          # 10
            'token_gradient_std',           # 11
            'token_gradient_max',           # 12
            'token_gradient_min',           # 13
        ],
        "1.3: Step Progression": [
            'step_progression_entropy',     # 14
            'step_progression_concentration',# 15
            'step_progression_spread',      # 16
        ],
        "1.4: Confidence Change": [
            'top1_confidence_change',       # 17
            'topk_confidence_change',       # 18
        ],
    },
    "Position": { # Features capturing key positional information (first/last steps) (14 features)
        "2.1: First Step Specific": [
            'first_attention_entropy',      # 19
            'first_attention_concentration',# 20
            'first_attention_spread',       # 21
            'first_confidence_volatility',  # 22
            'first_confidence_skewness',    # 23
            'first_top1_avg',               # 24
            'first_topk_avg',               # 25
        ],
        "2.2: Last Step Specific": [
            'last_attention_entropy',       # 26
            'last_attention_concentration', # 27
            'last_attention_spread',        # 28
            'last_confidence_volatility',   # 29
            'last_confidence_skewness',     # 30
            'last_top1_avg',                # 31
            'last_topk_avg',                # 32
        ],
    },
    "Stability": { # Features capturing stability and consistency patterns (10 features)
        "3.1: Attention Stability": [
            'attention_entropy_mean',       # 33
            'attention_entropy_std',        # 34
            'attention_concentration_mean', # 35
            'attention_concentration_std',  # 36
            'attention_spread_mean',        # 37
            'attention_spread_std',         # 38
        ],
        "3.2: Token-level Stability": [
            'token_volatility_mean',        # 39
            'token_volatility_std',         # 40
            'token_skewness_mean',          # 41
            'token_skewness_std',           # 42
        ],
    },
    "Structure": { # Features capturing structural and derived information (5 features)
        "4.1: Structural Metrics": [
            'normalized_step_count',        # 43 (Replaced total_tokens with new feature)
            'first_token_count',            # 44
            'last_token_count',             # 45
            'avg_tokens_per_step',          # 46
            'std_tokens_per_step',        # 47 (Kept from original)
        ],
    }
}

ALL_FEATURE_NAMES_FINAL_STABLE = [feature for category_dict in FEATURE_MAP_FINAL_STABLE.values() 
                                for category_features in category_dict.values() 
                                for feature in category_features]

def extract_final_stable_features(data):
    """
    Extract the final optimized 48 features using STABLE implementation.
    Based on train_regularized_gradient_guided_cv.py - no complex validation.
    """
    features_list = []
    labels = []
    
    for item in data:
        detailed_confidence_analysis = item.get("detailed_confidence_analysis", [])
        if len(detailed_confidence_analysis) >= 1:
            step_count = len(detailed_confidence_analysis)
            
            # Collect all steps information
            all_step_tokens = []
            all_step_top1_avgs = []
            all_step_topk_avgs = []
            all_step_token_counts = []
            
            for step in detailed_confidence_analysis:
                confidence_metrics = step.get("confidence_metrics", {})
                
                # Token-level confidences
                tokens = confidence_metrics.get("token_confidence_list", [])
                all_step_tokens.append(tokens)
                all_step_token_counts.append(len(tokens))
                
                # Statistical metrics
                top1_avg = confidence_metrics.get("average_trace_confidence_top1", 0)
                topk_avg = confidence_metrics.get("average_trace_confidence", 0)
                all_step_top1_avgs.append(top1_avg)
                all_step_topk_avgs.append(topk_avg)
            
            # 1. Cross-step gradient features (using all steps)
            if len(all_step_top1_avgs) > 1:
                top1_gradients = np.diff(all_step_top1_avgs)
                top1_gradient_mean = np.mean(top1_gradients)
                top1_gradient_std = np.std(top1_gradients)
                top1_gradient_max = np.max(top1_gradients)
                top1_gradient_min = np.min(top1_gradients)
                top1_gradient_trend = top1_gradients[-1] - top1_gradients[0] if len(top1_gradients) > 1 else 0
            else:
                top1_gradient_mean = top1_gradient_std = top1_gradient_max = top1_gradient_min = top1_gradient_trend = 0
            
            if len(all_step_topk_avgs) > 1:
                topk_gradients = np.diff(all_step_topk_avgs)
                topk_gradient_mean = np.mean(topk_gradients)
                topk_gradient_std = np.std(topk_gradients)
                topk_gradient_max = np.max(topk_gradients)
                topk_gradient_min = np.min(topk_gradients)
                topk_gradient_trend = topk_gradients[-1] - topk_gradients[0] if len(topk_gradients) > 1 else 0
            else:
                topk_gradient_mean = topk_gradient_std = topk_gradient_max = topk_gradient_min = topk_gradient_trend = 0
            
            # 2. Step-wise attention patterns (using all steps)
            all_attention_entropies = []
            all_attention_concentrations = []
            all_attention_spreads = []
            
            for tokens in all_step_tokens:
                if tokens:
                    tokens_array = np.array(tokens)
                    # Normalize tokens to [0,1] for entropy calculation
                    normalized_tokens = tokens_array / (np.sum(tokens_array) + 1e-8)
                    entropy = -np.sum(normalized_tokens * np.log(normalized_tokens + 1e-8))
                    concentration = np.max(tokens_array) / (np.mean(tokens_array) + 1e-8)
                    spread = np.std(tokens_array) / (np.mean(tokens_array) + 1e-8)
                    
                    all_attention_entropies.append(entropy)
                    all_attention_concentrations.append(concentration)
                    all_attention_spreads.append(spread)
                else:
                    all_attention_entropies.append(0)
                    all_attention_concentrations.append(0)
                    all_attention_spreads.append(0)
            
            # Aggregate attention patterns across steps
            attention_entropy_mean = np.mean(all_attention_entropies)
            attention_entropy_std = np.std(all_attention_entropies)
            attention_concentration_mean = np.mean(all_attention_concentrations)
            attention_concentration_std = np.std(all_attention_concentrations)
            attention_spread_mean = np.mean(all_attention_spreads)
            attention_spread_std = np.std(all_attention_spreads)
            
            # 3. Token-level patterns across all steps
            all_token_gradients = []
            all_token_volatilities = []
            all_token_skewnesses = []
            
            for tokens in all_step_tokens:
                if len(tokens) > 1:
                    tokens_array = np.array(tokens)
                    # Token-level gradients
                    token_gradients = np.diff(tokens_array)
                    all_token_gradients.extend(token_gradients)
                    
                    # Token-level volatility and skewness
                    volatility = np.std(tokens_array) / (np.mean(tokens_array) + 1e-8)
                    skewness = np.mean(((tokens_array - np.mean(tokens_array)) / (np.std(tokens_array) + 1e-8)) ** 3)
                    
                    all_token_volatilities.append(volatility)
                    all_token_skewnesses.append(skewness)
            
            # Aggregate token-level patterns
            if all_token_gradients:
                token_gradient_mean = np.mean(all_token_gradients)
                token_gradient_std = np.std(all_token_gradients)
                token_gradient_max = np.max(all_token_gradients)
                token_gradient_min = np.min(all_token_gradients)
            else:
                token_gradient_mean = token_gradient_std = token_gradient_max = token_gradient_min = 0
            
            if all_token_volatilities:
                token_volatility_mean = np.mean(all_token_volatilities)
                token_volatility_std = np.std(all_token_volatilities)
            else:
                token_volatility_mean = token_volatility_std = 0
            
            if all_token_skewnesses:
                token_skewness_mean = np.mean(all_token_skewnesses)
                token_skewness_std = np.std(all_token_skewnesses)
            else:
                token_skewness_mean = token_skewness_std = 0
            
            # 4. Step progression patterns
            if step_count > 1:
                # How patterns change across steps
                step_progression_entropy = np.std(all_attention_entropies) / (np.mean(all_attention_entropies) + 1e-8)
                step_progression_concentration = np.std(all_attention_concentrations) / (np.mean(all_attention_concentrations) + 1e-8)
                step_progression_spread = np.std(all_attention_spreads) / (np.mean(all_attention_spreads) + 1e-8)
            else:
                step_progression_entropy = step_progression_concentration = step_progression_spread = 0
            
            # 5. First and last step specific features (for comparison)
            first_step = detailed_confidence_analysis[0]
            last_step = detailed_confidence_analysis[-1]
            
            first_confidence_metrics = first_step.get("confidence_metrics", {})
            last_confidence_metrics = last_step.get("confidence_metrics", {})
            
            first_tokens = first_confidence_metrics.get("token_confidence_list", [])
            last_tokens = last_confidence_metrics.get("token_confidence_list", [])
            
            # First step specific features
            if first_tokens:
                first_tokens_array = np.array(first_tokens)
                first_attention_entropy = -np.sum(first_tokens_array * np.log(first_tokens_array + 1e-8))
                first_attention_concentration = np.max(first_tokens_array) / (np.mean(first_tokens_array) + 1e-8)
                first_attention_spread = np.std(first_tokens_array) / (np.mean(first_tokens_array) + 1e-8)
                first_confidence_volatility = np.std(first_tokens_array) / (np.mean(first_tokens_array) + 1e-8)
                first_confidence_skewness = np.mean(((first_tokens_array - np.mean(first_tokens_array)) / (np.std(first_tokens_array) + 1e-8)) ** 3)
            else:
                first_attention_entropy = first_attention_concentration = first_attention_spread = 0
                first_confidence_volatility = first_confidence_skewness = 0
            
            # Last step specific features
            if last_tokens:
                last_tokens_array = np.array(last_tokens)
                last_attention_entropy = -np.sum(last_tokens_array * np.log(last_tokens_array + 1e-8))
                last_attention_concentration = np.max(last_tokens_array) / (np.mean(last_tokens_array) + 1e-8)
                last_attention_spread = np.std(last_tokens_array) / (np.mean(last_tokens_array) + 1e-8)
                last_confidence_volatility = np.std(last_tokens_array) / (np.mean(last_tokens_array) + 1e-8)
                last_confidence_skewness = np.mean(((last_tokens_array - np.mean(last_tokens_array)) / (np.std(last_tokens_array) + 1e-8)) ** 3)
            else:
                last_attention_entropy = last_attention_concentration = last_attention_spread = 0
                last_confidence_volatility = last_confidence_skewness = 0
            
            # 6. Traditional metrics
            first_top1_avg = first_confidence_metrics.get("average_trace_confidence_top1", 0)
            first_topk_avg = first_confidence_metrics.get("average_trace_confidence", 0)
            last_top1_avg = last_confidence_metrics.get("average_trace_confidence_top1", 0)
            last_topk_avg = last_confidence_metrics.get("average_trace_confidence", 0)
            
        # 7. Structural features (UPDATED: replaced total_tokens with std_tokens_per_step, kept normalized_step_count)
        first_token_count = len(first_tokens)
        last_token_count = len(last_tokens)
        total_tokens = sum(all_step_token_counts)
        avg_tokens_per_step = total_tokens / step_count if step_count > 0 else 0
        std_tokens_per_step = np.std(all_step_token_counts) if len(all_step_token_counts) > 1 else 0
        normalized_step_count = step_count / 10.0
        
        # 8. Confidence changes
        top1_confidence_change = last_top1_avg - first_top1_avg
        topk_confidence_change = last_topk_avg - first_topk_avg
        
        # Combine all features in the NEW ORDER (no overlaps)
        features = [
            # Dynamics (19 features)
            top1_gradient_mean, top1_gradient_std, top1_gradient_max, top1_gradient_min, top1_gradient_trend,
            topk_gradient_mean, topk_gradient_std, topk_gradient_max, topk_gradient_min, topk_gradient_trend,
            token_gradient_mean, token_gradient_std, token_gradient_max, token_gradient_min,
            step_progression_entropy, step_progression_concentration, step_progression_spread,
            top1_confidence_change, topk_confidence_change,
            
            # Position (14 features)
            first_attention_entropy, first_attention_concentration, first_attention_spread,
            first_confidence_volatility, first_confidence_skewness, first_top1_avg, first_topk_avg,
            last_attention_entropy, last_attention_concentration, last_attention_spread,
            last_confidence_volatility, last_confidence_skewness, last_top1_avg, last_topk_avg,
            
            # Stability (10 features)
            attention_entropy_mean, attention_entropy_std, attention_concentration_mean, 
            attention_concentration_std, attention_spread_mean, attention_spread_std,
            token_volatility_mean, token_volatility_std, token_skewness_mean, token_skewness_std,
            
            # Structure (5 features)
            std_tokens_per_step, first_token_count, last_token_count, avg_tokens_per_step, normalized_step_count,
        ]
        
        features_list.append(features)
        labels.append(1 if item.get("evaluation", {}).get("judge_correct", False) else 0)
    
    return np.array(features_list), np.array(labels)

def get_feature_info_stable():
    """Get feature information for the stable version."""
    feature_info = {}
    for category, subcategories in FEATURE_MAP_FINAL_STABLE.items():
        for subcategory, features in subcategories.items():
            for i, feature in enumerate(features):
                # Find the global index
                global_idx = ALL_FEATURE_NAMES_FINAL_STABLE.index(feature) + 1  # 1-based indexing
                feature_info[feature] = {
                    'index': global_idx,
                    'category': category,
                    'subcategory': subcategory
                }
    return feature_info

def print_feature_summary_stable():
    """Print summary of the stable feature map."""
    print("FINAL STABLE FEATURE MAP SUMMARY")
    print("=" * 60)
    print(f"Total Features: {len(ALL_FEATURE_NAMES_FINAL_STABLE)}")
    print()
    
    for category, subcategories in FEATURE_MAP_FINAL_STABLE.items():
        total_in_category = sum(len(features) for features in subcategories.values())
        print(f"📁 {category}: {total_in_category} features")
        for subcategory, features in subcategories.items():
            print(f"  └─ {subcategory}: {len(features)} features")
            for feature in features:
                idx = ALL_FEATURE_NAMES_FINAL_STABLE.index(feature) + 1
                print(f"     {idx:2d}. {feature}")
        print()

if __name__ == "__main__":
    print_feature_summary_stable()
    
    # Test feature extraction
    # Load a small sample to test
    import json
    with open('evaluations_0902/gpt-4.1_SimpleQA__full_20250902_evaluated.jsonl', 'r') as f:
        sample_data = []
        for i, line in enumerate(f):
            if i >= 5:  # Test with first 5 samples
                break
            sample_data.append(json.loads(line.strip()))
    
    print(f' Testing with {len(sample_data)} samples...')
    
    try:
        features, labels = extract_final_stable_features(sample_data)
        
        # Show first few features
        feature_info = get_feature_info_stable()
        for i in range(10):
            feature_name = list(feature_info.keys())[i]
            feature_idx = feature_info[feature_name]['index']
            value = features[0, feature_idx-1]  # Convert to 0-based index
            print(f'  {feature_idx:2d}. {feature_name:<35}: {value:.4f}')
            
            
    except Exception as e:
        print(f' Error: {e}')
        import traceback
        traceback.print_exc()
