import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
import random
import json
import re

class ProbabilityQualityEvaluator:
    def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct", gpu_memory_utilization=0.9):
        """
        Initialize the evaluator with the specified model.
        """
        self.model_name = model_name

        # Load the model and tokenizer
        print(f"Loading model: {model_name}")
        self.model = LLM(
            model=model_name,
            gpu_memory_utilization=gpu_memory_utilization,
            trust_remote_code=True,
            max_model_len=18192,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Sampling parameters
        self.top_p = 0.95
        self.repetition_penalty = 1.2

    def format_prompt(self, prompt_text):
        """Format prompt with appropriate chat template"""
        messages = [{"role": "user", "content": prompt_text}]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def generate_responses_batch(self, prompts, n_samples=20, max_tokens=18092, temperature=1.0):
        """
        Generate multiple responses for each prompt in a single batch.
        """
        formatted_prompts = [self.format_prompt(prompt) for prompt in prompts]

        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=self.top_p,
            top_k=40,
            n=n_samples,
            # repetition_penalty=self.repetition_penalty,
            logprobs=19,  # Get logprobs for top 19 tokens at each position
            skip_special_tokens=True,
        )

        # Generate all responses in one batch
        print(f"Generating {n_samples} responses for {len(prompts)} prompts...")
        outputs = self.model.generate(formatted_prompts, sampling_params)

        # Process the outputs
        all_responses = {}
        all_position_logprobs = {}  # Track logprobs by position

        for i, prompt_outputs in enumerate(outputs):
            prompt = prompts[i]
            responses = []
            position_logprobs = {}  # For this specific prompt

            for output in prompt_outputs.outputs:
                response_text = output.text
                token_ids = output.token_ids

                # Get the cumulative logprob if available
                cumulative_logprob = output.cumulative_logprob if hasattr(output, 'cumulative_logprob') else None

                # Extract logprobs for each position
                token_position_logprobs = []

                if hasattr(output, 'logprobs') and output.logprobs:
                    for pos, token_logprobs in enumerate(output.logprobs):
                        if pos < len(token_ids):
                            selected_token = token_ids[pos]
                            if selected_token in token_logprobs:
                                logprob = token_logprobs[selected_token].logprob
                                token_position_logprobs.append((pos, logprob))

                                # Track logprobs by position
                                if pos not in position_logprobs:
                                    position_logprobs[pos] = []
                                position_logprobs[pos].append(logprob)

                # Calculate average logprob
                if cumulative_logprob is not None and len(token_ids) > 0:
                    # Use the cumulative logprob if available
                    avg_logprob = cumulative_logprob / len(token_ids)
                elif token_position_logprobs:
                    # Otherwise calculate from token-wise logprobs
                    avg_logprob = sum(lp for _, lp in token_position_logprobs) / len(token_position_logprobs)
                else:
                    avg_logprob = 0

                responses.append((response_text, avg_logprob, token_position_logprobs))
                # print("token_position_logprobs:", token_position_logprobs)

            all_responses[prompt] = responses
            all_position_logprobs[prompt] = position_logprobs

        return all_responses, all_position_logprobs

    def generate_greedy_responses(self, prompts, max_tokens=18092):
        """
        Generate single greedy (temperature=0) response for each prompt.
        """
        formatted_prompts = [self.format_prompt(prompt) for prompt in prompts]

        # Set greedy sampling parameters
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            top_p=1.0,
            top_k=-1,  # No top-k filtering for greedy
            n=1,
            logprobs=19,  # Get logprobs for top 19 tokens at each position
            skip_special_tokens=True,
        )

        # Generate greedy responses in one batch
        print(f"Generating greedy responses for {len(prompts)} prompts...")
        outputs = self.model.generate(formatted_prompts, sampling_params)

        # Process the outputs
        greedy_responses = {}
        greedy_logprobs = {}
        greedy_position_logprobs = {}

        for i, prompt_outputs in enumerate(outputs):
            prompt = prompts[i]
            output = prompt_outputs.outputs[0]  # Only one response per prompt

            response_text = output.text
            token_ids = output.token_ids

            # Calculate average logprob
            cumulative_logprob = 0
            token_position_logprobs = []
            position_logprobs = {}

            if hasattr(output, 'logprobs') and output.logprobs:
                for pos, token_logprobs in enumerate(output.logprobs):
                    if pos < len(token_ids):
                        selected_token = token_ids[pos]
                        if selected_token in token_logprobs:
                            logprob = token_logprobs[selected_token].logprob
                            token_position_logprobs.append((pos, logprob))
                            cumulative_logprob += logprob

                            # Track logprobs by position
                            if pos not in position_logprobs:
                                position_logprobs[pos] = []
                            position_logprobs[pos].append(logprob)

            # Calculate average logprob
            if token_position_logprobs:
                avg_logprob = cumulative_logprob / len(token_position_logprobs)
            else:
                avg_logprob = 0

            greedy_responses[prompt] = response_text
            greedy_logprobs[prompt] = avg_logprob
            greedy_position_logprobs[prompt] = position_logprobs

        return greedy_responses, greedy_logprobs, greedy_position_logprobs

    def extract_final_answer(self, response):
        """
        Extract the final answer from a GSM8k response.
        """
        # For GSM8K, look for the final answer pattern (####)
        match = re.search(r'####\s*(\d+)', response)
        if match:
            return match.group(1).strip()

        # If no #### pattern, look for the last number in the response
        numbers = re.findall(r'\d+', response)
        if numbers:
            return numbers[-1].strip()

        return response.strip()

    def is_answer_correct(self, predicted, reference):
        """
        Compare the predicted answer with the reference answer for GSM8k.
        """
        # Convert reference to string if it's not already
        reference = str(reference) if reference is not None else ""

        # Extract the final answer from predictions
        extracted_answer = self.extract_final_answer(predicted)

        # Extract the number from reference if needed
        ref_match = re.search(r'####\s*(\d+)', reference) if '####' in reference else None
        ref_answer = ref_match.group(1).strip() if ref_match else reference.strip()

        # Extract numbers from both strings
        pred_nums = re.findall(r'\d+', extracted_answer)
        ref_nums = re.findall(r'\d+', ref_answer)

        # Compare the last number in each
        if pred_nums and ref_nums:
            return pred_nums[-1] == ref_nums[-1]

        return False

    def evaluate_against_ground_truth(self, prompts, all_responses, reference_answers):
        """
        Evaluate all responses against ground truth reference answers.
        """
        results = {}

        for prompt_idx, prompt in enumerate(prompts):
            reference = reference_answers[prompt_idx]
            results[prompt] = []

            for response_idx, (response, _, _) in enumerate(all_responses[prompt]):
                # Compare the response against the reference answer
                is_correct = self.is_answer_correct(response, reference)

                # Store results (1.0 for correct, 0.0 for incorrect)
                results[prompt].append(1.0 if is_correct else 0.0)

        return results

    def evaluate_greedy_against_ground_truth(self, prompts, greedy_responses, reference_answers):
        """
        Evaluate greedy responses against ground truth reference answers.
        """
        results = {}

        for prompt_idx, prompt in enumerate(prompts):
            reference = reference_answers[prompt_idx]
            # Compare the response against the reference answer
            is_correct = self.is_answer_correct(greedy_responses[prompt], reference)
            # Store results (1.0 for correct, 0.0 for incorrect)
            results[prompt] = 1.0 if is_correct else 0.0

        return results

    def evaluate_with_llm_judge(self, prompts, all_responses, reference_answers):
        """
        Evaluate all responses using LLM-as-judge.
        """
        judge_prompts = []
        response_map = []  # To map judge results back to original responses

        # Create judge prompts for all responses
        for prompt_idx, prompt in enumerate(prompts):
            for response_idx, (response, _, _) in enumerate(all_responses[prompt]):
                judge_prompt = f"""
                Task: {prompt}
                
                Response to evaluate: 
                {response}
                
                Is this response correct? Answer with only 'Yes' or 'No'.
                """

                formatted_judge_prompt = self.format_prompt(judge_prompt)
                judge_prompts.append(formatted_judge_prompt)
                response_map.append((prompt, response_idx))

        # Generate judgments in batch
        print(f"Evaluating {len(judge_prompts)} responses with LLM-as-judge...")
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=10,
            skip_special_tokens=True,
        )

        judge_outputs = self.model.generate(judge_prompts, sampling_params)

        # Process judgments
        llm_judge_results = {}
        for i, output in enumerate(judge_outputs):
            prompt, response_idx = response_map[i]
            judgment_text = output.outputs[0].text.strip().lower()
            score = 1.0 if "yes" in judgment_text else 0.0

            if prompt not in llm_judge_results:
                llm_judge_results[prompt] = []

            # Store the result at the correct index
            while len(llm_judge_results[prompt]) <= response_idx:
                llm_judge_results[prompt].append(None)
            llm_judge_results[prompt][response_idx] = score

        return llm_judge_results

    def evaluate_greedy_with_llm_judge(self, prompts, greedy_responses):
        """
        Evaluate greedy responses using LLM-as-judge.
        """
        judge_prompts = []
        prompt_list = []  # To maintain order

        # Create judge prompts for all greedy responses
        for prompt_idx, prompt in enumerate(prompts):
            if prompt in greedy_responses:
                judge_prompt = f"""
                Task: {prompt}
                
                Response to evaluate: 
                {greedy_responses[prompt]}
                
                Is this response correct? Answer with only 'Yes' or 'No'.
                """

                formatted_judge_prompt = self.format_prompt(judge_prompt)
                judge_prompts.append(formatted_judge_prompt)
                prompt_list.append(prompt)

        # Generate judgments in batch
        print(f"Evaluating {len(judge_prompts)} greedy responses with LLM-as-judge...")
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=10,
            skip_special_tokens=True,
        )

        judge_outputs = self.model.generate(judge_prompts, sampling_params)

        # Process judgments
        llm_judge_results = {}
        for i, output in enumerate(judge_outputs):
            prompt = prompt_list[i]
            judgment_text = output.outputs[0].text.strip().lower()
            score = 1.0 if "yes" in judgment_text else 0.0
            llm_judge_results[prompt] = score

        return llm_judge_results

    def calculate_metrics(self, y_true, y_pred):
        """
        Calculate precision, recall, F1 score for binary classification.
        """
        # Convert to numpy arrays if they're not already
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # Calculate metrics
        true_positives = np.sum((y_true == 1) & (y_pred == 1))
        false_positives = np.sum((y_true == 0) & (y_pred == 1))
        false_negatives = np.sum((y_true == 1) & (y_pred == 0))

        # Calculate precision, recall, F1
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        # Calculate accuracy
        accuracy = np.mean(y_true == y_pred)

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1)
        }

    def run_experiment(self, n_samples=20, max_tokens=18192):
        """
        Run the experiment comparing probability-based quality assessment vs ground truth.
        """
        results = {}

        # Load GSM8k dataset
        print("\n=== Running experiment for GSM8k ===")
        dataset = load_dataset("gsm8k", "main", split='test')

        # Sample a subset of the dataset
        # selected_indices = random.sample(range(len(dataset)), 500)
        selected_indices = random.sample(range(len(dataset)), 100)
        dataset = dataset.select(selected_indices)

        # Extract prompts and reference answers
        prompts = []
        reference_answers = []

        for item in dataset:
            try:
                prompts.append(item["question"])
                reference_answers.append(item["answer"])
            except KeyError as e:
                print(f"Error accessing key: {e} in dataset. Available keys: {list(item.keys())}")
                continue

        # Generate responses for all prompts
        all_responses, all_position_logprobs = self.generate_responses_batch(prompts, n_samples, max_tokens)

        # Generate greedy responses for all prompts
        greedy_responses, greedy_logprobs, greedy_position_logprobs = self.generate_greedy_responses(prompts, max_tokens)

        # Method 1: Evaluate with ground truth
        ground_truth_results = self.evaluate_against_ground_truth(prompts, all_responses, reference_answers)

        # Evaluate greedy responses with ground truth
        greedy_ground_truth_results = self.evaluate_greedy_against_ground_truth(prompts, greedy_responses, reference_answers)

        # Method 2: Evaluate with LLM-as-judge
        llm_judge_results = self.evaluate_with_llm_judge(prompts, all_responses, reference_answers)

        # Evaluate greedy responses with LLM-as-judge
        greedy_llm_judge_results = self.evaluate_greedy_with_llm_judge(prompts, greedy_responses)

        # Process position-wise logprobs
        position_stats = self.analyze_position_logprobs(all_position_logprobs)

        # Process greedy position-wise logprobs
        greedy_position_stats = self.analyze_greedy_position_logprobs(greedy_position_logprobs)

        # Initialize task results
        task_results = {
            'prompts': prompts,
            'reference_answers': reference_answers,
            'position_stats': position_stats,
            'greedy_position_stats': greedy_position_stats,
            'logprob_thresholds': [],
            'logprob_metrics': [],
            'llm_judge_metrics': {},
            'ground_truth_accuracy': 0.0,
            'greedy_metrics': {
                'avg_logprob': 0.0,
                'ground_truth_accuracy': 0.0,
                'llm_judge_accuracy': 0.0
            }
        }

        # Collect all logprobs, llm judge scores, and ground truth correctness scores
        all_logprobs = []
        all_llm_judge_scores = []
        all_ground_truth_scores = []

        for prompt_idx, prompt in enumerate(prompts):
            if prompt not in all_responses or prompt not in ground_truth_results or prompt not in llm_judge_results:
                continue

            responses = all_responses[prompt]
            gt_scores = ground_truth_results[prompt]
            judge_scores = llm_judge_results[prompt]

            # Skip if lengths don't match
            if len(responses) != len(gt_scores) or len(responses) != len(judge_scores):
                continue

            for resp_idx in range(len(responses)):
                _, logprob, _ = responses[resp_idx]
                gt_score = gt_scores[resp_idx]
                judge_score = judge_scores[resp_idx]

                all_logprobs.append(logprob)
                all_ground_truth_scores.append(gt_score)
                all_llm_judge_scores.append(judge_score)

        if not all_logprobs:
            print("Warning: No valid logprobs collected. Skipping threshold analysis.")
            return results

        # Calculate LLM-as-judge metrics against ground truth
        llm_judge_metrics = self.calculate_metrics(all_ground_truth_scores, all_llm_judge_scores)
        task_results['llm_judge_metrics'] = llm_judge_metrics

        print(f"LLM-as-judge metrics against ground truth:")
        print(f"  Accuracy: {llm_judge_metrics['accuracy']:.4f}")
        print(f"  Precision: {llm_judge_metrics['precision']:.4f}")
        print(f"  Recall: {llm_judge_metrics['recall']:.4f}")
        print(f"  F1 Score: {llm_judge_metrics['f1']:.4f}")

        # Calculate ground truth accuracy (overall correct rate)
        ground_truth_accuracy = sum(all_ground_truth_scores) / len(all_ground_truth_scores) if all_ground_truth_scores else 0
        task_results['ground_truth_accuracy'] = float(ground_truth_accuracy)

        print(f"Ground truth accuracy: {ground_truth_accuracy:.4f}")

        # Calculate greedy metrics
        greedy_avg_logprob = np.mean(list(greedy_logprobs.values()))
        greedy_gt_accuracy = np.mean(list(greedy_ground_truth_results.values()))
        greedy_judge_accuracy = np.mean(list(greedy_llm_judge_results.values()))

        task_results['greedy_metrics'] = {
            'avg_logprob': float(greedy_avg_logprob),
            'ground_truth_accuracy': float(greedy_gt_accuracy),
            'llm_judge_accuracy': float(greedy_judge_accuracy)
        }

        print(f"\nGreedy decoding metrics:")
        print(f"  Average logprob: {greedy_avg_logprob:.4f}")
        print(f"  Ground truth accuracy: {greedy_gt_accuracy:.4f}")
        print(f"  LLM-as-judge accuracy: {greedy_judge_accuracy:.4f}")

        # Calculate metrics at predefined logprob thresholds
        predefined_thresholds = np.linspace(-3.0, -0.1, 20)
        for threshold in predefined_thresholds:
            # Logprob-based classification
            logprob_predictions = [1.0 if lp >= threshold else 0.0 for lp in all_logprobs]

            # Calculate metrics against ground truth
            metrics = self.calculate_metrics(all_ground_truth_scores, logprob_predictions)

            task_results['logprob_thresholds'].append(float(threshold))
            task_results['logprob_metrics'].append(metrics)

            print(f"Threshold {threshold:.4f} metrics:")
            print(f"  Accuracy: {metrics['accuracy']:.4f}")
            print(f"  Precision: {metrics['precision']:.4f}")
            print(f"  Recall: {metrics['recall']:.4f}")
            print(f"  F1 Score: {metrics['f1']:.4f}")

        results["GSM8k"] = task_results

        return results

    def analyze_position_logprobs(self, all_position_logprobs):
        """
        Analyze logprobs by token position across all responses.
        """
        # Collect all position logprobs across prompts
        position_stats = {}

        for prompt, position_logprobs in all_position_logprobs.items():
            for pos, logprobs in position_logprobs.items():
                if pos not in position_stats:
                    position_stats[pos] = []
                position_stats[pos].extend(logprobs)

        # Calculate average logprob for each position
        position_averages = []
        for pos in sorted(position_stats.keys()):
            position_averages.append({
                'position': pos,
                'mean_logprob': float(np.mean(position_stats[pos])),
                'count': len(position_stats[pos])
            })

        return position_averages

    def analyze_greedy_position_logprobs(self, greedy_position_logprobs):
        """
        Analyze logprobs by token position across all greedy responses.
        """
        # Collect all position logprobs across prompts
        position_stats = {}

        for prompt, position_logprobs in greedy_position_logprobs.items():
            for pos, logprobs in position_logprobs.items():
                if pos not in position_stats:
                    position_stats[pos] = []
                position_stats[pos].extend(logprobs)

        # Calculate average logprob for each position
        position_averages = []
        for pos in sorted(position_stats.keys()):
            position_averages.append({
                'position': pos,
                'mean_logprob': float(np.mean(position_stats[pos])),
                'count': len(position_stats[pos])
            })

        return position_averages

    def visualize_results(self, results):
        """
        Visualize experimental results for GSM8k.
        """
        task_name = "GSM8k"
        task_results = results.get(task_name)

        if not task_results:
            print(f"No results found for {task_name}")
            return

        # 1. Plot accuracy vs logprob threshold
        if 'logprob_thresholds' in task_results and task_results['logprob_thresholds']:
            plt.figure(figsize=(10, 6))
            accuracies = [m['accuracy'] for m in task_results['logprob_metrics']]
            plt.plot(task_results['logprob_thresholds'], accuracies,
                     'o-', label='Logprob-based Accuracy')

            # Add LLM-as-judge accuracy line
            if 'llm_judge_metrics' in task_results:
                plt.axhline(y=task_results['llm_judge_metrics']['accuracy'], color='r',
                            linestyle='--', label='LLM-as-Judge Accuracy')

            # Add greedy accuracy lines if available
            if 'greedy_metrics' in task_results:
                plt.axhline(y=task_results['greedy_metrics']['ground_truth_accuracy'], color='g',
                            linestyle='--', label='Greedy Ground Truth Accuracy')
                plt.axhline(y=task_results['greedy_metrics']['llm_judge_accuracy'], color='purple',
                            linestyle='--', label='Greedy LLM-as-Judge Accuracy')

                # Add greedy avg logprob as vertical line
                plt.axvline(x=task_results['greedy_metrics']['avg_logprob'], color='orange',
                            linestyle='--', label=f'Greedy Avg Logprob: {task_results["greedy_metrics"]["avg_logprob"]:.4f}')

            plt.xlabel('Logprob Threshold')
            plt.ylabel('Accuracy')
            plt.title(f'Accuracy vs Logprob Threshold for {task_name}')
            plt.legend()
            plt.grid(True)
            plt.savefig(f'{task_name}_accuracy_vs_threshold.png')
            plt.close()

        # 2. Plot position-wise average logprob as a histogram for sampling
        if 'position_stats' in task_results and task_results['position_stats']:
            position_stats = task_results['position_stats']
            positions = [stat['position'] for stat in position_stats]
            mean_logprobs = [stat['mean_logprob'] for stat in position_stats]

            # Limit to reasonable number of positions for better visualization
            max_pos = 1000  # Show first 100 positions only
            if len(positions) > max_pos:
                positions = positions[:max_pos]
                mean_logprobs = mean_logprobs[:max_pos]

            plt.figure(figsize=(12, 6))

            # Create histogram of logprobs
            plt.hist(mean_logprobs, bins=20, alpha=0.7, color='skyblue', label='Sampling')

            # Add a histogram for greedy if available
            if 'greedy_position_stats' in task_results and task_results['greedy_position_stats']:
                greedy_position_stats = task_results['greedy_position_stats']
                greedy_positions = [stat['position'] for stat in greedy_position_stats if stat['position'] < max_pos]
                greedy_mean_logprobs = [stat['mean_logprob'] for stat in greedy_position_stats if stat['position'] < max_pos]

                if greedy_mean_logprobs:
                    plt.hist(greedy_mean_logprobs, bins=20, alpha=0.5, color='orange', label='Greedy')

            # Add a kernel density estimate
            density = plt.gca().twinx()
            from scipy import stats
            kde = stats.gaussian_kde(mean_logprobs)
            x_range = np.linspace(min(mean_logprobs), max(mean_logprobs), 100)
            density.plot(x_range, kde(x_range), 'r-')
            density.set_ylabel('Density', color='r')
            density.tick_params(axis='y', colors='r')

            # Add vertical line for sampling mean
            sampling_mean = np.mean(mean_logprobs)
            plt.axvline(sampling_mean, color='blue', linestyle='dashed', linewidth=2,
                        label=f'Sampling Mean: {sampling_mean:.4f}')

            # Add vertical line for greedy mean if available
            if 'greedy_metrics' in task_results:
                greedy_mean = task_results['greedy_metrics']['avg_logprob']
                plt.axvline(greedy_mean, color='green', linestyle='dashed', linewidth=2,
                            label=f'Greedy Mean: {greedy_mean:.4f}')

            plt.xlabel('Mean Logprob Value')
            plt.ylabel('Count')
            plt.title(f'Distribution of Logprobs Across First {max_pos} Token Positions for {task_name}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.savefig(f'{task_name}_position_logprobs_histogram.png')
            plt.close()

            # 3. Create a heatmap visualization of position vs logprob comparing sampling and greedy
            plt.figure(figsize=(12, 6))

            # Plot sampling logprobs
            plt.scatter(positions, mean_logprobs, c='blue', alpha=0.6, s=30,
                        label='Sampling', edgecolors='k', linewidths=0.5)

            # Plot greedy logprobs if available
            if 'greedy_position_stats' in task_results and task_results['greedy_position_stats']:
                greedy_position_stats = task_results['greedy_position_stats']
                greedy_positions = [stat['position'] for stat in greedy_position_stats if stat['position'] < max_pos]
                greedy_mean_logprobs = [stat['mean_logprob'] for stat in greedy_position_stats if stat['position'] < max_pos]

                if greedy_positions and greedy_mean_logprobs:
                    plt.scatter(greedy_positions, greedy_mean_logprobs, c='red', alpha=0.6, s=30,
                                label='Greedy', edgecolors='k', linewidths=0.5)

            plt.xlabel('Token Position')
            plt.ylabel('Mean Logprob')
            plt.title(f'Logprob Comparison by Token Position for {task_name}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.savefig(f'{task_name}_position_logprobs_comparison.png')
            plt.close()

        # Print key findings
        print(f"\nKey findings for {task_name}:")

        # Ground Truth
        print(f"Ground truth accuracy: {task_results['ground_truth_accuracy']:.4f}")

        # Greedy Metrics
        if 'greedy_metrics' in task_results:
            print("\nGreedy decoding metrics:")
            print(f"  Average logprob: {task_results['greedy_metrics']['avg_logprob']:.4f}")
            print(f"  Ground truth accuracy: {task_results['greedy_metrics']['ground_truth_accuracy']:.4f}")
            print(f"  LLM-as-judge accuracy: {task_results['greedy_metrics']['llm_judge_accuracy']:.4f}")

        # LLM Judge vs Ground Truth
        if 'llm_judge_metrics' in task_results:
            metrics = task_results['llm_judge_metrics']
            print("\nLLM-as-judge metrics vs ground truth:")
            print(f"  Accuracy: {metrics['accuracy']:.4f}")
            print(f"  Precision: {metrics['precision']:.4f}")
            print(f"  Recall: {metrics['recall']:.4f}")
            print(f"  F1 Score: {metrics['f1']:.4f}")

        # Logprob Threshold Performance
        if 'logprob_metrics' in task_results and task_results['logprob_metrics']:
            # Find best threshold by accuracy
            best_acc_idx = np.argmax([m['accuracy'] for m in task_results['logprob_metrics']])
            best_acc_threshold = task_results['logprob_thresholds'][best_acc_idx]
            best_acc_metrics = task_results['logprob_metrics'][best_acc_idx]

            print("\nBest logprob threshold by accuracy:")
            print(f"  Threshold: {best_acc_threshold:.4f}")
            print(f"  Accuracy: {best_acc_metrics['accuracy']:.4f}")
            print(f"  Precision: {best_acc_metrics['precision']:.4f}")
            print(f"  Recall: {best_acc_metrics['recall']:.4f}")
            print(f"  F1 Score: {best_acc_metrics['f1']:.4f}")

        # Position-wise logprob analysis
        if 'position_stats' in task_results and task_results['position_stats']:
            positions = [stat['position'] for stat in task_results['position_stats']]
            logprobs = [stat['mean_logprob'] for stat in task_results['position_stats']]

            if positions and logprobs:
                # Report average for first few tokens
                first_tokens = 5
                if len(positions) >= first_tokens:
                    first_avg = np.mean(logprobs[:first_tokens])
                    print(f"\nAverage logprob of first {first_tokens} tokens: {first_avg:.4f}")

                # Report overall average
                overall_avg = np.mean(logprobs)
                print(f"Overall average token logprob: {overall_avg:.4f}")

        # Compare greedy vs sampling logprobs
        if 'greedy_position_stats' in task_results and task_results['position_stats']:
            sampling_logprobs = [stat['mean_logprob'] for stat in task_results['position_stats']]
            greedy_logprobs = [stat['mean_logprob'] for stat in task_results['greedy_position_stats']]

            # Limit to the minimum length of both
            min_len = min(len(sampling_logprobs), len(greedy_logprobs))
            if min_len > 0:
                sampling_avg = np.mean(sampling_logprobs[:min_len])
                greedy_avg = np.mean(greedy_logprobs[:min_len])

                print(f"\nComparison of average logprobs (first {min_len} tokens):")
                print(f"  Sampling average: {sampling_avg:.4f}")
                print(f"  Greedy average: {greedy_avg:.4f}")
                print(f"  Difference (Greedy - Sampling): {greedy_avg - sampling_avg:.4f}")


# Main execution
if __name__ == "__main__":
    evaluator = ProbabilityQualityEvaluator()

    # Run experiment
    results = evaluator.run_experiment(n_samples=20)

    # Visualize results
    evaluator.visualize_results(results)

    # Save results
    with open('probability_quality_experiment_results.json', 'w') as f:
        # Convert numpy values to float for JSON serialization
        class NpEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, np.integer):
                    return int(obj)
                if isinstance(obj, np.floating):
                    return float(obj)
                if isinstance(obj, np.ndarray):
                    return obj.tolist()
                if np.isnan(obj):
                    return None
                return super(NpEncoder, self).default(obj)

        json.dump(results, f, cls=NpEncoder, indent=2)