import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
import random
import json

class TokenProbabilityAnalyzer:
    def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct", gpu_memory_utilization=0.9):
        """
        Initialize the analyzer with the specified model.
        """
        self.model_name = model_name

        # Load the model and tokenizer
        print(f"Loading model: {model_name}")
        self.model = LLM(
            model=model_name,
            gpu_memory_utilization=gpu_memory_utilization,
            trust_remote_code=True,
            max_model_len=18192,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Sampling parameters
        self.top_p = 0.95

        # Get punctuation token IDs
        self.punctuation_tokens = self.get_punctuation_token_ids()
        print(f"Punctuation tokens: {self.punctuation_tokens}")

    def get_punctuation_token_ids(self):
        """Get token IDs for common punctuation marks"""
        punctuation = ['.', ',', '!', '?', ';', ':', '-', '(', ')', '[', ']', '{', '}', '"', "'"]
        punctuation_ids = {}

        for p in punctuation:
            # Note: This may return multiple tokens for some punctuation marks
            tokens = self.tokenizer.encode(p, add_special_tokens=False)
            if tokens:
                punctuation_ids[p] = tokens[0]  # Take the first token if multiple

        return punctuation_ids

    def format_prompt(self, prompt_text):
        """Format prompt with appropriate chat template"""
        messages = [{"role": "user", "content": prompt_text}]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def generate_responses_batch(self, prompts, n_samples=20, max_tokens=18092, temperature=1.0):
        """
        Generate multiple responses for each prompt in a single batch.
        """
        formatted_prompts = [self.format_prompt(prompt) for prompt in prompts]

        # Set sampling parameters (removed repetition_penalty)
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=self.top_p,
            top_k=40,
            n=n_samples,
            logprobs=19,  # Get logprobs for top 19 tokens at each position
            skip_special_tokens=True,
        )

        # Generate all responses in one batch
        print(f"Generating {n_samples} responses for {len(prompts)} prompts...")
        outputs = self.model.generate(formatted_prompts, sampling_params)

        # Process the outputs
        all_position_logprobs = {}  # Track logprobs by position for all prompts
        all_distances_to_punct = {}  # Track distances to previous punctuation

        for i, prompt_outputs in enumerate(outputs):
            prompt = prompts[i]
            position_logprobs = {}  # For this specific prompt
            distances_to_punct = []  # For this specific prompt

            for output in prompt_outputs.outputs:
                token_ids = output.token_ids

                # Track distances to previous punctuation
                last_punct_pos = -1

                # Extract logprobs for each position
                if hasattr(output, 'logprobs') and output.logprobs:
                    for pos, token_logprobs in enumerate(output.logprobs):
                        if pos < len(token_ids):
                            selected_token = token_ids[pos]

                            # Check if current token is punctuation
                            is_punct = selected_token in self.punctuation_tokens.values()

                            # Calculate distance to previous punctuation
                            dist_to_prev_punct = pos - last_punct_pos

                            # If current token is punctuation, update last punctuation position
                            if is_punct:
                                last_punct_pos = pos

                            if selected_token in token_logprobs:
                                logprob = token_logprobs[selected_token].logprob

                                # Track logprobs by position
                                if pos not in position_logprobs:
                                    position_logprobs[pos] = []
                                position_logprobs[pos].append(logprob)

                                # Track logprob and distance to previous punctuation
                                distances_to_punct.append((dist_to_prev_punct, logprob))

            all_position_logprobs[prompt] = position_logprobs
            all_distances_to_punct[prompt] = distances_to_punct

        return all_position_logprobs, all_distances_to_punct

    def analyze_position_stats(self, all_position_logprobs):
        """
        Analyze mean and standard deviation of logprobs by token position across all responses.
        """
        # Collect all position logprobs across prompts
        position_stats = {}

        for prompt, position_logprobs in all_position_logprobs.items():
            for pos, logprobs in position_logprobs.items():
                if pos not in position_stats:
                    position_stats[pos] = []
                position_stats[pos].extend(logprobs)

        # Calculate statistics for each position
        position_analysis = []
        for pos in sorted(position_stats.keys()):
            if len(position_stats[pos]) > 0:  # Only calculate stats if we have data
                position_analysis.append({
                    'position': pos,
                    'mean_logprob': float(np.mean(position_stats[pos])),
                    'std_logprob': float(np.std(position_stats[pos])),
                    'count': len(position_stats[pos])
                })

        return position_analysis

    def analyze_punct_distance_stats(self, all_distances_to_punct):
        """
        Analyze mean and standard deviation of logprobs by distance to previous punctuation.
        """
        # Collect all data by distance to previous punctuation
        distance_stats = {}

        for prompt, distances in all_distances_to_punct.items():
            for dist, logprob in distances:
                # Group by distance to previous punctuation
                if dist not in distance_stats:
                    distance_stats[dist] = []
                distance_stats[dist].append(logprob)

        # Calculate statistics for each distance
        distance_analysis = []
        for dist in sorted(distance_stats.keys()):
            if len(distance_stats[dist]) > 0:  # Only calculate stats if we have data
                distance_analysis.append({
                    'distance': dist,
                    'mean_logprob': float(np.mean(distance_stats[dist])),
                    'std_logprob': float(np.std(distance_stats[dist])),
                    'count': len(distance_stats[dist])
                })

        return distance_analysis

    def run_analysis(self, n_samples=20, max_tokens=18192):
        """
        Run the analysis of token probabilities by position and punctuation distance.
        """
        # Load GSM8k dataset
        print("\n=== Running analysis on GSM8k dataset ===")
        dataset = load_dataset("gsm8k", "main", split='test')

        # Sample a subset of the dataset
        selected_indices = random.sample(range(len(dataset)), 200)  # Reduced to 50 for faster analysis
        dataset = dataset.select(selected_indices)

        # Extract prompts
        prompts = []
        for item in dataset:
            try:
                prompts.append(item["question"])
            except KeyError as e:
                print(f"Error accessing key: {e} in dataset. Available keys: {list(item.keys())}")
                continue

        # Generate responses for all prompts
        all_position_logprobs, all_distances_to_punct = self.generate_responses_batch(prompts, n_samples, max_tokens)

        # Analyze position-wise statistics
        position_stats = self.analyze_position_stats(all_position_logprobs)

        # Analyze punctuation distance statistics
        punct_distance_stats = self.analyze_punct_distance_stats(all_distances_to_punct)

        # Save results to file
        with open('token_position_stats.json', 'w') as f:
            json.dump(position_stats, f, indent=2)

        with open('punct_distance_stats.json', 'w') as f:
            json.dump(punct_distance_stats, f, indent=2)

        return position_stats, punct_distance_stats

    def plot_statistics(self, position_stats, punct_distance_stats):
        """
        Plot mean and standard deviation of logprobs by position and by distance to punctuation.
        """
        # Get data for position plot
        positions = [stat['position'] for stat in position_stats]
        means = [stat['mean_logprob'] for stat in position_stats]
        stds = [stat['std_logprob'] for stat in position_stats]

        # Limit to first 200 positions for better visualization
        max_pos = 200
        if len(positions) > max_pos:
            positions = positions[:max_pos]
            means = means[:max_pos]
            stds = stds[:max_pos]

        # Plot mean logprob by position with std bands
        plt.figure(figsize=(12, 6))
        plt.plot(positions, means, 'b-', label='Mean Logprob')
        # Add error bands (mean ± std)
        plt.fill_between(positions,
                         [m - s for m, s in zip(means, stds)],
                         [m + s for m, s in zip(means, stds)],
                         color='blue', alpha=0.2, label='Mean ± Std Dev')
        plt.xlabel('Token Position')
        plt.ylabel('Logprob')
        plt.title('Mean and Std Dev of Token Logprobs by Position')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.savefig('mean_logprob_by_position.png')
        plt.close()

        # Plot standard deviation by position
        plt.figure(figsize=(12, 6))
        plt.plot(positions, stds, 'r-', label='Standard Deviation')
        plt.xlabel('Token Position')
        plt.ylabel('Standard Deviation of Logprob')
        plt.title('Standard Deviation of Token Logprobs by Position')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.savefig('std_logprob_by_position.png')
        plt.close()

        # Get data for punctuation distance plot
        distances = [stat['distance'] for stat in punct_distance_stats]
        punct_means = [stat['mean_logprob'] for stat in punct_distance_stats]
        punct_stds = [stat['std_logprob'] for stat in punct_distance_stats]

        # Limit to reasonable distances for better visualization
        max_dist = 50
        if len(distances) > max_dist:
            valid_indices = [i for i, d in enumerate(distances) if d <= max_dist]
            distances = [distances[i] for i in valid_indices]
            punct_means = [punct_means[i] for i in valid_indices]
            punct_stds = [punct_stds[i] for i in valid_indices]

        # Plot mean logprob by distance to previous punctuation with std bands
        plt.figure(figsize=(12, 6))
        plt.plot(distances, punct_means, 'b-', label='Mean Logprob')
        # Add error bands (mean ± std)
        plt.fill_between(distances,
                         [m - s for m, s in zip(punct_means, punct_stds)],
                         [m + s for m, s in zip(punct_means, punct_stds)],
                         color='blue', alpha=0.2, label='Mean ± Std Dev')
        plt.xlabel('Distance to Previous Punctuation (tokens)')
        plt.ylabel('Logprob')
        plt.title('Mean Logprob by Distance to Previous Punctuation')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.savefig('logprob_by_punct_distance.png')
        plt.close()

        # Print some statistics
        print("\nToken Probability Statistics Summary:")

        # Overall statistics
        print(f"Average mean logprob across positions: {np.mean(means):.4f}")
        print(f"Average standard deviation across positions: {np.mean(stds):.4f}")

        # First tokens statistics
        first_n = 10
        if len(means) >= first_n:
            print(f"\nFirst {first_n} token positions:")
            print(f"  Average mean: {np.mean(means[:first_n]):.4f}")
            print(f"  Average std dev: {np.mean(stds[:first_n]):.4f}")

        # Positions with highest/lowest mean
        if means:
            max_mean_idx = np.argmax(means)
            min_mean_idx = np.argmin(means)

            print(f"\nPosition with highest mean logprob: {positions[max_mean_idx]} (Mean: {means[max_mean_idx]:.4f})")
            print(f"Position with lowest mean logprob: {positions[min_mean_idx]} (Mean: {means[min_mean_idx]:.4f})")

        # Positions with highest/lowest std dev
        if stds:
            max_std_idx = np.argmax(stds)
            min_std_idx = np.argmin(stds)

            print(f"\nPosition with highest std dev: {positions[max_std_idx]} (Std Dev: {stds[max_std_idx]:.4f})")
            print(f"Position with lowest std dev: {positions[min_std_idx]} (Std Dev: {stds[min_std_idx]:.4f})")

        # Punctuation distance statistics
        print("\nPunctuation Distance Statistics:")
        print(f"Average mean logprob across all distances: {np.mean(punct_means):.4f}")
        print(f"Average std dev across all distances: {np.mean(punct_stds):.4f}")

        # Distances with highest/lowest mean logprob
        if punct_means:
            max_mean_idx = np.argmax(punct_means)
            min_mean_idx = np.argmin(punct_means)

            print(f"\nDistance with highest mean logprob: {distances[max_mean_idx]} tokens (Mean: {punct_means[max_mean_idx]:.4f})")
            print(f"Distance with lowest mean logprob: {distances[min_mean_idx]} tokens (Mean: {punct_means[min_mean_idx]:.4f})")

    def save_for_latex(self, position_stats, punct_distance_stats):
        """
        Save data in a format that can be directly loaded into LaTeX.
        """
        # Save position stats in PGFPlots format
        with open('position_stats_latex.dat', 'w') as f:
            f.write("# position mean_logprob lower_bound upper_bound\n")
            for stat in position_stats:
                pos = stat['position']
                mean = stat['mean_logprob']
                std = stat['std_logprob']
                # Write data in space-separated format
                f.write(f"{pos} {mean} {mean-std} {mean+std}\n")

        # Save punctuation distance stats in PGFPlots format
        with open('punct_distance_stats_latex.dat', 'w') as f:
            f.write("# distance mean_logprob lower_bound upper_bound\n")
            for stat in punct_distance_stats:
                dist = stat['distance']
                mean = stat['mean_logprob']
                std = stat['std_logprob']
                # Write data in space-separated format
                f.write(f"{dist} {mean} {mean-std} {mean+std}\n")


if __name__ == "__main__":
    analyzer = TokenProbabilityAnalyzer()

    # Run analysis
    position_stats, punct_distance_stats = analyzer.run_analysis(n_samples=20)

    # Plot statistics
    analyzer.plot_statistics(position_stats, punct_distance_stats)

    # Save data for LaTeX
    analyzer.save_for_latex(position_stats, punct_distance_stats)