#!/usr/bin/env python3
"""
Cluster reasoning sentences separately for each model-dataset combination.
"""
import os
import json
import re
import sys
import numpy as np
import umap
import matplotlib.pyplot as plt
from sklearn.cluster import HDBSCAN
from sklearn.preprocessing import normalize
from openai import OpenAI
from tqdm import tqdm

# Fix Unicode encoding issues on Windows
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

# ------------------------------
# Config
# ------------------------------
# Response files - all available model-dataset combinations
RESPONSE_FILES = [
    "responses_gpt_4o_mini_GSM8K.jsonl",
    "responses_gpt_4o_mini_ASDiv.jsonl",
    "responses_gpt_4o_mini_SVAMP.jsonl",
    "responses_gpt_3.5_turbo_1106_ASDiv.jsonl",
    "responses_gpt_3.5_turbo_1106_SVAMP.jsonl"
]

# Models
EMBEDDING_MODEL = "text-embedding-3-large"
LABELING_MODEL = "gpt-4o-mini"

# HDBSCAN parameters
MIN_CLUSTER_SIZE = 6
MIN_SAMPLES = 3

# ------------------------------
# Load API key & Initialize Client
# ------------------------------
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY environment variable is not set")
client = OpenAI(api_key=api_key)

# ------------------------------
# Helper Functions
# ------------------------------
def extract_number_from_answer(answer_str):
    """Extract numerical answer from ground truth string."""
    if not answer_str:
        return None

    answer_str = str(answer_str).strip()

    # For GSM8K format: "#### 21"
    if '####' in answer_str:
        match_str = answer_str.split('####')[-1].strip()
    else:
        match_str = answer_str

    # Remove units in parentheses like "(pies)", "(books)", etc
    match_str = re.sub(r'\s*\([^)]*\)\s*', '', match_str)

    # Handle ratios (e.g., "2:3")
    if ':' in match_str:
        return match_str.strip()

    # Remove commas and convert to float
    match_str = match_str.replace(',', '').strip()

    try:
        return float(match_str)
    except (ValueError, TypeError):
        # Return as string for non-numeric values (like ratios)
        return match_str if match_str else None

def safe_to_float(value):
    """Safely convert a value to float or return as string for non-numeric."""
    if value is None:
        return None

    value_str = str(value).strip()

    # Handle ratios
    if ':' in value_str:
        return value_str

    try:
        return float(value_str)
    except (ValueError, TypeError):
        return value_str if value_str else None

def is_correct(ground_truth, model_answer, dataset=None):
    """Check if model answer matches ground truth."""
    # For MathQA, ground truth is a letter (a, b, c, d, e)
    if dataset == "MathQA":
        gt_str = str(ground_truth).strip().lower()
        ma_str = str(model_answer).strip().lower() if model_answer else ""
        return gt_str == ma_str

    # For other datasets, compare numerical values or strings
    gt = extract_number_from_answer(ground_truth)
    ma = safe_to_float(model_answer)

    if gt is None or ma is None:
        return False

    # If both are strings (e.g., ratios), do string comparison
    if isinstance(gt, str) and isinstance(ma, str):
        return gt.strip().lower() == ma.strip().lower()

    # If both are numbers, use close comparison
    try:
        gt_float = float(gt)
        ma_float = float(ma)
        return np.isclose(gt_float, ma_float, rtol=1e-5)
    except (ValueError, TypeError):
        return False

def split_into_sentences(text):
    """Split a reasoning trace into individual sentences."""
    if not text:
        return []
    # Split by common sentence delimiters
    sentences = re.split(r'(?<=[.!?])\s+|\n+', text)
    cleaned_sentences = [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10]
    return cleaned_sentences

def generate_embeddings_batch(sentences, batch_size=2000):
    """Generate embeddings in batches."""
    all_embeddings = []
    for i in tqdm(range(0, len(sentences), batch_size), desc="Generating Embeddings"):
        batch = sentences[i:i + batch_size]
        response = client.embeddings.create(input=batch, model=EMBEDDING_MODEL)
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
    return np.array(all_embeddings)

def cluster_and_label_combination(response_file, model_name, dataset_name):
    """Cluster and label reasoning sentences for a specific model-dataset combination."""

    print(f"\n{'=' * 80}")
    print(f"CLUSTERING: {model_name} + {dataset_name}")
    print(f"{'=' * 80}")

    # Load responses
    print("Loading responses...")
    responses = []
    with open(response_file, "r", encoding="utf-8") as f:
        for line in f:
            entry = json.loads(line)
            if entry.get("dataset") == dataset_name:
                responses.append(entry)

    print(f"Loaded {len(responses)} responses for {model_name} + {dataset_name}")

    # Extract sentences
    print("Extracting reasoning sentences...")
    all_sentences = []
    sentence_to_response_map = []
    response_correctness_map = {}

    for response_entry in responses:
        response_id = response_entry['id']
        ground_truth = response_entry.get('ground_truth')
        model_answer = response_entry.get('model_answer')

        response_correctness_map[response_id] = is_correct(ground_truth, model_answer, dataset_name)

        sentences = split_into_sentences(response_entry.get("response", ""))
        for sentence in sentences:
            all_sentences.append(sentence)
            sentence_to_response_map.append(response_entry)

    print(f"Extracted {len(all_sentences)} sentences")

    if len(all_sentences) < MIN_CLUSTER_SIZE * 2:
        print(f"Not enough sentences to cluster (need at least {MIN_CLUSTER_SIZE * 2})")
        return None

    # Generate embeddings
    print(f"Generating embeddings using {EMBEDDING_MODEL}...")
    sentence_embeddings = generate_embeddings_batch(all_sentences)

    # Normalize and cluster
    print("Normalizing and clustering...")
    normalized_embeddings = normalize(sentence_embeddings, norm='l2', axis=1)

    clusterer = HDBSCAN(min_cluster_size=MIN_CLUSTER_SIZE, min_samples=MIN_SAMPLES, metric='euclidean')
    cluster_labels = clusterer.fit_predict(normalized_embeddings)

    n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
    n_noise = np.sum(cluster_labels == -1)
    print(f"Found {n_clusters} clusters and {n_noise} noise points")

    # Auto-label and analyze clusters
    print("Auto-labeling clusters...")
    cluster_report = []

    for cid in tqdm(range(n_clusters), desc="Labeling Clusters"):
        cluster_indices = np.where(cluster_labels == cid)[0]
        cluster_sentences = [all_sentences[i] for i in cluster_indices]

        # Calculate correctness rate
        num_correct = sum(1 for i in cluster_indices
                         if response_correctness_map.get(sentence_to_response_map[i]['id'], False))
        correctness_rate = (num_correct / len(cluster_indices)) * 100 if len(cluster_indices) > 0 else 0

        # Sample sentences for labeling
        sample_size = min(10, len(cluster_sentences))
        sample_indices = np.random.choice(len(cluster_sentences), sample_size, replace=False)
        sample_sentences = [cluster_sentences[i] for i in sample_indices]

        # Generate label
        prompt = "The following are examples of a single reasoning sentence from a math problem. Provide a 4-8 word label that describes the common pattern or action.\n\n"
        prompt += "EXAMPLES:\n" + "\n".join(f"- \"{s}\"" for s in sample_sentences)
        prompt += "\n\nLABEL:"

        response = client.chat.completions.create(
            model=LABELING_MODEL,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=20
        )
        label_text = response.choices[0].message.content.strip().replace('"', '')

        cluster_report.append({
            "cluster_id": cid,
            "auto_label": label_text,
            "sentence_count": len(cluster_sentences),
            "correctness_percentage": round(correctness_rate, 2),
            "sample_sentences": sample_sentences
        })

    cluster_report.sort(key=lambda x: x["sentence_count"], reverse=True)

    # Save results
    output_file = f"clusters_{model_name}_{dataset_name}.json"
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump({
            "model": model_name,
            "dataset": dataset_name,
            "n_clusters": int(n_clusters),  # Convert numpy int to Python int
            "n_noise": int(n_noise),  # Convert numpy int to Python int
            "total_sentences": len(all_sentences),
            "clusters": cluster_report
        }, f, indent=2)

    print(f"Saved clustering results to {output_file}")

    # Print summary
    print(f"\n{'ID':<6} {'Count':<8} {'Accuracy %':<12} {'Label'}")
    print("-" * 70)
    for report in cluster_report[:10]:  # Show top 10
        print(f"{report['cluster_id']:<6} {report['sentence_count']:<8} {report['correctness_percentage']:<12.1f} {report['auto_label']}")

    return len(cluster_report)

def main():
    print("=" * 80)
    print("CLUSTERING REASONING SENTENCES PER MODEL-DATASET COMBINATION")
    print("=" * 80)

    total_clusters = 0

    for response_file in RESPONSE_FILES:
        if not os.path.exists(response_file):
            print(f"\nWarning: {response_file} not found, skipping...")
            continue

        # Extract model name and dataset from filename
        filename = response_file.replace('responses_', '').replace('.jsonl', '')
        parts = filename.rsplit('_', 1)

        if len(parts) == 2:
            model_name, dataset_name = parts
        else:
            print(f"\nSkipping {response_file} - cannot parse model and dataset from filename")
            continue

        n_clusters = cluster_and_label_combination(response_file, model_name, dataset_name)
        if n_clusters:
            total_clusters += n_clusters

    print("\n" + "=" * 80)
    print(f"CLUSTERING COMPLETE")
    print(f"Total clusters created across all combinations: {total_clusters}")
    print("=" * 80)

if __name__ == "__main__":
    main()
