import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import TfidfVectorizer
import random
import json
import re
import nltk
from nltk.util import ngrams
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

class GSM8kResponseFilteringEvaluator:
    def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct", gpu_memory_utilization=0.9):
        """
        Initialize the evaluator with the specified model.
        """
        self.model_name = model_name

        # Load the model and tokenizer
        print(f"Loading model: {model_name}")
        self.model = LLM(
            model=model_name,
            gpu_memory_utilization=gpu_memory_utilization,
            trust_remote_code=True,
            max_model_len=18192,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Load sentence transformer for semantic clustering
        self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')

        # Try to download NLTK data if not already present
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt')

    def format_prompt(self, prompt_text):
        """Format prompt with appropriate chat template"""
        messages = [{"role": "user", "content": prompt_text}]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def generate_responses_batch(self, prompts, n_samples=100, max_tokens=18092, temperature=1.0):
        """
        Generate multiple responses for each prompt in a batch.
        """
        formatted_prompts = [self.format_prompt(prompt) for prompt in prompts]

        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=0.95,
            top_k=40,
            n=n_samples,
            logprobs=19,  # Get logprobs for top 19 tokens at each position
            skip_special_tokens=True,
        )

        # Generate all responses in one batch
        print(f"Generating {n_samples} responses for {len(prompts)} prompts...")
        outputs = self.model.generate(formatted_prompts, sampling_params)

        # Process the outputs
        all_responses = {}
        all_logprobs = {}

        for i, prompt_outputs in enumerate(outputs):
            prompt = prompts[i]
            responses = []
            logprobs_list = []

            for output in prompt_outputs.outputs:
                response_text = output.text
                token_ids = output.token_ids

                # Calculate average logprob
                cumulative_logprob = 0
                token_position_logprobs = []

                if hasattr(output, 'logprobs') and output.logprobs:
                    for pos, token_logprobs in enumerate(output.logprobs):
                        if pos < len(token_ids):
                            selected_token = token_ids[pos]
                            if selected_token in token_logprobs:
                                logprob = token_logprobs[selected_token].logprob
                                token_position_logprobs.append(logprob)
                                cumulative_logprob += logprob

                # Calculate average logprob
                if token_position_logprobs:
                    avg_logprob = cumulative_logprob / len(token_position_logprobs)
                else:
                    avg_logprob = 0

                responses.append(response_text)
                logprobs_list.append(avg_logprob)

            all_responses[prompt] = responses
            all_logprobs[prompt] = logprobs_list

        return all_responses, all_logprobs

    def extract_final_answer(self, response):
        """
        Extract the final answer from a GSM8k response.
        """
        # For GSM8K, look for the final answer pattern (####)
        match = re.search(r'####\s*(\d+)', response)
        if match:
            return match.group(1).strip()

        # If no #### pattern, look for the last number in the response
        numbers = re.findall(r'\d+', response)
        if numbers:
            return numbers[-1].strip()

        return response.strip()

    def extract_all_answers(self, responses):
        """
        Extract all final answers from a list of responses.
        """
        answers = []
        for response in responses:
            answers.append(self.extract_final_answer(response))
        return answers

    def is_answer_correct(self, predicted, reference):
        """
        Compare the predicted answer with the reference answer for GSM8k.
        """
        # Convert reference to string if it's not already
        reference = str(reference) if reference is not None else ""

        # Extract the final answer from predictions
        extracted_answer = self.extract_final_answer(predicted)

        # Extract the number from reference if needed
        ref_match = re.search(r'####\s*(\d+)', reference) if '####' in reference else None
        ref_answer = ref_match.group(1).strip() if ref_match else reference.strip()

        # Extract numbers from both strings
        pred_nums = re.findall(r'\d+', extracted_answer)
        ref_nums = re.findall(r'\d+', ref_answer)

        # Compare the last number in each
        if pred_nums and ref_nums:
            return pred_nums[-1] == ref_nums[-1]

        return False

    def evaluate_method_accuracy(self, prompts, all_responses, all_logprobs, reference_answers, top_k_values, method):
        """
        Evaluate accuracy of a filtering method for different top-k values.
        """
        results = {"top_k": top_k_values, "accuracy": [], "variance": []}

        for k in top_k_values:
            correct_counts = []

            for prompt_idx, prompt in enumerate(prompts):
                if prompt not in all_responses:
                    continue

                responses = all_responses[prompt]
                logprobs = all_logprobs[prompt]
                reference = reference_answers[prompt_idx]

                # Apply the specified filtering method
                if method == "probability":
                    filtered_responses = self.filter_by_probability(responses, logprobs, k)
                elif method == "ngram":
                    filtered_responses = self.filter_by_ngram(responses, logprobs, k)
                elif method == "semantic":
                    filtered_responses = self.filter_by_semantic_clustering(responses, logprobs, k)
                else:
                    raise ValueError(f"Unknown method: {method}")

                # Check if any of the filtered responses is correct
                any_correct = False
                for response in filtered_responses:
                    if self.is_answer_correct(response, reference):
                        any_correct = True
                        break

                correct_counts.append(1.0 if any_correct else 0.0)

            # Calculate average accuracy and variance
            accuracy = np.mean(correct_counts) if correct_counts else 0.0
            variance = np.var(correct_counts) if correct_counts else 0.0

            results["accuracy"].append(accuracy)
            results["variance"].append(variance)

        return results

    def filter_by_probability(self, responses, logprobs, k):
        """
        Filter responses by taking top-k with highest probability.
        """
        if len(responses) <= k:
            return responses

        # Sort by logprob (higher is better)
        sorted_pairs = sorted(zip(responses, logprobs), key=lambda x: x[1], reverse=True)
        return [pair[0] for pair in sorted_pairs[:k]]

    def filter_by_ngram(self, responses, logprobs, k):
        """
        Filter responses by first clustering based on n-grams, 
        then taking the highest probability response from each cluster.
        """
        if len(responses) <= k:
            return responses

        # Extract answers for clustering
        answers = self.extract_all_answers(responses)

        # Create a dictionary to store clusters
        clusters = {}

        # Group by exact answer match
        for i, answer in enumerate(answers):
            if answer not in clusters:
                clusters[answer] = []
            clusters[answer].append(i)

        # Get the highest probability response from each cluster
        top_responses = []
        for cluster_indices in clusters.values():
            if not cluster_indices:
                continue

            # Find highest probability response in this cluster
            best_idx = max(cluster_indices, key=lambda i: logprobs[i])
            top_responses.append((responses[best_idx], logprobs[best_idx]))

        # Sort clusters by their best response's probability and take top-k
        top_responses.sort(key=lambda x: x[1], reverse=True)
        return [pair[0] for pair in top_responses[:k]]

    def filter_by_semantic_clustering(self, responses, logprobs, k):
        """
        Filter responses by semantic clustering using sentence embeddings,
        then taking the highest probability response from each cluster.
        """
        if len(responses) <= k:
            return responses

        # Create embeddings
        embeddings = self.sentence_transformer.encode(responses)

        # Determine number of clusters (min of k and number of responses)
        n_clusters = min(k, len(responses))

        # Apply K-means clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)

        # Get the highest probability response from each cluster
        clusters = {}
        for i, label in enumerate(cluster_labels):
            if label not in clusters:
                clusters[label] = []
            clusters[label].append(i)

        # Get highest probability response from each cluster
        top_responses = []
        for cluster_indices in clusters.values():
            if not cluster_indices:
                continue

            # Find highest probability response in this cluster
            best_idx = max(cluster_indices, key=lambda i: logprobs[i])
            top_responses.append((responses[best_idx], logprobs[best_idx]))

        # Sort by probability
        top_responses.sort(key=lambda x: x[1], reverse=True)
        return [pair[0] for pair in top_responses[:k]]

    def run_experiment(self, n_samples=100, max_tokens=18192, n_problems=20):
        """
        Run experiments comparing different filtering methods for GSM8k responses.
        """
        results = {}

        # Load GSM8k dataset
        print("\n=== Running experiment for GSM8k ===")
        dataset = load_dataset("gsm8k", "main", split='test')

        # Sample a subset of the dataset
        selected_indices = random.sample(range(len(dataset)), n_problems)
        dataset = dataset.select(selected_indices)

        # Extract prompts and reference answers
        prompts = []
        reference_answers = []

        for item in dataset:
            try:
                prompts.append(item["question"])
                reference_answers.append(item["answer"])
            except KeyError as e:
                print(f"Error accessing key: {e} in dataset. Available keys: {list(item.keys())}")
                continue

        # Generate responses for all prompts
        all_responses, all_logprobs = self.generate_responses_batch(prompts, n_samples, max_tokens)

        # Define top-k values to evaluate
        top_k_values = [3, 5, 7, 10, 13, 16, 20, 25]

        # Evaluate each method
        probability_results = self.evaluate_method_accuracy(
            prompts, all_responses, all_logprobs, reference_answers, top_k_values, "probability"
        )

        ngram_results = self.evaluate_method_accuracy(
            prompts, all_responses, all_logprobs, reference_answers, top_k_values, "ngram"
        )

        semantic_results = self.evaluate_method_accuracy(
            prompts, all_responses, all_logprobs, reference_answers, top_k_values, "semantic"
        )

        results = {
            "top_k_values": top_k_values,
            "probability": probability_results,
            "ngram": ngram_results,
            "semantic": semantic_results
        }

        return results

    def visualize_results(self, results):
        """
        Visualize experimental results comparing different filtering methods.
        """
        top_k_values = results["top_k_values"]

        plt.figure(figsize=(10, 6))

        # Plot accuracy for each method
        plt.plot(top_k_values, results["probability"]["accuracy"], 'o-',
                 label='Probability-based', color='blue')
        plt.plot(top_k_values, results["ngram"]["accuracy"], 'o-',
                 label='N-gram clustering', color='green')
        plt.plot(top_k_values, results["semantic"]["accuracy"], 'o-',
                 label='Semantic clustering', color='red')

        # Add error bars using standard deviation (sqrt of variance)
        plt.errorbar(top_k_values, results["probability"]["accuracy"],
                     yerr=np.sqrt(results["probability"]["variance"]),
                     fmt='none', ecolor='blue', alpha=0.3)
        plt.errorbar(top_k_values, results["ngram"]["accuracy"],
                     yerr=np.sqrt(results["ngram"]["variance"]),
                     fmt='none', ecolor='green', alpha=0.3)
        plt.errorbar(top_k_values, results["semantic"]["accuracy"],
                     yerr=np.sqrt(results["semantic"]["variance"]),
                     fmt='none', ecolor='red', alpha=0.3)

        plt.xlabel('Top-k Value')
        plt.ylabel('Accuracy')
        plt.title('Accuracy Comparison of Different Filtering Methods')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.savefig('gsm8k_filtering_methods_comparison.png')
        plt.close()

        # Print results
        print("\nResults Summary:")
        for k, prob_acc, ngram_acc, sem_acc in zip(
                top_k_values,
                results["probability"]["accuracy"],
                results["ngram"]["accuracy"],
                results["semantic"]["accuracy"]
        ):
            print(f"Top-{k}:")
            print(f"  Probability-based: {prob_acc:.4f}")
            print(f"  N-gram clustering: {ngram_acc:.4f}")
            print(f"  Semantic clustering: {sem_acc:.4f}")


# Main execution
if __name__ == "__main__":
    evaluator = GSM8kResponseFilteringEvaluator()

    # Run experiment - adjust n_problems for faster testing if needed
    results = evaluator.run_experiment(n_samples=100, n_problems=50)

    # Visualize results
    evaluator.visualize_results(results)

    # Save results
    with open('gsm8k_filtering_methods_results.json', 'w') as f:
        # Convert numpy values to float for JSON serialization
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer):
                    return int(obj)
                if isinstance(obj, np.floating):
                    return float(obj)
                if isinstance(obj, np.ndarray):
                    return obj.tolist()
                if np.isnan(obj):
                    return None
                return super(NpEncoder, self).default(obj)

        json.dump(results, f, cls=NpEncoder, indent=2)