#!/usr/bin/env python3
"""
Generate heatmaps visualizing the density/quality of retrieved knowledge as top-k increases.

This script creates two heatmaps:
1. GPQAPhysics: 86 questions × 10 top-k values
2. GPQA Chemistry: 93 questions × 10 top-k values

Each heatmap shows cumulative average relevance scores, demonstrating that
quality decreases as top-k increases (introducing more irrelevant knowledge).
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import hashlib
import sys
import os
from pathlib import Path

# Add the src directory to the path to import dataset classes
sys.path.append('/home/<anonymous>/projects/Schema_ICL/schema_instructor')

def generate_hash(input_string):
    """Generate a 36-character hash from input string (same as Dataset.generate_hash)."""
    sha256_hash = hashlib.sha256(input_string.encode()).hexdigest()
    return sha256_hash[:36]

def load_dataset_question_ids(dataset_name):
    """
    Load question IDs from the actual dataset files in correct order.
    
    Args:
        dataset_name: Either 'physics' or 'chemistry'
    
    Returns:
        List of question IDs in the order they appear in the dataset
    """
    base_path = '/home/<anonymous>/projects/Schema_ICL/schema_instructor/src/entity/datasets/raw_files'
    
    if dataset_name == 'physics':
        csv_path = f'{base_path}/gpqa_diamond_physics.csv'
    elif dataset_name == 'chemistry':
        csv_path = f'{base_path}/gpqa_chemistry.csv'
    else:
        raise ValueError("dataset_name must be 'physics' or 'chemistry'")
    
    df = pd.read_csv(csv_path)
    question_ids = []
    
    for i, row in df.iterrows():
        question_id = generate_hash(row['Question'])
        question_ids.append(question_id)
    
    return question_ids

def load_mapping_data(filepath):
    """Load the rerank mapping JSON data."""
    with open(filepath, 'r') as f:
        return json.load(f)


def calculate_cumulative_averages(mapping_data, question_ids):
    """
    Calculate marginal (individual) relevance scores for each question and top-k position.
    
    Args:
        mapping_data: Dictionary with question IDs as keys and lists of similar questions
        question_ids: List of question IDs in the order they appear in the dataset
    
    Returns:
        numpy array of shape (len(question_ids), 10) with individual relevance scores
    """
    # Initialize heatmap data matrix
    heatmap_data = np.zeros((len(question_ids), 10))
    
    for q_idx, question_id in enumerate(question_ids):
        if question_id not in mapping_data:
            print(f"Warning: Question ID {question_id} not found in mapping data")
            continue
            
        similar_questions = mapping_data[question_id]
        
        # Store individual relevance scores for top-k = 1 to 10
        for k in range(10):  # top-k from 1 to 10
            heatmap_data[q_idx, k] = similar_questions[k]['relevance_score']
    
    return heatmap_data


def create_heatmap(data, title, dataset_name, num_questions):
    """
    Create and save a heatmap visualization.
    
    Args:
        data: numpy array with heatmap values
        title: Title for the heatmap
        dataset_name: Name for saving the file
        num_questions: Number of questions in the dataset
    """
    # Create figure and axis
    plt.figure(figsize=(16, 8))
    
    # Create heatmap with custom colormap (light red to dark red)
    # Transpose and flip vertically so top-k=1 is at the bottom
    heatmap_data = np.flipud(data.T)
    
    ax = sns.heatmap(
        heatmap_data,
        cmap='Reds',
        cbar_kws={'label': 'Marginal Relevance Score'},
        xticklabels=range(1, num_questions + 1),
        yticklabels=range(10, 0, -1),  # Reverse order so top-k=1 is at bottom
        vmin=0,
        vmax=1,
        linewidths=0.1,
        linecolor='white'
    )
    
    # Customize the plot
    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Question Number', fontsize=14)
    plt.ylabel('Top-k', fontsize=14)
    
    # Adjust x-axis ticks for readability
    if num_questions > 50:
        # Show every 10th question number for readability
        tick_positions = range(0, num_questions, 10)
        tick_labels = range(1, num_questions + 1, 10)
        plt.xticks(tick_positions, tick_labels)
    
    
    plt.tight_layout()
    
    # Save the heatmap
    output_path = f'/home/<anonymous>/projects/Schema_ICL/schema_instructor/mappings/heatmap_{dataset_name}.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved heatmap: {output_path}")
    
    plt.close()  # Close the figure instead of showing it
    return output_path


def main():
    """Main function to generate both heatmaps."""
    mappings_dir = Path('/home/<anonymous>/projects/Schema_ICL/schema_instructor/mappings')
    
    # Load physics data and generate heatmap
    print("Processing GPQA Physics data...")
    physics_question_ids = load_dataset_question_ids('physics')
    physics_data = load_mapping_data(mappings_dir / 'rerank_mapping_physics.json')
    physics_heatmap = calculate_cumulative_averages(physics_data, physics_question_ids)
    
    physics_path = create_heatmap(
        physics_heatmap,
        'GPQA Physics: Density of Prior Knowledge via RAG Rerank',
        'gpqa_physics',
        len(physics_question_ids)
    )
    
    # Load chemistry data and generate heatmap
    print("Processing GPQA Chemistry data...")
    chemistry_question_ids = load_dataset_question_ids('chemistry')
    chemistry_data = load_mapping_data(mappings_dir / 'rerank_mapping_chemistry.json')
    chemistry_heatmap = calculate_cumulative_averages(chemistry_data, chemistry_question_ids)
    
    chemistry_path = create_heatmap(
        chemistry_heatmap,
        'GPQA Chemistry: Density of Prior Knowledge via RAG Rerank',
        'gpqa_chemistry',
        len(chemistry_question_ids)
    )
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    print(f"GPQA Physics - Shape: {physics_heatmap.shape} (Expected: 86 questions)")
    print(f"  Actual questions processed: {len(physics_question_ids)}")
    print(f"  Top-1 avg: {physics_heatmap[:, 0].mean():.4f}")
    print(f"  Top-10 avg: {physics_heatmap[:, 9].mean():.4f}")
    print(f"  Quality drop: {(physics_heatmap[:, 0].mean() - physics_heatmap[:, 9].mean()):.4f}")
    
    print(f"\nGPQA Chemistry - Shape: {chemistry_heatmap.shape} (Expected: 93 questions)")
    print(f"  Actual questions processed: {len(chemistry_question_ids)}")
    print(f"  Top-1 avg: {chemistry_heatmap[:, 0].mean():.4f}")
    print(f"  Top-10 avg: {chemistry_heatmap[:, 9].mean():.4f}")
    print(f"  Quality drop: {(chemistry_heatmap[:, 0].mean() - chemistry_heatmap[:, 9].mean()):.4f}")
    
    print(f"\nHeatmaps saved:")
    print(f"  - {physics_path}")
    print(f"  - {chemistry_path}")


if __name__ == "__main__":
    main()