import json
import glob 
from pathlib import Path
import os
import sys
import argparse
# import matplotlib.pyplot as plt  # Disabled: plotting removed

def analyze_results(pattern="*", expected_count=None):
    """
    Analyze results from output folders matching the pattern.
    
    Args:
        pattern: Pattern to match folder names (e.g., "*-TestRun" or "*-UnifiedFeedbackWith5Turns")
        expected_count: Expected number of behaviors (if None, uses actual count)
    """
    # Use relative path 'outputs' instead of SCRATCH environment variable
    # If pattern contains wildcards, use glob; otherwise treat as direct directory path
    if '*' in pattern or '?' in pattern:
        # Pattern matching mode
        folders = [p for p in glob.glob(os.path.join('outputs', pattern)) if os.path.isdir(p)]
    else:
        # Direct directory path mode
        folder_path = os.path.join('outputs', pattern)
        if os.path.isdir(folder_path):
            folders = [folder_path]
        else:
            folders = []
    
    if not folders:
        print(f"No folders found matching pattern: {pattern}")
        return
    
    for folder in sorted(folders):
        jb_true = jb_false = total_refinements = total_conversation = file_count = 0
        # Defense statistics
        total_defense_attempts = 0
        total_defense_successes = 0
        defense_enabled_count = 0
        total_increase_attempts = 0
        total_increase_successes = 0
        total_decrease_attempts = 0
        total_decrease_successes = 0
        # Score change statistics: from_score_to_score -> count (separated by direction)
        increase_score_changes = {}  # For increase direction
        decrease_score_changes = {}  # For decrease direction
        # Score changes for successful increases/decreases
        increase_successful_score_diffs = []  # List of (new_score - original_score) for successful increases
        decrease_successful_score_diffs = []  # List of (original_score - new_score) for successful decreases
        # Similarity scores (calculated during analysis)
        all_similarity_scores = []
        # Similarity scores for successful defense (score_changed == True)
        successful_similarity_scores = []
        
        for file in glob.glob(os.path.join(folder, 'behavior_*.json')):
            try:
                data = json.load(open(file))
                fb = data.get('feedback_result', {})
                jb = fb.get('jailbreak')
                jb_true += jb is True
                jb_false += jb is False
                turns = fb.get('turns') or []
                if isinstance(turns, list):
                    for t in turns:
                        refinements = t.get('refinements') if isinstance(t, dict) else None
                        if isinstance(refinements, list):
                            total_refinements += len(refinements)
                        elif isinstance(refinements, int):
                            total_refinements += refinements
                        
                        # Process defense_info if available
                        defense_info = t.get('defense_info') if isinstance(t, dict) else None
                        if defense_info:
                            # Read similarity score from defense_info if available
                            if 'similarity_score' in defense_info:
                                similarity = defense_info['similarity_score']
                                if similarity is not None:
                                    all_similarity_scores.append(similarity)
                                    # If score was successfully changed, add to successful_similarity_scores
                                    if defense_info.get('score_changed', False):
                                        successful_similarity_scores.append(similarity)
                            
                            # Collect score change statistics from defense_info
                            direction = defense_info.get('direction')
                            original_score = defense_info.get('original_score')
                            new_score = defense_info.get('new_score')
                            
                            if direction and original_score is not None and new_score is not None:
                                change_key = f"{original_score}_to_{new_score}"
                                
                                if direction == "increase":
                                    # Track all increase attempts
                                    total_increase_attempts += 1
                                    increase_score_changes[change_key] = increase_score_changes.get(change_key, 0) + 1
                                    # Check if successfully increased
                                    if new_score > original_score:
                                        total_increase_successes += 1
                                        increase_successful_score_diffs.append(new_score - original_score)
                                
                                elif direction == "decrease":
                                    # Track all decrease attempts
                                    total_decrease_attempts += 1
                                    decrease_score_changes[change_key] = decrease_score_changes.get(change_key, 0) + 1
                                    # Check if successfully decreased
                                    if new_score < original_score:
                                        total_decrease_successes += 1
                                        decrease_successful_score_diffs.append(original_score - new_score)
                
                conv = fb.get('conversation') or []
                if isinstance(conv, list):
                    total_conversation += len(conv)/2
                
                # Collect defense statistics from defense_stats for overall counts
                defense_stats = fb.get('defense_stats')
                if defense_stats:
                    defense_enabled_count += 1
                    total_defense_attempts += defense_stats.get('attempts', 0)
                    total_defense_successes += defense_stats.get('successes', 0)
                    # Get increase/decrease attempts from defense_stats (not from defense_info)
                    # defense_stats counts ALL refinement steps, while defense_info only records final result per turn
                    total_increase_attempts += defense_stats.get('increase_attempts', 0)
                    total_increase_successes += defense_stats.get('increase_successes', 0)
                    total_decrease_attempts += defense_stats.get('decrease_attempts', 0)
                    total_decrease_successes += defense_stats.get('decrease_successes', 0)
                    
                    # Read similarity scores from all_defense_info
                    all_defense_info = defense_stats.get('all_defense_info', [])
                    
                    # If all_defense_info exists, read similarity scores from it
                    if isinstance(all_defense_info, list) and len(all_defense_info) > 0:
                        # Read similarity scores from all_defense_info items
                        stats_similarity_scores = []
                        for defense_info_item in all_defense_info:
                            if 'similarity_score' in defense_info_item:
                                similarity = defense_info_item['similarity_score']
                                if similarity is not None:
                                    stats_similarity_scores.append(similarity)
                    
                    # If no all_defense_info, fall back to reading from turns' defense_info (backward compatibility)
                    elif not all_defense_info or len(all_defense_info) == 0:
                        stats_similarity_scores = []
                        turns = fb.get('turns') or []
                        if isinstance(turns, list):
                            for t in turns:
                                defense_info = t.get('defense_info') if isinstance(t, dict) else None
                                if defense_info and 'similarity_score' in defense_info:
                                    similarity = defense_info['similarity_score']
                                    if similarity is not None:
                                        stats_similarity_scores.append(similarity)
                    
                    # Collect similarity scores from defense_stats for overall statistics
                    if isinstance(stats_similarity_scores, list) and len(stats_similarity_scores) > 0:
                        all_similarity_scores.extend(stats_similarity_scores)
                        
                        # Collect successful similarity scores (score_changed == True)
                        # If all_defense_info exists, use it; otherwise use turns' defense_info
                        if isinstance(all_defense_info, list) and len(all_defense_info) > 0:
                            # Match with all_defense_info by index
                            for idx, defense_info_item in enumerate(all_defense_info):
                                if idx < len(stats_similarity_scores):
                                    if defense_info_item.get('score_changed', False):
                                        successful_similarity_scores.append(stats_similarity_scores[idx])
                        else:
                            # Fall back to turns' defense_info
                            turns = fb.get('turns') or []
                            if isinstance(turns, list):
                                for t in turns:
                                    defense_info = t.get('defense_info') if isinstance(t, dict) else None
                                    if defense_info and 'similarity_score' in defense_info:
                                        if defense_info.get('score_changed', False):
                                            successful_similarity_scores.append(defense_info['similarity_score'])
                
                file_count += 1
            except Exception as e:
                print(f"  Warning: Error processing {file}: {e}")
                continue
        
        if file_count == 0:
            print(f"{folder}: No valid behavior files found")
            continue
        
        avg_conv = round(total_conversation / file_count, 2) if file_count else 0
        print(f"\n{folder}")
        print(f"  Total files: {file_count}")
        
        failed_num = 0
        if expected_count and jb_true + jb_false != expected_count:
            failed_num = expected_count - jb_true - jb_false
        
        print(f"  jailbreak true: {jb_true}, false: {jb_false}", end="")
        if failed_num > 0:
            print(f", failed: {failed_num}")
        else:
            print()
        
        if expected_count:
            print(f'  jailbreak success rate: {jb_true/expected_count:.4f}')
            print(f"  avg refinements per sample: {total_refinements/expected_count:.2f}")
        else:
            print(f'  jailbreak success rate: {jb_true/file_count:.4f}')
            print(f"  avg refinements per sample: {total_refinements/file_count:.2f}")
        
        print(f"  conversation messages total: {total_conversation} (avg {avg_conv}/file)")
        
        # Display defense statistics if defense was enabled
        # if defense_enabled_count > 0:
        #     defense_success_rate = (
        #         total_defense_successes / total_defense_attempts 
        #         if total_defense_attempts > 0 else 0.0
        #     )
        #     increase_success_rate = (
        #         total_increase_successes / total_increase_attempts
        #         if total_increase_attempts > 0 else 0.0
        #     )
        #     decrease_success_rate = (
        #         total_decrease_successes / total_decrease_attempts
        #         if total_decrease_attempts > 0 else 0.0
        #     )
            
        #     print(f"  Defense enabled: {defense_enabled_count} files")
        #     print(f"  Defense attempts: {total_defense_attempts}")
        #     print(f"  Defense successes: {total_defense_successes}")
        #     print(f"  Defense success rate (score changed): {defense_success_rate:.4f} ({defense_success_rate*100:.2f}%)")
        #     print(f"  ")
        #     print(f"  Increase attempts: {total_increase_attempts}")
        #     print(f"  Increase successes: {total_increase_successes}")
        #     print(f"  Increase success rate: {increase_success_rate:.4f} ({increase_success_rate*100:.2f}%)")
        #     print(f"  Decrease attempts: {total_decrease_attempts}")
        #     print(f"  Decrease successes: {total_decrease_successes}")
        #     print(f"  Decrease success rate: {decrease_success_rate:.4f} ({decrease_success_rate*100:.2f}%)")
            
        #     # Display score changes (separated by direction)
        #     if increase_score_changes or decrease_score_changes:
        #         print(f"  ")
        #         if increase_score_changes:
        #             print(f"  Score changes (Increase direction):")
        #             # Sort by count (descending)
        #             sorted_changes = sorted(increase_score_changes.items(), key=lambda x: x[1], reverse=True)
        #             for change_key, count in sorted_changes:
        #                 from_score, to_score = change_key.split('_to_')
        #                 print(f"    {from_score} -> {to_score}: {count} times")
                    
        #             # Display successful increase statistics
        #             if increase_successful_score_diffs:
        #                 avg_increase = sum(increase_successful_score_diffs) / len(increase_successful_score_diffs)
        #                 print(f"    Successful increases: {len(increase_successful_score_diffs)} times")
        #                 print(f"    Average score increase: {avg_increase:.2f}")
        #             else:
        #                 print(f"    Successful increases: 0 times")
        #                 print(f"    Average score increase: N/A")
                
        #         if decrease_score_changes:
        #             print(f"  ")
        #             print(f"  Score changes (Decrease direction):")
        #             # Sort by count (descending)
        #             sorted_changes = sorted(decrease_score_changes.items(), key=lambda x: x[1], reverse=True)
        #             for change_key, count in sorted_changes:
        #                 from_score, to_score = change_key.split('_to_')
        #                 print(f"    {from_score} -> {to_score}: {count} times")
                    
        #             # Display successful decrease statistics
        #             if decrease_successful_score_diffs:
        #                 avg_decrease = sum(decrease_successful_score_diffs) / len(decrease_successful_score_diffs)
        #                 print(f"    Successful decreases: {len(decrease_successful_score_diffs)} times")
        #                 print(f"    Average score decrease: {avg_decrease:.2f}")
        #             else:
        #                 print(f"    Successful decreases: 0 times")
        #                 print(f"    Average score decrease: N/A")
            
        #     # Display similarity statistics
        #     if all_similarity_scores:
        #         avg_similarity = sum(all_similarity_scores) / len(all_similarity_scores)
        #         min_similarity = min(all_similarity_scores)
        #         max_similarity = max(all_similarity_scores)
                
        #         # Calculate distribution: count scores >= 0.9 and >= 0.8
        #         count_above_90 = sum(1 for score in all_similarity_scores if score >= 0.9)
        #         count_above_80 = sum(1 for score in all_similarity_scores if score >= 0.8)
        #         total_count = len(all_similarity_scores)
        #         ratio_above_90 = count_above_90 / total_count if total_count > 0 else 0.0
        #         ratio_above_80 = count_above_80 / total_count if total_count > 0 else 0.0
                
        #         print(f"  ")
        #         print(f"  Similarity statistics (all samples):")
        #         print(f"    Average similarity: {avg_similarity:.4f}")
        #         print(f"    Min similarity: {min_similarity:.4f}")
        #         print(f"    Max similarity: {max_similarity:.4f}")
        #         print(f"    Total similarity scores: {total_count}")
        #         print(f"    Scores >= 0.9: {count_above_90} ({ratio_above_90:.2%})")
        #         print(f"    Scores >= 0.8: {count_above_80} ({ratio_above_80:.2%})")
            
        #     # Display similarity statistics for successful defense samples
        #     if successful_similarity_scores:
        #         avg_similarity_successful = sum(successful_similarity_scores) / len(successful_similarity_scores)
        #         min_similarity_successful = min(successful_similarity_scores)
        #         max_similarity_successful = max(successful_similarity_scores)
                
        #         # Calculate distribution: count scores >= 0.9 and >= 0.8
        #         count_above_90_successful = sum(1 for score in successful_similarity_scores if score >= 0.9)
        #         count_above_80_successful = sum(1 for score in successful_similarity_scores if score >= 0.8)
        #         total_count_successful = len(successful_similarity_scores)
        #         ratio_above_90_successful = count_above_90_successful / total_count_successful if total_count_successful > 0 else 0.0
        #         ratio_above_80_successful = count_above_80_successful / total_count_successful if total_count_successful > 0 else 0.0
                
        #         print(f"  ")
        #         print(f"  Similarity statistics (successful score changes only):")
        #         print(f"    Average similarity: {avg_similarity_successful:.4f}")
        #         print(f"    Min similarity: {min_similarity_successful:.4f}")
        #         print(f"    Max similarity: {max_similarity_successful:.4f}")
        #         print(f"    Total similarity scores: {total_count_successful}")
        #         print(f"    Scores >= 0.9: {count_above_90_successful} ({ratio_above_90_successful:.2%})")
        #         print(f"    Scores >= 0.8: {count_above_80_successful} ({ratio_above_80_successful:.2%})")

# Original hardcoded analysis for backward compatibility
if __name__ == "__main__" and len(sys.argv) == 1:
    # Default behavior: analyze all output folders under ./outputs
    analyze_results("*", expected_count=None)
    # plot_summary()  # Disabled: plotting removed
elif __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze experiment results")
    parser.add_argument(
        "--pattern",
        type=str,
        default="*",
        help="Pattern (or direct folder name) to match output folders under ./outputs (default: *)",
    )
    parser.add_argument("--expected-count", type=int, default=None,
                        help="Expected number of behaviors (default: use actual count)")
    parser.add_argument("--no-plot", action="store_true",
                        help="Don't generate plot")
    args = parser.parse_args()
    
    analyze_results(args.pattern, args.expected_count)
    # if not args.no_plot and args.pattern == "*-UnifiedFeedbackWith5Turns":
    #     plot_summary()  # Disabled: plotting removed
    pass


# ---------------------------------------------------------------------------
# Plot summary from collected metrics (derived from the printed results above)
# ---------------------------------------------------------------------------
# Disabled: plotting functionality removed
# def plot_summary(output_path: str = "analysis_summary.png") -> None:
#     """Plot summary charts - DISABLED"""
#     pass

