import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sentence_transformers import SentenceTransformer
import pandas as pd
import argparse
import os
import json
import glob
from math_verify import parse, verify

def load_jsonl_data(file_path, doc_id):
    """
    Load data from a JSONL file and find the document with the specified doc_id.

    Args:
        file_path: Path to the JSONL file
        doc_id: The document ID to find

    Returns:
        The document data if found, None otherwise
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                if data.get('doc_id', None) == doc_id:
                    return data
            except json.JSONDecodeError:
                continue
    return None

def visualize_responses(responses, correct_responses, method_name, question, output_dir):
    """
    Generate embeddings and visualize the responses without clustering.

    Args:
        responses: List of response texts
        correct_responses: List of boolean values indicating correctness
        method_name: Name of the decoding method
        question: The question text
        output_dir: Directory to save results
    """
    # Parameters
    embedding_model = "all-MiniLM-L6-v2"

    # Get embeddings
    print(f"Getting embeddings with model: {embedding_model}")
    model = SentenceTransformer(embedding_model)

    # Create response embeddings
    embeddings = model.encode(responses, show_progress_bar=True)

    # Scale embeddings
    # scaler = StandardScaler()
    # normalized_embeddings = scaler.fit_transform(embeddings)
    normalized_embeddings = embeddings

    # Project to 2D for visualization
    print("Projecting embeddings to 2D space")
    tsne = TSNE(n_components=2, random_state=42,
                perplexity=min(30, len(normalized_embeddings)-1),
                init='pca', learning_rate='auto')
    embeddings_2d = tsne.fit_transform(normalized_embeddings)

    # Create figure and axes
    fig, ax = plt.subplots(figsize=(25, 12))

    # Point sizes
    point_size = 600

    # Plot each point with color based on correctness
    correct_points = embeddings_2d[correct_responses]
    incorrect_points = embeddings_2d[~np.array(correct_responses)]

    # Scatter plot for correct answers (with 'o' marker)
    if len(correct_points) > 0:
        ax.scatter(
            correct_points[:, 0],
            correct_points[:, 1],
            s=point_size*1.2,
            c='green',
            alpha=0.8,
            marker='o',
            label='Correct Answer',
            edgecolors='black',
            linewidth=1.5
        )

    # Scatter plot for incorrect answers (with 'X' marker)
    if len(incorrect_points) > 0:
        ax.scatter(
            incorrect_points[:, 0],
            incorrect_points[:, 1],
            s=point_size,
            c='red',
            alpha=0.6,
            marker='X',
            label='Incorrect Answer',
            edgecolors='black',
            linewidth=0.5
        )

    # Remove all ticks from x and y axes
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Add title
    # plt.title(f"Response Embeddings - {method_name}", fontsize=24)

    # Add legend
    legend = ax.legend(loc='upper right', fontsize=42)
    legend.get_frame().set_alpha(0.7)

    # Add percentage of correct answers in the plot
    percent_correct = (sum(correct_responses) / len(correct_responses) * 100)
    plt.text(0.02, 0.98, f"Correct: {sum(correct_responses)}/{len(correct_responses)} ({percent_correct:.1f}%)",
             transform=ax.transAxes, fontsize=42, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))

    plt.tight_layout()

    # Save visualization
    viz_path = f"{output_dir}/{method_name}_response_embeddings.pdf"
    plt.savefig(viz_path, bbox_inches='tight')
    print(f"Visualization saved to {viz_path}")

    # Save results to CSV with correctness information
    data = {
        'response_id': range(len(responses)),
        'is_correct': correct_responses,
        'response_text': responses,
        'x_coord': embeddings_2d[:, 0],
        'y_coord': embeddings_2d[:, 1]
    }

    df = pd.DataFrame(data)
    results_path = f"{output_dir}/{method_name}_embedding_results.csv"
    df.to_csv(results_path, index=False)
    print(f"Results saved to {results_path}")

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='LLM Response Embedding Visualizer')
    parser.add_argument('--index', type=int, default=0,
                        help='Index of the question to visualize (doc_id)')
    parser.add_argument('--data_dir', type=str, default='combined_eval',
                        help='Root directory containing the decoding method folders')
    parser.add_argument('--output_dir', type=str, default='./visualization_results',
                        help='Directory to save results')

    args = parser.parse_args()

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)

    # Define the decoding methods and their directories
    decoding_methods = {
        'diverse_beam_search': 'diverse_beam_search',
        'semantic_guided_search': 'semantic_guided_search',
        'temperature': 'temperature'
    }

    # Find all samples_gsm8k files for each method
    doc_id = args.index
    question = None
    correct_answer = None

    for method_name, method_dir in decoding_methods.items():
        # Find the samples file
        pattern = os.path.join(args.data_dir, method_dir, 'Qwen__Qwen2.5-3B-Instruct', 'samples_gsm8k_*.jsonl')
        samples_files = glob.glob(pattern)

        if not samples_files:
            print(f"Warning: No samples file found for {method_name} at {pattern}")
            continue

        samples_file = samples_files[0]
        print(f"Processing {method_name} with file: {samples_file}")

        # Load the data
        data = load_jsonl_data(samples_file, doc_id)
        if not data:
            print(f"Warning: Document with doc_id {doc_id} not found in {samples_file}")
            continue

        # Extract question and correct answer if not already done
        if question is None:
            question = data['doc']['question']
            correct_answer = data['doc']['answer'].split('####')[-1].strip()
            print(f"Question: {question}")
            print(f"Correct Answer: {correct_answer}")

        # Extract responses - flatten the responses array
        responses = []
        for resp_group in data['resps']:
            for resp in resp_group[0]:
                responses.append(resp)

        print(f"Found {len(responses)} responses for {method_name}")

        # Verify correctness
        correct_responses = []
        gold = parse(correct_answer)

        for response in responses:
            try:
                answer = parse(response)
                is_correct = verify(gold, answer)
                correct_responses.append(is_correct)
            except Exception as e:
                print(f"Error verifying response: {e}")
                correct_responses.append(False)

        print(f"Found {sum(correct_responses)} correct responses out of {len(responses)} for {method_name}")

        # Visualize the responses
        visualize_responses(
            responses=responses,
            correct_responses=correct_responses,
            method_name=method_name,
            question=question,
            output_dir=args.output_dir
        )

    print(f"Completed visualization for question with doc_id: {doc_id}")

    # Create a combined visualization with all methods
    create_combined_visualization(args.output_dir, decoding_methods)

def create_combined_visualization(output_dir, decoding_methods):
    """
    Create a combined visualization showing embedding results from all methods.

    Args:
        output_dir: Directory containing individual method results
        decoding_methods: Dictionary mapping method names to directory names
    """
    # Create a figure with subplots
    fig, axes = plt.subplots(1, len(decoding_methods), figsize=(24, 8))

    # If only one method was processed, ensure axes is a list
    if len(decoding_methods) == 1:
        axes = [axes]

    # Process each method
    for i, method_name in enumerate(decoding_methods.keys()):
        # Try to load the CSV results
        results_path = f"{output_dir}/{method_name}_embedding_results.csv"

        try:
            df = pd.read_csv(results_path)

            # Extract data
            x_coords = df['x_coord'].values
            y_coords = df['y_coord'].values
            is_correct = df['is_correct'].values

            # Plot in the corresponding subplot
            ax = axes[i]

            # Plot correct points
            correct_mask = is_correct
            if np.any(correct_mask):
                ax.scatter(
                    x_coords[correct_mask],
                    y_coords[correct_mask],
                    s=100,
                    c='green',
                    alpha=0.8,
                    marker='o',
                    edgecolors='black',
                    linewidth=1
                )

            # Plot incorrect points
            incorrect_mask = ~correct_mask
            if np.any(incorrect_mask):
                ax.scatter(
                    x_coords[incorrect_mask],
                    y_coords[incorrect_mask],
                    s=80,
                    c='red',
                    alpha=0.6,
                    marker='X',
                    edgecolors='black',
                    linewidth=0.5
                )

            # Set title and remove ticks
            ax.set_title(method_name, fontsize=16)
            ax.set_xticks([])
            ax.set_yticks([])

            # Add percentage text
            percent_correct = (sum(is_correct) / len(is_correct) * 100)
            ax.text(0.05, 0.95, f"{percent_correct:.1f}% correct",
                    transform=ax.transAxes, fontsize=14,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))

        except Exception as e:
            print(f"Error creating combined plot for {method_name}: {e}")
            axes[i].text(0.5, 0.5, f"No data for {method_name}",
                         ha='center', va='center', fontsize=14)
            axes[i].set_xticks([])
            axes[i].set_yticks([])

    # Set the main title
    plt.suptitle("Response Embeddings Comparison Across Methods", fontsize=20)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for the suptitle

    # Save the combined visualization
    combined_path = f"{output_dir}/combined_methods_comparison.pdf"
    plt.savefig(combined_path, bbox_inches='tight')
    print(f"Combined visualization saved to {combined_path}")

if __name__ == "__main__":
    main()