import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from sentence_transformers import SentenceTransformer
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import euclidean_distances
import pandas as pd
import argparse
# Import the math_verify library
from math_verify import parse, verify

def run_response_clustering(question, correct_answer="5", model_name="Qwen/Qwen2.5-3B-Instruct", k=500,
                            n_clusters=5, temperature=1.0, top_p=0.95,
                            repetition_penalty=1.1, max_tokens=1024,
                            embedding_model="all-MiniLM-L6-v2", output_dir="./results"):
    """
    Run the complete pipeline for generating, clustering, and visualizing LLM responses.
    Enhanced with math verification functionality.

    Args:
        question: Question to generate responses for
        correct_answer: The correct answer to the math problem (default: "5")
        model_name: Name of the VLLM model to use
        k: Number of responses to generate
        n_clusters: Number of clusters to form
        temperature: Sampling temperature
        top_p: Top-p sampling parameter
        repetition_penalty: Repetition penalty for generation
        max_tokens: Maximum tokens to generate per response
        embedding_model: Name of the SentenceTransformer model for embeddings
        output_dir: Directory to save results

    Returns:
        Dictionary containing all results from the pipeline
    """
    import os
    os.makedirs(output_dir, exist_ok=True)

    # 1. Set up model and tokenizer
    print(f"Loading model and tokenizer: {model_name}")
    llm = LLM(model_name, gpu_memory_utilization=0.6)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 2. Apply template to question
    messages = [{"role": "user", "content": question}]
    try:
        # Try using chat template if supported
        templated_question = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    except:
        # Fallback for models without chat template
        templated_question = f"<|im_start|>user\n{question}\n<|im_end|>\n<|im_start|>assistant\n"

    print(f"Applied template to question:\n{templated_question[:100]}...")

    # 3. Generate responses - all at once without batching
    print(f"Generating {k} responses...")

    sampling_params = SamplingParams(
        n=k,  # Number of responses to generate
        temperature=temperature,
        top_p=top_p,
        top_k=-1,
        # repetition_penalty=repetition_penalty,
        max_tokens=max_tokens,
        logprobs=2,
        detokenize=True,
        skip_special_tokens=True
    )

    # Generate all responses in one go
    outputs = llm.generate([templated_question], sampling_params)

    responses = []
    logprobs = []

    for output in outputs[0].outputs:
        generated_text = output.text
        responses.append(generated_text)

        # Get token count and cumulative log probability
        token_count = len(output.token_ids)
        cumulative_logprob = output.cumulative_logprob

        # If you want to calculate per-token average log probability
        avg_logprob = cumulative_logprob / token_count if token_count > 0 else 0.0

        # You could store this separately if needed
        logprobs.append(avg_logprob)

    print(f"Generated {len(responses)} responses")

    # 4. Verify response correctness using math_verify
    print("Verifying mathematical correctness of responses...")
    gold = parse(correct_answer)
    correct_responses = []

    for i, response in enumerate(responses):
        try:
            # Try to extract the answer from the response
            # This might need to be adjusted based on the response format
            answer = parse(response)
            is_correct = verify(gold, answer)
            correct_responses.append(is_correct)
        except Exception as e:
            print(f"Error verifying response {i}: {e}")
            correct_responses.append(False)

    print(f"Found {sum(correct_responses)} correct responses out of {len(responses)}")

    # 5. Get embeddings
    print(f"Getting embeddings with model: {embedding_model}")
    model = SentenceTransformer(embedding_model)

    # Create question+response pairs for embedding
    combined_texts = [f"{templated_question} {response}" for response in responses]
    embeddings = model.encode(combined_texts, show_progress_bar=True)

    # 6. Cluster embeddings
    print(f"Clustering embeddings into {n_clusters} clusters")
    scaler = StandardScaler()
    normalized_embeddings = scaler.fit_transform(embeddings)

    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(normalized_embeddings)

    # Calculate silhouette score to evaluate clustering quality
    if len(np.unique(cluster_labels)) > 1 and len(cluster_labels) > 2:
        silhouette_avg = silhouette_score(normalized_embeddings, cluster_labels)
        print(f"Silhouette Score: {silhouette_avg:.4f}")

    # 7. Project to 2D for visualization
    print("Projecting embeddings to 2D space")
    # Project embeddings and cluster centers together
    combined_embeddings = np.vstack([normalized_embeddings, kmeans.cluster_centers_])

    tsne = TSNE(n_components=2, random_state=42,
                perplexity=min(30, len(combined_embeddings)-1),
                init='pca', learning_rate='auto')
    combined_2d = tsne.fit_transform(combined_embeddings)

    # Split the projected data
    embeddings_2d = combined_2d[:len(normalized_embeddings)]
    centers_2d = combined_2d[len(normalized_embeddings):]

    # 8. Visualize clusters with correctness indication
    # Create figure and axes explicitly
    fig, ax = plt.subplots(figsize=(25, 12))

    # Convert logprobs to sizes using LINEAR relationship
    min_size = 200
    max_size = 2200

    # LINEAR size calculation (instead of using normalized_logprobs ** 2)
    logprobs_array = np.array(logprobs)
    min_logprob = np.min(logprobs_array)
    max_logprob = np.max(logprobs_array)
    normalized_logprobs = (logprobs_array - min_logprob) / (max_logprob - min_logprob + 1e-10)
    sizes = min_size + (normalized_logprobs ** 2) * (max_size - min_size)


    # Use a colormap for different clusters
    unique_clusters = np.unique(cluster_labels)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_clusters)))

    # Plot each point - Keep colors consistent for clusters but change markers for correct answers
    for i, (x, y) in enumerate(embeddings_2d):
        cluster = cluster_labels[i]
        color = colors[cluster]
        size = sizes[i]

        # Use different marker for correct answers but keep cluster color
        if correct_responses[i]:
            # Highlight correct responses with a star marker
            ax.scatter(x, y, s=size*1.2, c=[color], alpha=0.9, marker='o',
                       edgecolors='black', linewidth=1.5)
        else:
            # Regular responses with circle marker
            ax.scatter(x, y, s=size, c=[color], alpha=0.7,
                       edgecolors='k', linewidth=0.5, marker='X')

    # Remove all ticks from x and y axes
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Add a colorbar legend for cluster colors without ticks
    sm = plt.cm.ScalarMappable(cmap=plt.cm.tab10, norm=plt.Normalize(vmin=0, vmax=len(unique_clusters)-1))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, ticks=[])  # Remove colorbar ticks
    cbar.set_label('Cluster', fontsize=52)

    # Add legend entries for markers and sizes
    # Create custom legend entries
    legend_elements = []

    # Markers for correct/incorrect responses
    legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', label='Correct Answer',
                                      markerfacecolor='gray', markersize=30))
    legend_elements.append(plt.Line2D([0], [0], marker='X', color='w', label='Incorrect Answer',
                                      markerfacecolor='gray', markersize=30))

    # Size examples for logprobs
    # for i, prob in enumerate(np.linspace(min_logprob, max_logprob, 3)):
    #     norm_prob = (prob - min_logprob) / (max_logprob - min_logprob + 1e-10)
    #     size = min_size + norm_prob * (max_size - min_size)  # LINEAR relationship
    #     size_in_points = np.sqrt(size) / 2  # Convert area to point size for legend
    #     legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
    #                                       label=f'Logprob: {prob:.4f}',
    #                                       markerfacecolor='gray', markersize=size_in_points))

    # Add the custom legend
    legend = ax.legend(handles=legend_elements,
                       loc='upper left', fontsize=43, title_fontsize=43)
    legend.get_frame().set_alpha(0.7)

    # No axis titles
    plt.tight_layout()

    # Save visualization
    viz_path = f"{output_dir}/response_clusters.pdf"
    plt.savefig(viz_path, bbox_inches='tight')
    print(f"Visualization saved to {viz_path}")

    # 9. Save results to CSV with correctness information
    data = {
        'response_id': range(len(responses)),
        'cluster': cluster_labels,
        'avg_logprob': logprobs,
        'is_correct': correct_responses,
        'response_text': responses
    }

    df = pd.DataFrame(data)
    results_path = f"{output_dir}/clustering_results.csv"
    df.to_csv(results_path, index=False)
    print(f"Results saved to {results_path}")

    # Also save cluster statistics with correctness rates
    cluster_stats = df.groupby('cluster').agg({
        'response_id': 'count',
        'avg_logprob': 'mean',
        'is_correct': 'sum'
    }).rename(columns={'response_id': 'count', 'is_correct': 'correct_count'})

    # Add percentage of correct answers in each cluster
    cluster_stats['correct_percent'] = (cluster_stats['correct_count'] / cluster_stats['count'] * 100).round(2)

    stats_path = f"{output_dir}/cluster_statistics.csv"
    cluster_stats.to_csv(stats_path)
    print(f"Cluster statistics saved to {stats_path}")

    # 10. Find representative responses for each cluster, prioritizing correct ones
    representative_indices = {}

    for cluster_id in range(n_clusters):
        # Get all responses in this cluster
        cluster_indices = np.where(cluster_labels == cluster_id)[0]

        if len(cluster_indices) > 0:
            # Get correct responses in this cluster first
            correct_cluster_indices = [idx for idx in cluster_indices if correct_responses[idx]]

            if correct_cluster_indices:
                # If there are correct responses, prioritize them
                cluster_embeddings = normalized_embeddings[correct_cluster_indices]
                distances = euclidean_distances([kmeans.cluster_centers_[cluster_id]], cluster_embeddings)[0]
                closest_correct = [correct_cluster_indices[i] for i in np.argsort(distances)[:2]]

                # Then add one representative incorrect response if available
                incorrect_cluster_indices = [idx for idx in cluster_indices if not correct_responses[idx]]
                if incorrect_cluster_indices:
                    cluster_embeddings = normalized_embeddings[incorrect_cluster_indices]
                    distances = euclidean_distances([kmeans.cluster_centers_[cluster_id]], cluster_embeddings)[0]
                    closest_incorrect = [incorrect_cluster_indices[np.argmin(distances)]]
                    representative_indices[cluster_id] = closest_correct + closest_incorrect
                else:
                    # If no incorrect responses, use top 3 correct ones
                    closest_correct = [correct_cluster_indices[i] for i in np.argsort(distances)[:3]]
                    representative_indices[cluster_id] = closest_correct
            else:
                # If no correct responses, use the closest 3 to center
                cluster_embeddings = normalized_embeddings[cluster_indices]
                distances = euclidean_distances([kmeans.cluster_centers_[cluster_id]], cluster_embeddings)[0]
                representative_indices[cluster_id] = [cluster_indices[i] for i in np.argsort(distances)[:3]]

    # Print representative responses for each cluster
    print("\nRepresentative Responses by Cluster:")
    for cluster_id, indices in representative_indices.items():
        print(f"\nCluster {cluster_id}:")
        print(f"Correct responses in cluster: {cluster_stats.loc[cluster_id, 'correct_count']}/{cluster_stats.loc[cluster_id, 'count']} ({cluster_stats.loc[cluster_id, 'correct_percent']}%)")
        print("=" * 50)
        for rank, idx in enumerate(indices):
            correctness_indicator = "✓ CORRECT" if correct_responses[idx] else "✗ INCORRECT"
            print(f"  Representative #{rank+1} (ID: {idx}, Logprob: {logprobs[idx]:.4f}) - {correctness_indicator}:")
            print("  " + responses[idx][:150].replace('\n', '\n  ') + "..."
                  if len(responses[idx]) > 150 else responses[idx].replace('\n', '\n  '))
            print("-" * 50)

    return {
        'responses': responses,
        'logprobs': logprobs,
        'cluster_labels': cluster_labels,
        'embeddings_2d': embeddings_2d,
        'representative_indices': representative_indices,
        'original_embeddings': normalized_embeddings,
        'cluster_centers': kmeans.cluster_centers_,
        'centers_2d': centers_2d,
        'correct_responses': correct_responses
    }

if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='LLM Response Clustering Visualizer with Math Verification')
    # parser.add_argument('--question', type=str, default="Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?",
    #                     help='Question to generate responses for')
    # parser.add_argument('--correct_answer', type=str, default="5",
    #                     help='The correct answer to verify against')
    parser.add_argument('--question', type=str, default="A rectangular band formation is a formation with $m$ band members in each of $r$ rows, where $m$ and $r$ are integers. A particular band has less than 100 band members. The director arranges them in a rectangular formation and finds that he has two members left over. If he increases the number of members in each row by 1 and reduces the number of rows by 2, there are exactly enough places in the new formation for each band member. What is the largest number of members the band could have?",
                        help='Question to generate responses for')
    parser.add_argument('--correct_answer', type=str, default="98",
                        help='The correct answer to verify against')
    parser.add_argument('--model', type=str, default="Qwen/Qwen2.5-3B-Instruct",
                        help='VLLM model name')
    parser.add_argument('--k', type=int, default=100,
                        help='Number of responses to generate')
    parser.add_argument('--clusters', type=int, default=10,
                        help='Number of clusters')
    parser.add_argument('--temperature', type=float, default=1.0,
                        help='Sampling temperature')
    parser.add_argument('--top_p', type=float, default=1.0,
                        help='Top-p sampling parameter')
    parser.add_argument('--repetition_penalty', type=float, default=1.1,
                        help='Repetition penalty')
    parser.add_argument('--max_tokens', type=int, default=1024,
                        help='Maximum tokens to generate')
    parser.add_argument('--embedding_model', type=str, default="all-MiniLM-L6-v2",
                        help='SentenceTransformer model name')
    parser.add_argument('--output_dir', type=str, default="./tmp_results",
                        help='Directory to save results')

    args = parser.parse_args()

    # Run the pipeline
    results = run_response_clustering(
        question=args.question,
        correct_answer=args.correct_answer,
        model_name=args.model,
        k=args.k,
        n_clusters=args.clusters,
        temperature=args.temperature,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty,
        max_tokens=args.max_tokens,
        embedding_model=args.embedding_model,
        output_dir=args.output_dir
    )

    print(f"Completed clustering analysis for question: '{args.question}'")
    print(f"Total correct responses: {sum(results['correct_responses'])} out of {len(results['responses'])}")