import math
from copy import deepcopy
from pprint import pprint
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.configuration_utils import GenerationConfig
from sentence_transformers import SentenceTransformer
from vllm import LLM, SamplingParams, TokensPrompt

logger = logging.get_logger(__name__)
# logger.setLevel(logging.WARNING)
import logging as _logging

_logging.getLogger("httpx").setLevel(logging.ERROR)


class EmbeddingModelWrapper:
    """
    Wrapper for embedding model to standardize encoding interface.
    添加批处理支持以提高效率。
    """

    def __init__(self, model_name="all-MiniLM-L6-v2", device=None):
        """
        Initialize embedding model wrapper using SentenceTransformer.

        Args:
            model_name: Name of the SentenceTransformer model
            device: Device to run the model on
        """
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(model_name, device=self.device)

    def encode(self, input_text):
        """
        Encode input text into semantic embeddings.

        Args:
            input_text: String text or list of strings to encode

        Returns:
            Embedding vector as tensor or tensor batch
        """
        # Convert to tensor if not already
        embedding = self.model.encode(
            input_text[-1000:],
            convert_to_tensor=True,
            show_progress_bar=False,
        )
        return embedding.to(self.device)


class DirectVLLMWrapper:
    """
    Wrapper for interacting with vLLM directly with improved token tracking and batch processing.
    Simplified to focus on handling batched inputs with consistent parameters.
    """

    def __init__(self, model_name, gpu_memory_utilization=0.85, tokenizer=None, top_p=0.8, repetition_penalty=1.2):
        """
        Initialize a direct wrapper for vLLM model.

        Args:
            model_name: Name of the model to load
            gpu_memory_utilization: Fraction of GPU memory to use
            tokenizer: Tokenizer for processing inputs and outputs
            top_p: Top-p sampling parameter
            repetition_penalty: Repetition penalty parameter
        """
        self.model_name = model_name
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty

        # Load the model
        print(f"Loading vLLM model: {model_name}")
        self.vllm_model = LLM(
            model=model_name,
            gpu_memory_utilization=gpu_memory_utilization,
            trust_remote_code=True,
            max_model_len=8192,
            enable_prefix_caching=True,
        )

        # Initialize tokenizer if not provided
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def generate(self, prompt=None, max_tokens=100, temperature=1.0, stop=None, stop_token_ids=None, logprobs=19, n=1,
                 input_tokens=None):
        """
        Generate text using vLLM directly with improved token tracking and batch processing.
        Modified to accept and return token IDs directly and use TokensPrompt.

        Args:
            prompt: Input text prompt or list of prompts (used if input_tokens is None)
            max_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature (0 for greedy)
            stop: List of strings to stop generation
            stop_token_ids: List of token IDs to stop generation
            logprobs: Number of top log probabilities to return
            n: Number of output sequences to generate per prompt
            input_tokens: Optional input token IDs (list or batch of lists) to use instead of text prompt

        Returns:
            If input is single:
                List of tuples, each containing:
                (token_ids, tokens_generated, prompt_tokens, avg_log_probability, last_token_entropy)
            If input is batch:
                List of lists, where each inner list contains the tuples described above for each prompt
        """
        try:
            # Check if we're using token input or text input
            using_tokens = input_tokens is not None

            # Determine if we have a batch
            is_batch = isinstance(input_tokens, list) and isinstance(input_tokens[0], list)
            token_sequences = input_tokens if is_batch else [input_tokens]
            prompt_tokens_counts = [len(tokens) for tokens in token_sequences]

            # Prepare sampling parameters
            sampling_params = SamplingParams(
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=self.top_p,
                top_k=20,
                repetition_penalty=self.repetition_penalty,
                # frequency_penalty=0.1,
                stop_token_ids=stop_token_ids,
                logprobs=logprobs,
                skip_special_tokens=True,
                include_stop_str_in_output=True,
            )

            # Wrap each token sequence with TokensPrompt
            wrapped_prompts = [TokensPrompt(prompt_token_ids=tokens) for tokens in token_sequences]

            # Generate using vLLM
            outputs = self.vllm_model.generate(
                wrapped_prompts,
                sampling_params=sampling_params,
                use_tqdm=False,
            )

            # Process results for each prompt
            all_results = []
            voc_set = set(self.tokenizer.get_vocab().values())
            for i, prompt_outputs in enumerate(outputs):
                prompt_results = []
                prompt_token_count = prompt_tokens_counts[i]

                # Get all generations for this prompt
                generations = prompt_outputs.outputs

                for generation in generations:
                    # Get token IDs directly from the generation object
                    token_ids = generation.token_ids
                    # Filter token_ids to ensure they're in the vocabulary
                    # token_ids = [token_id for token_id in token_ids if token_id in voc_set]

                    # Get cumulative logprob directly from the output
                    cumulative_logprob = generation.cumulative_logprob

                    # Count generated tokens
                    generated_token_count = len(token_ids)

                    # Calculate average log probability
                    avg_logprob = cumulative_logprob / generated_token_count if generated_token_count > 0 else 0.0

                    # Calculate entropy for the last token
                    last_token_entropy = 0

                    # Get the last token's logprobs distribution
                    if generated_token_count > 0 and hasattr(generation, 'logprobs') and generation.logprobs:
                        last_token_distribution = generation.logprobs[-1]

                        # Extract probabilities
                        probs = {}
                        for token, logprob_obj in last_token_distribution.items():
                            if hasattr(logprob_obj, 'logprob'):
                                probs[token] = np.exp(logprob_obj.logprob)

                        # Only calculate entropy if we have probabilities
                        if probs:
                            # Normalize probabilities
                            total = sum(probs.values())
                            probs = {token: p / total for token, p in probs.items()}
                            # Calculate entropy
                            last_token_entropy = -sum(p * np.log(p) for p in probs.values())

                    prompt_results.append(
                        (token_ids, generated_token_count, prompt_token_count, avg_logprob, last_token_entropy))

                all_results.append(prompt_results)

            # If input was a single prompt, return just the results for that prompt
            return all_results[0] if not is_batch else all_results

        except Exception as e:
            print(f"Error in vLLM generation: {e}")
            if is_batch:
                return [[(None, 0, 0, 0, 0)] * n for _ in range(len(token_sequences))]
            else:
                return [(None, 0, 0, 0, 0)] * n


def process_single_sample(
        embedding_model,
        vllm_model,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor],
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
        use_cache: bool,
        logits_processor,
        stopping_criteria,
        generation_config,
        synced_gpus: bool,
        tokenizer=None,
        device=None,
        num_beams=1,
        num_beam_groups=1,
        num_sub_beams=1,
        forward_steps=1,
        return_dict_in_generate=False,
        eos_token_id=None,
        pad_token_id=None,
):
    """
    Process a single sample using Semantic-guided Diverse Decoding (SemDiD).
    Fully parallelized version that processes all beams at once with single vLLM call.
    Modified to work directly with token IDs throughout and align with paper descriptions.
    """
    import math  # Ensure math module is imported

    # Ensure device is set
    if device is None:
        device = input_ids.device

    # Track punctuation tokens for exploration stopping
    punct_tokens = ['.', '!', '?']
    punct_tokens_ids = [tokenizer.encode(token)[0] for token in punct_tokens]

    # Parameters for dynamic exploration
    base_explore_width = int(3 * num_sub_beams)  # Base number of sequences to explore
    base_beam_width = num_sub_beams  # Base number of beams to keep
    beam_width_multiplier = 1.5  # Hyperparameter for width increase

    # Initialize parameters for quality assessment debiasing
    beta_seq = 0.001  # Sequence position decay rate
    beta_sent = 0.005  # Sentence position decay rate
    tau = -0.6  # Saturation threshold τ

    # Parameters for harmonic gain function and ε-constraint
    harmonic_strength = 2.0  # λ parameter for harmonic gain
    quality_relaxation = 0.2  # γ relaxation parameter for ε-constraint (0 < γ < 1)

    # Since log probabilities are negative, we need to be careful with the relaxation
    # γ < 1 means we're allowing lower (more negative) log probabilities than the greedy baseline

    max_ctx_len = 4096
    group_step = 0
    max_group_step = 5

    # Get initial context embedding (question embedding)
    with torch.no_grad():
        # Convert input IDs to list
        original_input_ids = input_ids[0].tolist()
        pad_token_id = tokenizer.pad_token_id
        if pad_token_id in original_input_ids:
            # Find the index of the first non-padding token
            first_non_pad_idx = original_input_ids.index(
                next(token_id for token_id in original_input_ids if token_id != pad_token_id))
            # Remove padding by slicing
            original_input_ids = original_input_ids[first_non_pad_idx:]  # Remove padding tokens
        original_input_ids = original_input_ids[-max_ctx_len:]  # Limit to last 4096 tokens

        # Get text representation for initial embedding only
        context_text = tokenizer.decode(original_input_ids[-500:])
        context_embedding = embedding_model.encode(context_text)

    # Generate orthogonal direction vectors for guidance
    embedding_dim = context_embedding.shape[0]
    guidance_vectors = generate_orthogonal_vectors_efficient(
        dim=embedding_dim,
        num_vectors=num_beam_groups - 1,
        device=device
    )

    # Original length and token count
    prompt_token_count = len(original_input_ids)

    # Prepare EOS token handling
    if isinstance(eos_token_id, int):
        eos_token_ids = [eos_token_id]
    else:
        eos_token_ids = eos_token_id if eos_token_id else []

    # Set initial quality threshold epsilon to negative infinity (will be updated by greedy group)
    # Since log probabilities are negative, we start with negative infinity as the worst possible score
    epsilon = float('-inf')

    # Create initial beams - one for each group
    beam_groups = []
    for group_idx in range(num_beam_groups):
        # Initialize hypotheses for this group
        group_beams = []

        # First hypothesis is always the input sequence
        group_beams.append({
            "tokens": deepcopy(original_input_ids),  # Store token IDs directly
            "score": torch.tensor(0.0, device=device),
            "is_finished": False,
            "punct_seen": False,
            "dynamic_beam_width": base_beam_width,
            "dynamic_explore_width": base_explore_width,
            "token_count": 0,
            "cumulative_logprob": 0.0,
        })

        # Add this group
        beam_groups.append({
            "beams": group_beams,
            "best_score": torch.tensor(0.0, device=device),
            "direction": None if group_idx == 0 else guidance_vectors[group_idx - 1]
        })

    # Main generation loop
    finished_sequences = []
    max_length = 4096

    # Track all embeddings for each group
    group_embeddings = [[] for _ in range(num_beam_groups)]

    # Initialize group centroid embeddings with context embedding
    group_centroids = [context_embedding.clone() for _ in range(num_beam_groups)]

    # Current mix weights
    alpha_init = 1.0
    beta_init = 1.0
    gamma_init = 0.0
    alpha = alpha_init
    beta = beta_init
    gamma = gamma_init

    # Continue until all groups are finished
    all_groups_finished = False

    while not all_groups_finished:
        # Process each group separately
        for group_idx, group in enumerate(beam_groups):
            # Filter active beams - beams that haven't finished generation
            active_beams = [beam for beam in group["beams"] if not beam["is_finished"]]

            # If no active beams in this group, continue to next group
            if len(active_beams) == 0:
                logger.info(f"Group {group_idx} has no active beams, skipping")
                continue

            # If it's the first group, use greedy search with beam width of 1
            if group_idx == 0:
                # First group uses greedy search
                active_beams = active_beams[:1]  # Only keep the first beam

                # Perform model forward steps for this beam
                beam = active_beams[0]
                current_tokens = beam["tokens"]  # Use token IDs directly
                current_token_count = beam["token_count"]

                # Generate with vLLM using token IDs directly
                generation_results = vllm_model.generate(
                    prompt=None,  # Not using text prompt
                    max_tokens=max_length - current_token_count,
                    temperature=0.0,  # Use 0 for greedy search
                    stop_token_ids=eos_token_ids,
                    n=1,
                    input_tokens=current_tokens  # Pass tokens directly
                )

                # Get the first (and only) result
                generated_token_ids, generated_token_count, prompt_tokens, avg_log_prob, last_token_entropy = \
                    generation_results[0]

                # Apply position debiasing to the log probabilities
                debiased_logprob = avg_log_prob

                # Correctly apply position debiasing to log probabilities
                if generated_token_count > 0:
                    # Approximate median position of generated tokens
                    median_position = current_token_count + generated_token_count // 2
                    # Apply sequence position debiasing (adding to logprob, not multiplying)
                    seq_position_debiasing = -beta_seq * median_position
                    # Apply simplified sentence debiasing (adding to logprob, not multiplying)
                    sent_position_debiasing = -beta_sent * (generated_token_count // 4)
                    # Apply debiasing by addition in log space
                    debiased_logprob = avg_log_prob + seq_position_debiasing + sent_position_debiasing
                    # Apply saturation threshold in log space
                    debiased_logprob = max(debiased_logprob, tau)

                # Update cumulative logprob
                new_cumulative_logprob = beam["cumulative_logprob"] + (debiased_logprob * generated_token_count)

                # Update total token count
                new_token_count = current_token_count + generated_token_count

                # Calculate full sequence avg logprob
                full_seq_avg_logprob = new_cumulative_logprob / new_token_count if new_token_count > 0 else 0.0

                # Combine token sequences
                explore_tokens = current_tokens + generated_token_ids

                # Check if finished
                is_finished = False

                # Check for EOS token in generated tokens
                if any(token_id in generated_token_ids for token_id in eos_token_ids):
                    is_finished = True

                # Check if max length reached
                if new_token_count >= max_length:
                    is_finished = True

                # Check for punctuation - even for greedy search
                punct_seen = beam["punct_seen"]
                if generated_token_ids and generated_token_ids[-1] in punct_tokens_ids:
                    punct_seen = True
                    # We don't adjust beam width for greedy search, but we do mark it

                # Get embedding for the generated sequence
                with torch.no_grad():
                    # Use the last 500 tokens for embedding
                    tokens_for_embedding = explore_tokens[-500:] if len(explore_tokens) > 500 else explore_tokens
                    # Convert tokens to text for embedding
                    text_for_embedding = tokenizer.decode(tokens_for_embedding)
                    seq_embedding = embedding_model.encode(text_for_embedding)

                # Create updated beam
                updated_beam = {
                    "tokens": explore_tokens,  # Store token IDs
                    "score": torch.tensor(full_seq_avg_logprob, device=device),
                    "is_finished": is_finished,
                    "punct_seen": punct_seen,
                    "dynamic_explore_width": beam.get("dynamic_explore_width", base_explore_width),
                    "dynamic_beam_width": beam.get("dynamic_beam_width", base_beam_width),
                    "token_count": new_token_count,  # Updated token count
                    "cumulative_logprob": new_cumulative_logprob,  # Updated cumulative logprob
                    "entropy": last_token_entropy,
                    "quality_score": full_seq_avg_logprob,  # Store quality score for ε-constraint
                }

                # Update epsilon quality threshold based on greedy group quality
                # For log probabilities (negative values), the relaxation means making it more negative
                # If quality_relaxation = 0.2, and log prob = -2.0, then epsilon = -2.0 / 0.2 = -10.0
                # This means we're requiring other groups to have a log prob better than -10.0
                epsilon = max(epsilon, full_seq_avg_logprob / quality_relaxation)
                # By using division with a value between 0 and 1, we make the threshold more lenient
                # The smaller quality_relaxation is, the more lenient the threshold

                # Add to finished sequences if done
                if is_finished:
                    finished_sequences.append({
                        "tokens": explore_tokens,  # Store token IDs
                        "score": torch.tensor(full_seq_avg_logprob, device=device),
                        "group": group_idx,
                        "token_count": new_token_count,
                        "cumulative_logprob": new_cumulative_logprob,
                        "avg_logprob": full_seq_avg_logprob
                    })

                # Add embedding to group embeddings
                group_embeddings[group_idx].append(seq_embedding)

                # Update group beams with just this single beam
                group["beams"] = [updated_beam]

                # Update group centroid with new embedding
                group_centroids[group_idx] = seq_embedding

                # Continue to next group - we've handled this group differently
                continue

            else:
                # For diverse groups, parallelize multi-step lookahead exploration for all active beams
                # Collect all token prompts and beam information for batch processing
                batch_token_prompts = []
                beam_mapping = []  # Maps result index back to beam and explore index

                # Prepare all token prompts for a single batch call to vLLM
                for beam_idx, beam in enumerate(active_beams):
                    current_tokens = deepcopy(beam["tokens"])  # Use tokens directly
                    dynamic_explore_width = beam.get("dynamic_explore_width", base_explore_width)

                    # Add this beam's token prompt multiple times based on explore width
                    for explore_idx in range(dynamic_explore_width):
                        batch_token_prompts.append(deepcopy(current_tokens))
                        beam_mapping.append({
                            "beam_idx": beam_idx,
                            "beam": beam,
                            "explore_idx": explore_idx
                        })

                # Skip if no prompts to process
                if not batch_token_prompts:
                    continue

                # Generate all continuations in a single batch call
                # Add punctuation tokens to the stop tokens for non-greedy beams
                combined_stop_tokens = eos_token_ids + punct_tokens_ids  # Add punctuation as stop tokens

                all_generation_results = vllm_model.generate(
                    prompt=None,  # Not using text prompts
                    max_tokens=forward_steps,
                    temperature=1.5,
                    stop_token_ids=combined_stop_tokens,  # Include punctuation as stop tokens
                    n=1,  # Each prompt is duplicated based on explore width, so n=1 for each
                    input_tokens=batch_token_prompts  # Pass token IDs directly
                )

                # Collect tokens and candidate info
                token_sequences = []
                temp_candidates = []

                # Process all generated sequences
                for batch_idx, (mapping, results) in enumerate(zip(beam_mapping, all_generation_results)):
                    beam_idx = mapping["beam_idx"]
                    beam = mapping["beam"]

                    # Extract the single result (since n=1 for each prompt)
                    generated_token_ids, generated_token_count, prompt_tokens, avg_log_prob, last_token_entropy = \
                        results[0]

                    # Get beam's current state
                    current_tokens = deepcopy(beam["tokens"])
                    current_token_count = deepcopy(beam["token_count"])
                    current_cumulative_logprob = beam["cumulative_logprob"]
                    punct_seen = beam["punct_seen"]

                    # Correctly apply position debiasing to log probabilities
                    debiased_logprob = avg_log_prob

                    if generated_token_count > 0:
                        # Approximate median position of generated tokens
                        median_position = current_token_count + generated_token_count // 2
                        # Apply sequence position debiasing (adding to logprob, not multiplying)
                        seq_position_debiasing = -beta_seq * median_position
                        # Apply simplified sentence debiasing (adding to logprob, not multiplying)
                        sent_position_debiasing = -beta_sent * (generated_token_count // 4)
                        # Apply debiasing by addition in log space
                        debiased_logprob = avg_log_prob + seq_position_debiasing + sent_position_debiasing
                        # Apply saturation threshold in log space
                        debiased_logprob = max(debiased_logprob, tau)

                    # Update token counts and logprobs
                    new_cumulative_logprob = current_cumulative_logprob + (debiased_logprob * generated_token_count)
                    new_token_count = current_token_count + generated_token_count
                    full_seq_avg_logprob = new_cumulative_logprob / new_token_count if new_token_count > 0 else 0.0

                    # Combine token sequences
                    explore_tokens = current_tokens + generated_token_ids

                    # Check if finished
                    path_is_finished = False

                    # Check for EOS token in generated tokens
                    if any(token_id in generated_token_ids for token_id in eos_token_ids):
                        path_is_finished = True

                    # Check max length
                    if new_token_count >= max_length:
                        path_is_finished = True

                    # Check for punctuation and set flag if found
                    path_punct_seen = punct_seen
                    if generated_token_ids and generated_token_ids[-1] in punct_tokens_ids:
                        path_punct_seen = True
                        # This will also cause vLLM to stop generation for this sequence

                    # Add to token sequences for later embedding
                    token_sequences.append(explore_tokens)

                    # Save candidate info
                    temp_candidates.append({
                        "tokens": deepcopy(explore_tokens),  # Store token IDs
                        "raw_prob": full_seq_avg_logprob,
                        "punct_seen": path_punct_seen,
                        "is_finished": path_is_finished,
                        "entropy": last_token_entropy,
                        "source_beam_idx": beam_idx,
                        "parent_beam": beam,
                        "token_count": deepcopy(new_token_count),
                        "cumulative_logprob": new_cumulative_logprob,
                        "quality_score": full_seq_avg_logprob  # Store quality score for ε-constraint
                    })

                # Skip if no token sequences to embed
                if not token_sequences:
                    continue

                # Prepare text for embedding (last 500 tokens of each sequence)
                sequences_to_embed = []
                for tokens in token_sequences:
                    tokens_for_embedding = tokens[-500:] if len(tokens) > 500 else tokens
                    text_for_embedding = tokenizer.decode(tokens_for_embedding)
                    sequences_to_embed.append(text_for_embedding)

                # Batch compute embeddings for all sequences at once
                with torch.no_grad():
                    batch_embeddings = embedding_model.encode(sequences_to_embed)

                # Process embeddings and create candidates
                all_candidates = []

                for i, temp_candidate in enumerate(temp_candidates):
                    # Get embedding for this candidate
                    if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() > 1:
                        embedding = batch_embeddings[i]
                    else:
                        embedding = batch_embeddings

                    # Add complete candidate with embedding
                    all_candidates.append({
                        **temp_candidate,
                        "embedding": embedding
                    })

                # If no candidates were generated, continue to next group
                if len(all_candidates) == 0:
                    logger.warning(f"No candidates generated for group {group_idx}, skipping")
                    continue

                # Normalize quality scores
                raw_quality_scores = [candidate["raw_prob"] for candidate in all_candidates]
                min_quality = min(raw_quality_scores)
                max_quality = max(raw_quality_scores)
                quality_range = max_quality - min_quality if max_quality > min_quality else 1.0

                # Collect direction and repulsion scores for normalization
                direction_scores = []
                repulsion_scores = []

                # First pass: collect all scores
                for candidate in all_candidates:
                    candidate_embedding = candidate["embedding"]

                    # Calculate directional guidance score
                    direction_score = 0.0
                    if group_idx > 0:
                        direction_vector = group["direction"]
                        displacement = candidate_embedding - context_embedding
                        cos_sim = F.cosine_similarity(
                            displacement.unsqueeze(0),
                            direction_vector.unsqueeze(0)
                        )
                        direction_score = abs(cos_sim[0].item())
                    direction_scores.append(direction_score)

                    # Calculate inter-group repulsion score
                    repulsion_score = 0.0
                    for other_idx, other_centroid in enumerate(group_centroids):
                        if other_idx != group_idx:
                            cos_sim = F.cosine_similarity(
                                candidate_embedding.unsqueeze(0),
                                other_centroid.unsqueeze(0)
                            )
                            repulsion_score += 1.0 - cos_sim[0].item()

                    # Normalize repulsion score by number of other groups
                    if num_beam_groups > 1:
                        repulsion_score = repulsion_score / (num_beam_groups - 1)
                    repulsion_scores.append(repulsion_score)

                # Get ranges for normalization
                min_direction = min(direction_scores) if direction_scores else 0.0
                max_direction = max(direction_scores) if direction_scores else 1.0
                direction_range = max_direction - min_direction if max_direction > min_direction else 1.0

                min_repulsion = min(repulsion_scores) if repulsion_scores else 0.0
                max_repulsion = max(repulsion_scores) if repulsion_scores else 1.0
                repulsion_range = max_repulsion - min_repulsion if max_repulsion > min_repulsion else 1.0

                # Score candidates with weighted combination of factors
                # Correct: normalize within whole groups rather than one group
                for i, candidate in enumerate(all_candidates):
                    # Get probability score - normalize to [0,1] range for quality
                    raw_quality = candidate["raw_prob"]

                    # Normalize the quality score (higher is better)
                    # Since log probs are negative, we use relative improvement over worst quality
                    # if quality_range > 0:
                    #     norm_quality = (raw_quality - min_quality) / quality_range
                    # else:
                    #     norm_quality = 0.5  # If all candidates have same quality
                    norm_quality = math.e**raw_quality

                    prob_score = torch.tensor(norm_quality, device=device)

                    # Normalize direction score
                    # if direction_range > 0:
                    #     norm_direction = (direction_scores[i] - min_direction) / direction_range
                    # else:
                    #     norm_direction = 0.5 if direction_scores[i] > 0 else 0.0
                    norm_direction = direction_scores[i]

                    # Normalize repulsion score
                    # if repulsion_range > 0:
                    #     norm_repulsion = (repulsion_scores[i] - min_repulsion) / repulsion_range
                    # else:
                    #     norm_repulsion = 0.5 if repulsion_scores[i] > 0 else 0.0
                    norm_repulsion = repulsion_scores[i]

                    # Calculate diversity score as weighted combination of normalized directional guidance and repulsion
                    # Use the group_step based weighting for alpha transition as in original code
                    group_step_progress = group_step / max_group_step
                    current_beta = max(0, beta_init - group_step_progress)
                    current_gamma = min(1, gamma_init + group_step_progress)

                    # Normalize weights
                    total = current_beta + current_gamma
                    if total > 0:
                        current_beta = current_beta / total
                        current_gamma = current_gamma / total

                    # Calculate combined diversity score using normalized components
                    diversity_score = current_beta * norm_direction + current_gamma * norm_repulsion

                    # Apply ε-constraint - check if raw quality is above epsilon threshold
                    if raw_quality < epsilon:
                        combined_score = float('-inf')
                    else:
                        # Use harmonic gain function to prioritize improving the weakest aspect
                        # Using normalized scores for quality and diversity
                        if prob_score + diversity_score > 0:
                            combined_score = (
                                    harmonic_strength * prob_score * diversity_score /
                                    (prob_score + diversity_score)
                            )
                        else:
                            combined_score = 0.0
                    # logger.info(f"harmonic_strength: {harmonic_strength}, prob_score: {prob_score}, diversity_score: {diversity_score}, norm_direction: {norm_direction}, norm_repulsion: {norm_repulsion}, combined_score: {combined_score}")

                    # Store score and component scores for logging
                    candidate["score"] = combined_score
                    candidate["prob_score"] = norm_quality  # Store normalized quality score
                    candidate["direction_score"] = norm_direction  # Store normalized direction score
                    candidate["repulsion_score"] = norm_repulsion  # Store normalized repulsion score
                    candidate["diversity_score"] = diversity_score  # Store combined diversity score
                    candidate["raw_quality"] = raw_quality  # Store original quality score for threshold checks

                # Calculate mean dynamic beam width from active beams
                if active_beams:
                    group_dynamic_beam_width = sum(beam.get("dynamic_beam_width", base_beam_width)
                                                   for beam in active_beams) // len(active_beams)
                else:
                    group_dynamic_beam_width = base_beam_width

                # Sort and select top candidates
                sorted_candidates = sorted(all_candidates, key=lambda x: x["score"], reverse=True)
                selected_candidates = sorted_candidates[:group_dynamic_beam_width]

                # Update beam group with selected candidates
                updated_beams = []
                for candidate in selected_candidates:
                    # Get is_finished status from candidate
                    is_finished = candidate["is_finished"]

                    # Double-check sequence length constraint
                    if not is_finished:
                        # Check if token count exceeds max_length
                        if candidate["token_count"] >= max_length:
                            is_finished = True

                    # Calculate dynamic explore width for next iteration
                    dynamic_explore_width = base_explore_width
                    if candidate["punct_seen"]:
                        # Increase exploration width by the multiplier for sequences that reached punctuation
                        dynamic_explore_width = int(base_explore_width * beam_width_multiplier)

                    # Calculate dynamic beam width for next iteration
                    dynamic_beam_width = base_beam_width
                    if candidate["punct_seen"]:
                        # Increase beam width by the multiplier for sequences that reached punctuation
                        dynamic_beam_width = int(base_beam_width * beam_width_multiplier)

                    updated_beams.append({
                        "tokens": deepcopy(candidate["tokens"]),  # Store token IDs
                        "score": torch.tensor(candidate["score"], device=device),
                        "is_finished": is_finished,
                        "punct_seen": candidate["punct_seen"],
                        "entropy": candidate["entropy"],
                        "dynamic_explore_width": dynamic_explore_width,
                        "dynamic_beam_width": dynamic_beam_width,
                        "token_count": deepcopy(candidate["token_count"]),
                        "cumulative_logprob": candidate["cumulative_logprob"],
                        "raw_prob": candidate["raw_quality"],  # Store original raw probability
                        "quality_score": candidate["prob_score"],  # Store normalized quality score
                    })

                    # Add to finished sequences if done
                    if is_finished:
                        finished_sequences.append({
                            "tokens": deepcopy(candidate["tokens"]),  # Store token IDs
                            "score": torch.tensor(candidate["score"], device=device),
                            "group": group_idx,
                            "token_count": deepcopy(candidate["token_count"]),
                            "cumulative_logprob": candidate["cumulative_logprob"],
                            "avg_logprob": candidate["raw_quality"]  # Use raw log probability
                        })

                    # Add embedding to group embeddings
                    group_embeddings[group_idx].append(candidate["embedding"])

                # Update group beams
                group["beams"] = updated_beams

                # Update group centroid with new embeddings
                if len(group_embeddings[group_idx]) > 0:
                    group_centroids[group_idx] = torch.stack(group_embeddings[group_idx]).mean(dim=0)

        # Update progress based on the average token length across all groups
        current_token_counts = []
        for group in beam_groups:
            for beam in group["beams"]:
                if len(group["beams"]) > 0:
                    current_token_counts.append(deepcopy(beam["token_count"]))

        current_token_count = sum(current_token_counts) / len(current_token_counts) if current_token_counts else 0
        progress = min(1.0, current_token_count / max_length)

        # Update adaptive weights based on group step progress as in original code
        group_step_progress = group_step / max_group_step
        beta = max(0, beta_init - group_step_progress)
        gamma = min(1, gamma_init + group_step_progress)
        group_step += 1

        # Normalize weights
        total = beta + gamma
        # alpha 不计入统计
        if total > 0:
            beta = beta / total
            gamma = gamma / total

        # Log the current generation state
        logger.info(f"Generation progress: {progress:.2f}, Current Token Count: {current_token_counts}, "
                    f"Max Length: {max_length}, Alpha: {alpha:.4f}, Beta: {beta:.4f}, Gamma: {gamma:.4f}, "
                    f"group_step_progress: {group_step_progress:.4f}, Epsilon: {epsilon:.4f}")

        # Check if all groups are finished
        all_groups_finished = all([
            all([beam["is_finished"] for beam in group["beams"]])
            for group in beam_groups
            if len(group["beams"]) > 0
        ])

        # print(f"beam_groups:")
        # for group in beam_groups:
        #     for beam in group["beams"]:
        #         print(f"    tokens: {beam['tokens']}")
        #         print(f"    score: {beam['score']}")
        #         print(f"    is_finished: {beam['is_finished']}")
        #         print(f"    dynamic_beam_width: {beam['dynamic_beam_width']}")
        #         print(f"    dynamic_explore_width: {beam['dynamic_explore_width']}")
        #         print(f"    token_count: {beam['token_count']}")
        #         print(f"    cumulative_logprob: {beam['cumulative_logprob']}")

        # Apply stopping criteria
        if all_groups_finished:
            logger.info("All groups are finished, stopping generation")
            break

    # If no finished sequences, take the best ongoing sequences
    if len(finished_sequences) == 0:
        logger.warning("No finished sequences, using best ongoing sequences")
        for group_idx, group in enumerate(beam_groups):
            if len(group["beams"]) > 0:
                best_beam = max(group["beams"], key=lambda x: x["score"])
                finished_sequences.append({
                    "tokens": best_beam["tokens"],  # Store token IDs
                    "score": best_beam["score"],
                    "group": group_idx,
                    "token_count": best_beam["token_count"],
                    "cumulative_logprob": best_beam["cumulative_logprob"],
                    "avg_logprob": best_beam["cumulative_logprob"] / best_beam["token_count"] if best_beam[
                                                                                                     "token_count"] > 0 else 0.0
                })

    # Select best sequence from each group
    final_sequences = []
    final_scores = []
    final_avg_logprobs = []

    for group_idx in range(num_beam_groups):
        group_sequences = [seq for seq in finished_sequences if seq["group"] == group_idx]

        if len(group_sequences) > 0:
            # Take best sequence from group
            best_seq = max(group_sequences, key=lambda x: x["score"])

            # Convert token sequence to tensor
            seq_ids = torch.tensor([best_seq["tokens"]], device=device)
            final_sequences.append(seq_ids)
            final_scores.append(best_seq["score"])
            final_avg_logprobs.append(best_seq.get("avg_logprob", 0.0))

            # Log the best sequence from each group
            logger.info(f"Final sequence from group {group_idx} - "
                        f"Score: {best_seq['score']:.4f}, "
                        f"Avg LogProb: {best_seq.get('avg_logprob', 0.0):.4f}, "
                        f"Token Count: {best_seq['token_count']}, "
                        f"Cumulative LogProb: {best_seq['cumulative_logprob']:.4f}")
        else:
            # No finished sequence for this group, use the input as fallback
            final_sequences.append(input_ids)
            final_scores.append(torch.tensor(0.0, device=device))
            final_avg_logprobs.append(0.0)
            logger.info(f"No finished sequence for group {group_idx}, using input as fallback")

    # Final result processing
    decoded_sequences = []

    for seq_ids in final_sequences:
        # Convert tensor to list
        token_ids = seq_ids[0].tolist()

        # Skip the input portion by using token count from prompt
        output_tokens = token_ids[prompt_token_count:]

        # Decode only the output tokens
        decoded_text = tokenizer.decode(output_tokens, skip_special_tokens=True)
        decoded_sequences.append(decoded_text)

    return decoded_sequences


def _semantic_diverse_group_beam_search(
        model,
        input_ids: torch.LongTensor,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool,
        tokenizer=None,
        **model_kwargs,
):
    """
    Semantic guided diverse beam search generation function (SemDiD) using direct vLLM.
    Initializes models internally without global variables.

    Implements:
    1. Orthogonal Direction Guidance - first group is greedy, others follow orthogonal vectors
    2. Dynamic Multi-step Lookahead - explores multiple steps ahead before scoring
    3. Adaptive beam width based on distribution entropy
    4. Balance between probability and diversity using weighted scoring

    Args:
        model: Language model used for generation (will use model_name for vLLM)
        input_ids: Input token sequences (batch_size, seq_len)
        logits_processor: An instance of LogitsProcessorList
        stopping_criteria: An instance of StoppingCriteriaList
        generation_config: GenerationConfig instance with generation parameters
        synced_gpus: Whether to continue running for max_length in distributed setting
        tokenizer: Tokenizer for the main inference model
        **model_kwargs: Additional model-specific arguments

    Returns:
        List of decoded output sequences as strings
    """
    # Initialize values from generation_config
    pad_token_id = generation_config.pad_token_id
    eos_token_id = generation_config.eos_token_id
    if pad_token_id is None:
        pad_token_id = eos_token_id
    return_dict_in_generate = generation_config.return_dict_in_generate

    # Set default parameters if not provided in generation_config
    forward_steps = 30  # Default forward exploration steps
    # if hasattr(generation_config, "forward_steps"):
    #     forward_steps = generation_config.forward_steps

    # Get model name for vLLM
    model_name = model.name_or_path
    gpu_memory_utilization = 0.7
    embedding_model_name = "all-MiniLM-L6-v2"

    # Create tokenizer if not provided
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Check if required components are provided
    if tokenizer is None:
        raise ValueError("tokenizer must be provided for semantic diverse beam search")

    # Get batch size and device
    device = input_ids.device
    batch_size, seq_len = input_ids.shape

    # Extract beam search parameters from generation_config
    num_beams = generation_config.num_beams
    num_beam_groups = generation_config.num_beam_groups

    # Validate beam groups
    if num_beam_groups > num_beams:
        raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
    if num_beams % num_beam_groups != 0:
        raise ValueError("`num_beams` should be divisible by `num_beam_groups`")

    num_sub_beams = num_beams // num_beam_groups

    # Process each batch individually
    all_decoded_outputs = []  # Changed to store decoded strings instead of tensors
    attention_mask = model_kwargs.get("attention_mask")
    past_key_values = model_kwargs.get("past_key_values")
    use_cache = model_kwargs.get("use_cache", True)

    # Initialize embedding model
    embedding_model = EmbeddingModelWrapper(embedding_model_name, device=device)

    # Initialize vLLM model
    vllm_model = DirectVLLMWrapper(
        model_name=model_name,
        gpu_memory_utilization=gpu_memory_utilization,
        tokenizer=tokenizer
    )

    for batch_idx in tqdm(range(batch_size), desc="Batch Loop"):
        # Get batch input tokens
        single_input_ids = input_ids[batch_idx].unsqueeze(0)
        single_attention_mask = attention_mask[batch_idx].unsqueeze(0) if attention_mask is not None else None

        # Process single batch with semantic diverse beam search
        sample_outputs = process_single_sample(
            embedding_model=embedding_model,
            vllm_model=vllm_model,
            input_ids=single_input_ids,
            attention_mask=single_attention_mask,
            past_key_values=deepcopy(past_key_values) if past_key_values is not None else None,
            use_cache=use_cache,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            generation_config=generation_config,
            synced_gpus=synced_gpus,
            tokenizer=tokenizer,
            device=device,
            num_beams=num_beams,
            num_beam_groups=num_beam_groups,
            num_sub_beams=num_sub_beams,
            forward_steps=forward_steps,
            return_dict_in_generate=return_dict_in_generate,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
        )

        # Extend our list of decoded outputs with the sample outputs
        # The sample_outputs from process_single_sample is now a list of strings
        all_decoded_outputs.extend(sample_outputs)

    return all_decoded_outputs


def generate_orthogonal_vectors_efficient(dim, num_vectors, device):
    """
    Generate orthogonal vectors efficiently using QR decomposition.

    Args:
        dim: Dimension of vectors
        num_vectors: Number of orthogonal vectors to generate
        device: PyTorch device

    Returns:
        List of orthogonal vectors
    """
    # Create random matrix
    rand_matrix = torch.randn(dim, num_vectors, device=device)

    # QR decomposition
    q, r = torch.linalg.qr(rand_matrix)

    # Extract orthogonal vectors
    vectors = [q[:, i] for i in range(min(num_vectors, q.shape[1]))]

    # Normalize vectors to small magnitude (so they're subtle guidance directions)
    scale_factor = 0.1  # Small enough to guide but not dominate question embedding
    normalized_vectors = [v * scale_factor for v in vectors]

    return normalized_vectors


# Add a patch function for transformers generation to use our semantic diverse beam search
def patch_transformers_generation():
    """
    Patch the transformers library to use our semantic diverse beam search method
    with direct vLLM integration.
    """
    import transformers.generation.utils as gen_utils

    # Save the original method
    original_group_beam_search = gen_utils.GenerationMixin._group_beam_search

    # Create patching function
    def patched_group_beam_search(self, *args, **kwargs):
        # Check if semantic diverse search flag is set
        if getattr(self, "use_semantic_diverse_beam_search", False):
            return _semantic_diverse_group_beam_search(self, *args, **kwargs)
        else:
            return original_group_beam_search(self, *args, **kwargs)

    # Apply the patch
    gen_utils.GenerationMixin._group_beam_search = patched_group_beam_search
    print("Patched transformers library for semantic diverse beam search")


# Example usage of the direct vLLM approach
def example_usage():
    """Example of how to use the direct vLLM implementation."""
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

    # Model parameters
    model_name = "meta-llama/Llama-2-7b-hf"  # Replace with your model name
    gpu_memory_utilization = 0.85  # Target GPU memory usage

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Patch transformers for semantic diverse beam search
    patch_transformers_generation()

    # Create a base model (this won't actually be used for generation)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
    model.use_semantic_diverse_beam_search = True  # Enable semantic diverse search

    # Generate with semantic diverse beam search
    input_text = "Write a short story about a robot that discovers emotions."
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

    # Create generation config
    generation_config = GenerationConfig(
        max_length=512,
        num_beams=8,
        num_beam_groups=4,
        diversity_penalty=1.0,
        do_sample=True,
        temperature=1.0,
        top_p=1.0,
        forward_steps=300  # Custom parameter for multi-step lookahead
    )

    # Set additional parameters for vLLM and embedding
    additional_params = {
        "vllm_model_name": model_name,
        "gpu_memory_utilization": gpu_memory_utilization,
        "embedding_model_name": "all-MiniLM-L6-v2"
    }

    # Generate
    outputs = model.generate(
        input_ids,
        generation_config=generation_config,
        **additional_params
    )

    # Get generated text
    generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

    # Print results
    for i, text in enumerate(generated_texts):
        print(f"== Group {i + 1} ==")
        print(text)
        print("=" * 50)

    return generated_texts


if __name__ == "__main__":
    # Example usage
    generated_texts = example_usage()
