import json
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import argparse

# Default Configuration
DEFAULT_MODELS_TO_ANALYZE = ["Zephyr-7b"]
DEFAULT_K_LIST = [200]
DEFAULT_ANSWERS_DIR = "Passk/wmdp_rmu_zephyr_answers_seed0_temp1.0"
DEFAULT_OUTPUT_DIR = "heatmaps/Passk/wmdp_rmu_zephyr_answers_seed0_temp1.0"

def load_json_safely(file_path):
    """Load JSON file with error handling"""
    try:
        print(f"Loading {file_path}...")
        with open(file_path, 'r') as f:
            data = json.load(f)
        print(f"✅ Successfully loaded {len(data)} entries")
        return data
    except Exception as e:
        print(f"❌ Error loading {file_path}: {e}")
        return None

def letter_to_index(letter):
    """Convert answer letter (A, B, C, D) to index (0, 1, 2, 3)"""
    if isinstance(letter, str) and letter.upper() in ['A', 'B', 'C', 'D']:
        return ord(letter.upper()) - ord('A')
    return -1  # Invalid answer

def index_to_letter(index):
    """Convert index (0, 1, 2, 3) to answer letter (A, B, C, D)"""
    if isinstance(index, int) and 0 <= index <= 3:
        return chr(ord('A') + index)
    return None

def check_answer_correctness(predicted_answer, ground_truth_index):
    """Check if the predicted answer matches the ground truth"""
    if isinstance(ground_truth_index, int) and 0 <= ground_truth_index <= 3:
        ground_truth_letter = index_to_letter(ground_truth_index)
        if isinstance(predicted_answer, str):
            return predicted_answer.upper().strip() == ground_truth_letter
    return False

def extract_accuracy_matrix_for_k(original_data, rmu_data, k_value):
    """Extract accuracy matrix for heatmap visualization"""
    
    # Get common indices and sort them for consistent ordering
    orig_indices = set(original_data.keys())
    rmu_indices = set(rmu_data.keys())
    common_indices = sorted(orig_indices.intersection(rmu_indices))
    
    print(f"   Found {len(common_indices)} common samples for k={k_value}")
    
    if len(common_indices) == 0:
        return None, None, None
    
    # Initialize matrix: rows = k responses, columns = samples
    # 1 = correct, 0 = incorrect, -1 = missing/invalid
    accuracy_matrix = np.full((k_value, len(common_indices)), -1, dtype=int)
    sample_labels = []
    sample_info = []
    
    for col_idx, wmdp_idx in enumerate(tqdm(common_indices, desc=f"Processing k={k_value}")):
        # Get ground truth answer
        ground_truth = original_data[wmdp_idx]["wmdp_data"]["answer"]
        
        # Get RMU answers for this k
        k_str = str(k_value)
        if k_str in rmu_data[wmdp_idx]:
            rmu_k_data = rmu_data[wmdp_idx][k_str]
            if "answers" in rmu_k_data:
                rmu_answers = rmu_k_data["answers"]
                
                # Check correctness for each answer
                for row_idx, rmu_answer in enumerate(rmu_answers[:k_value]):
                    if isinstance(rmu_answer, str) and len(rmu_answer.strip()) > 0:
                        is_correct = check_answer_correctness(rmu_answer, ground_truth)
                        accuracy_matrix[row_idx, col_idx] = 1 if is_correct else 0
        
        # Store sample info
        sample_labels.append(f"S{wmdp_idx}")
        sample_info.append({
            "wmdp_idx": wmdp_idx,
            "ground_truth_index": ground_truth,
            "ground_truth_letter": index_to_letter(ground_truth),
            "question": original_data[wmdp_idx]["wmdp_data"]["question"][:100] + "..."
        })
    
    return accuracy_matrix, sample_labels, sample_info

def create_accuracy_heatmap(accuracy_matrix, k_value, model_name, sample_labels, output_dir):
    """Create and save accuracy heatmap with improved colors and grid"""
    # Calculate figure size to make cells square
    # Each cell should be approximately 0.4 inches
    cell_size = 0.4
    fig_width = len(sample_labels) * cell_size + 2  # Add padding for labels
    fig_height = k_value * cell_size + 2  # Add padding for labels
    
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    # Create custom colormap: White for missing (-1), Red for incorrect (0), Green for correct (1)
    from matplotlib.colors import ListedColormap
    colors = ['#ffffff', '#d73027', '#2ca02c']  # White (missing), Red (incorrect), Green (correct)
    cmap = ListedColormap(colors)
    
    # Create heatmap using imshow with square aspect ratio
    im = ax.imshow(accuracy_matrix, cmap=cmap, aspect='equal', vmin=-1, vmax=1)
    
    # Add grid lines for better clarity
    ax.set_xticks(np.arange(-0.5, len(sample_labels), 1), minor=True)
    ax.set_yticks(np.arange(-0.5, k_value, 1), minor=True)
    ax.grid(which="minor", color="black", linestyle='-', linewidth=1)
    ax.tick_params(which="minor", size=0)
    
    # Add colorbar with custom labels
    cbar = plt.colorbar(im, ax=ax, shrink=0.8, ticks=[-1, 0, 1])
    cbar.set_ticklabels(['Missing', 'Incorrect', 'Correct'])
    cbar.set_label('Answer Accuracy', fontsize=12, fontweight='bold')
    
    # Set ticks and labels
    ax.set_xticks(range(len(sample_labels)))
    ax.set_xticklabels(sample_labels, rotation=45, ha='right', fontsize=8)
    
    ax.set_yticks(range(k_value))
    ax.set_yticklabels([f"Resp_{i+1}" for i in range(k_value)], fontsize=9)
    
    # Add title and labels
    ax.set_title(f'{model_name} - Answer Accuracy (k={k_value})\nRMU Responses vs Ground Truth', 
                 fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('Samples (WMDP Questions)', fontsize=12, fontweight='bold')
    ax.set_ylabel(f'RMU Generated Responses (1 to {k_value})', fontsize=12, fontweight='bold')
    
    # Add text annotations for smaller matrices
    if k_value <= 10 and len(sample_labels) <= 25:
        for i in range(k_value):
            for j in range(len(sample_labels)):
                value = accuracy_matrix[i, j]
                if value == 1:
                    text = "✓"
                    color = "white"
                elif value == 0:
                    text = "✗"
                    color = "white"
                else:
                    text = "?"
                    color = "black"
                
                ax.text(j, i, text, ha="center", va="center", 
                       color=color, fontsize=8, fontweight='bold')
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the plot
    filename = f"{output_dir}/{model_name}_k{k_value}_accuracy_heatmap.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"   💾 Saved heatmap: {filename}")
    
    # Clear the figure to free memory
    plt.close()
    
    return filename

def compute_accuracy_stats(accuracy_matrix, k_value, model_name):
    """Compute accuracy statistics for the matrix - now based on total samples"""
    total_samples = accuracy_matrix.shape[1]
    total_possible_responses = k_value * total_samples
    
    # Count different types of responses
    missing_responses = np.sum(accuracy_matrix == -1)
    incorrect_responses = np.sum(accuracy_matrix == 0)
    correct_responses = np.sum(accuracy_matrix == 1)
    
    # Calculate accuracy based on total possible responses (including missing)
    total_responses = missing_responses + incorrect_responses + correct_responses
    accuracy = correct_responses / total_possible_responses if total_possible_responses > 0 else 0.0
    
    # Per-sample accuracy (best response for each sample)
    sample_accuracies = []
    for col in range(accuracy_matrix.shape[1]):
        sample_col = accuracy_matrix[:, col]
        # Check if there's at least one correct answer for this sample
        has_correct = np.any(sample_col == 1)
        sample_accuracies.append(1 if has_correct else 0)
    
    pass_at_k = np.mean(sample_accuracies) if sample_accuracies else 0.0
    
    stats = {
        "model": model_name,
        "k": k_value,
        "total_samples": total_samples,
        "total_possible_responses": int(total_possible_responses),
        "missing_responses": int(missing_responses),
        "incorrect_responses": int(incorrect_responses),
        "correct_responses": int(correct_responses),
        "accuracy": float(accuracy),  # Based on total possible responses
        "pass_at_k": float(pass_at_k),  # Fraction of samples with at least one correct answer
        "samples_with_correct": int(np.sum(sample_accuracies))
    }
    
    return stats

def create_summary_heatmap(all_stats, models_to_analyze, k_list, output_dir):
    """Create a summary heatmap showing accuracy across models and k values"""
    models = models_to_analyze
    k_values = k_list
    
    # Create matrices for accuracy and pass@k
    accuracy_matrix = np.zeros((len(models), len(k_values)))
    pass_at_k_matrix = np.zeros((len(models), len(k_values)))
    
    for i, model in enumerate(models):
        for j, k in enumerate(k_values):
            # Find stats for this model and k
            model_k_stats = [s for s in all_stats if s["model"] == model and s["k"] == k]
            if model_k_stats:
                accuracy_matrix[i, j] = model_k_stats[0]["accuracy"]
                pass_at_k_matrix[i, j] = model_k_stats[0]["pass_at_k"]
    
    # Create subplots with square cells
    # Calculate size based on number of models and k values
    cell_size = 0.6
    fig_width = len(k_values) * cell_size * 2 + 4  # 2 subplots side by side
    fig_height = len(models) * cell_size + 2
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(fig_width, fig_height))
    
    # Overall accuracy
    im1 = ax1.imshow(accuracy_matrix, cmap='RdYlBu_r', aspect='equal', vmin=0, vmax=1)
    ax1.set_xticks(range(len(k_values)))
    ax1.set_xticklabels([f'k={k}' for k in k_values])
    ax1.set_yticks(range(len(models)))
    ax1.set_yticklabels(models)
    ax1.set_title('Overall Accuracy', fontweight='bold')
    ax1.set_xlabel('K Values')
    ax1.set_ylabel('Models')
    
    # Add grid lines
    ax1.set_xticks(np.arange(-0.5, len(k_values), 1), minor=True)
    ax1.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
    ax1.grid(which="minor", color="black", linestyle='-', linewidth=1)
    ax1.tick_params(which="minor", size=0)
    
    # Add value annotations
    for i in range(len(models)):
        for j in range(len(k_values)):
            text = ax1.text(j, i, f'{accuracy_matrix[i, j]:.3f}',
                           ha="center", va="center", color="white", fontweight='bold')
    
    # Pass@K accuracy
    im2 = ax2.imshow(pass_at_k_matrix, cmap='RdYlBu_r', aspect='equal', vmin=0, vmax=1)
    ax2.set_xticks(range(len(k_values)))
    ax2.set_xticklabels([f'k={k}' for k in k_values])
    ax2.set_yticks(range(len(models)))
    ax2.set_yticklabels(models)
    ax2.set_title('Pass@K Accuracy\n(≥1 correct per sample)', fontweight='bold')
    ax2.set_xlabel('K Values')
    ax2.set_ylabel('Models')
    
    # Add grid lines
    ax2.set_xticks(np.arange(-0.5, len(k_values), 1), minor=True)
    ax2.set_yticks(np.arange(-0.5, len(models), 1), minor=True)
    ax2.grid(which="minor", color="black", linestyle='-', linewidth=1)
    ax2.tick_params(which="minor", size=0)
    
    # Add value annotations
    for i in range(len(models)):
        for j in range(len(k_values)):
            text = ax2.text(j, i, f'{pass_at_k_matrix[i, j]:.3f}',
                           ha="center", va="center", color="white", fontweight='bold')
    
    # Add colorbars
    cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.8)
    cbar1.set_label('Accuracy')
    cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.8)
    cbar2.set_label('Pass@K Rate')
    
    plt.tight_layout()
    
    # Save summary heatmap
    summary_filename = f"{output_dir}/summary_accuracy_heatmap.png"
    plt.savefig(summary_filename, dpi=300, bbox_inches='tight')
    print(f"💾 Saved summary heatmap: {summary_filename}")
    plt.close()

def main(models_to_analyze, k_list, answers_dir, output_dir):
    """
    Main function to generate WMDP accuracy heatmaps
    
    Args:
        models_to_analyze: List of model names to analyze
        k_list: List of k values to analyze
        answers_dir: Directory containing the answer files
        output_dir: Directory to save output files
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    print("🎯 Starting WMDP Accuracy Heatmap Generation")
    print("=" * 60)
    print(f"📁 Input directory: {answers_dir}")
    print(f"📁 Output directory: {output_dir}")
    print(f"🤖 Models: {models_to_analyze}")
    print(f"📊 K values: {k_list}")
    print("=" * 60)
    
    all_stats = []
    
    for model_name in models_to_analyze:
        print(f"\n🤖 Processing {model_name}")
        print("-" * 40)
        
        # Load data
        orig_file = f"{answers_dir}/original_{model_name}.json"
        rmu_file = f"{answers_dir}/rmu_{model_name}.json"
        
        original_data = load_json_safely(orig_file)
        rmu_data = load_json_safely(rmu_file)
        
        if original_data is None or rmu_data is None:
            print(f"❌ Failed to load data for {model_name}")
            continue
        
        for k_value in k_list:
            print(f"\n📊 Generating accuracy heatmap for k={k_value}")
            
            # Extract accuracy matrix
            accuracy_matrix, sample_labels, sample_info = extract_accuracy_matrix_for_k(
                original_data, rmu_data, k_value
            )
            
            if accuracy_matrix is None:
                print(f"   ⚠️ No data available for k={k_value}")
                continue
            
            # Create heatmap
            heatmap_file = create_accuracy_heatmap(accuracy_matrix, k_value, model_name, sample_labels, output_dir)
            
            # Compute and store statistics
            stats = compute_accuracy_stats(accuracy_matrix, k_value, model_name)
            if stats:
                all_stats.append(stats)
                print(f"   📈 Overall Accuracy: {stats['accuracy']:.3f} ({stats['correct_responses']}/{stats['total_possible_responses']})")
                print(f"   🎯 Pass@{k_value}: {stats['pass_at_k']:.3f} ({stats['samples_with_correct']}/{stats['total_samples']} samples)")
                print(f"   📊 Missing: {stats['missing_responses']}, Incorrect: {stats['incorrect_responses']}, Correct: {stats['correct_responses']}")
    
    # Create summary heatmap
    if all_stats:
        print(f"\n📋 Creating summary heatmap...")
        create_summary_heatmap(all_stats, models_to_analyze, k_list, output_dir)
        
        # Save detailed statistics
        stats_file = f"{output_dir}/detailed_accuracy_statistics.json"
        with open(stats_file, 'w') as f:
            json.dump(all_stats, f, indent=2)
        print(f"💾 Saved detailed statistics: {stats_file}")
    
    print(f"\n✅ Accuracy heatmap generation complete!")
    print(f"📁 All files saved in: {output_dir}/")
    print(f"   - Individual heatmaps: {len([s for s in all_stats])} files")
    print(f"   - Summary heatmap: summary_accuracy_heatmap.png")
    print(f"   - Statistics: detailed_accuracy_statistics.json")
    
    # Print final summary
    print(f"\n📊 ACCURACY SUMMARY")
    print("=" * 50)
    for model_name in models_to_analyze:
        print(f"\n🤖 {model_name}")
        model_stats = [s for s in all_stats if s["model"] == model_name]
        for stats in sorted(model_stats, key=lambda x: x["k"]):
            k = stats["k"]
            acc = stats["accuracy"]
            pass_k = stats["pass_at_k"]
            print(f"   k={k:2d}: Accuracy={acc:.3f}, Pass@{k}={pass_k:.3f}")

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Generate WMDP accuracy heatmaps')
    
    parser.add_argument('--models', nargs='+', default=DEFAULT_MODELS_TO_ANALYZE,
                       help=f'List of model names to analyze (default: {DEFAULT_MODELS_TO_ANALYZE})')
    
    parser.add_argument('--k-values', nargs='+', type=int, default=DEFAULT_K_LIST,
                       help=f'List of k values to analyze (default: {DEFAULT_K_LIST})')
    
    parser.add_argument('--answers-dir', default=DEFAULT_ANSWERS_DIR,
                       help=f'Directory containing the answer files (default: {DEFAULT_ANSWERS_DIR})')
    
    parser.add_argument('--output-dir', default=DEFAULT_OUTPUT_DIR,
                       help=f'Directory to save output files (default: {DEFAULT_OUTPUT_DIR})')
    
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_arguments()
    main(args.models, args.k_values, args.answers_dir, args.output_dir) 