import transformers
import torch
from transformers import GenerationConfig, BeamScorer, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerateBeamOutput, GenerateBeamEncoderDecoderOutput, GenerateBeamDecoderOnlyOutput
from transformers.generation.utils import _split_model_inputs, stack_model_outputs
from methods.model_forward.crops_qwen_forward import patch_qwen_forward
from typing import Union, Optional, List, Dict
import copy
import torch.nn.functional as F

def opera_beam_search(
    self,
    input_ids: torch.LongTensor,
    beam_scorer: BeamScorer,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    **model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
    """
    Custom beam search decoding method implementing the OPERA strategy.
    
    This function modifies the standard Hugging Face beam search to incorporate
    the OPERA (Over-trust Penalty and Retrospection-Allocation) strategy for
    alleviating hallucinations in multi-modal large language models.
    """
    # init values from generation_config
    pad_token_id = generation_config._pad_token_tensor
    eos_token_id = generation_config._eos_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = generation_config.output_logits
    return_dict_in_generate = generation_config.return_dict_in_generate
    sequential = generation_config.low_memory
    do_sample = generation_config.do_sample
    
    # Initialize OPERA-specific parameters from generation config, with defaults
    try:
        key_position = generation_config.key_position
        scale_factor = generation_config.scale_factor
        threshold = generation_config.threshold
        num_attn_candidates = generation_config.num_attn_candidates
        window_size = generation_config.window_size
        penalty_weights = generation_config.penalty_weights
    except AttributeError:
        # Fallback to defaults if not in config
        key_position = getattr(generation_config, 'key_position', {})
        scale_factor = getattr(generation_config, 'scale_factor', 50.0)
        threshold = getattr(generation_config, 'threshold', 15)
        num_attn_candidates = getattr(generation_config, 'num_attn_candidates', 5)
        window_size = getattr(generation_config, 'window_size', 512)
        penalty_weights = getattr(generation_config, 'penalty_weights', 1.0)
    
    # OPERA-specific state variables
    history_states = []
    history_rollback_locs = None
    max_rollback_time = torch.zeros(window_size, device=input_ids.device)
    reject_token_pos_gather = [[] for _ in range(window_size)]
    model_kwargs_ori = copy.deepcopy(model_kwargs)
    attn_previous = None
    beam_next_tokens = None
    beam_idx = None
    rollback_pos = 0
    history_length = window_size

    batch_size = len(beam_scorer._beam_hyps)
    num_beams = beam_scorer.num_beams

    batch_beam_size, cur_len = input_ids.shape
    if hasattr(self, '_get_initial_cache_position'):
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    if num_beams * batch_size != batch_beam_size:
        raise ValueError(
            f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
        )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    beam_indices = tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
    # of the first beam are considered to avoid sampling the exact same tokens across all beams.
    beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
    beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view((batch_beam_size,))

    this_peer_finished = False

    decoder_prompt_len = input_ids.shape[-1]  # record the prompt length of decoder
    vocab_size = self.config.vocab_size  # Needed for OPERA

    while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
        # --- OPERA: State Capture ---
        current_state = {
            "input_ids": input_ids.clone(),
            "beam_scorer": copy.deepcopy(beam_scorer),
            "beam_indices": copy.deepcopy(beam_indices) if beam_indices is not None else None,
            "cur_len": cur_len,
            "beam_scores": beam_scores.clone(),
            "beam_next_tokens": beam_next_tokens.clone() if beam_next_tokens is not None else None,
            "beam_idx": beam_idx.clone() if beam_idx is not None else None,
            "attn_previous": attn_previous.clone() if attn_previous is not None else None,
        }

        # Prepare model inputs
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        
        # Store original output settings
        original_output_attentions = output_attentions
        
        # --- OPERA: Force attentions output if needed ---
        model_inputs["output_attentions"] = True
        output_attentions = True
            
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        # if sequential is True, split the input to batches of batch_size and run sequentially
        if sequential:
            if any(
                model_name in self.__class__.__name__.lower()
                for model_name in [
                    "fsmt",
                    "reformer",
                    "ctrl",
                    "gpt_bigcode",
                    "transo_xl",
                    "xlnet",
                    "cpm",
                    "jamba",
                ]
            ):
                raise RuntimeError(
                    f"Currently generation for {self.__class__.__name__} is not supported "
                    f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
                )

            inputs_per_sub_batches = _split_model_inputs(
                model_inputs,
                split_size=batch_size,
                full_batch_size=batch_beam_size,
                config=self.config.get_text_config(),
            )
            outputs_per_sub_batch = [
                self(**inputs_per_sub_batches, return_dict=True) for inputs_per_sub_batches in inputs_per_sub_batches
            ]

            outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config())

        else:  # Unchanged original behavior
            outputs = self(**model_inputs, return_dict=True)

        # --- OPERA: Store attention weights ---
        # The logic here is simplified and more robust.
        # Instead of managing a growing `attn_previous` tensor, we
        # compute the full attention matrix for each candidate, ensuring consistency.
        current_seq_len = input_ids.shape[-1]

        # Restore original attention output setting
        output_attentions = original_output_attentions

        # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=self.config.is_encoder_decoder,
        )
        if synced_gpus and this_peer_finished:
            cur_len = cur_len + 1
            continue

        # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
        # (the clone itself is always small)
        # .float() is needed to retain precision for later logits manipulations
        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # --- OPERA: Candidate Token Processing ---

        # Get top candidate tokens
        candidate_token_scores, candidate_tokens = torch.topk(
            next_token_logits, num_attn_candidates, dim=-1, largest=True, sorted=True
        )
        
        # Store in current state
        current_state["candidate_tokens"] = candidate_tokens.clone()
        current_state["candidate_token_scores"] = candidate_token_scores.clone()
        
        attn_last = []
        for candidate_id in range(num_attn_candidates):
            input_ids_tmp = torch.cat([
                input_ids, 
                candidate_tokens[:, candidate_id].unsqueeze(-1)
            ], dim=-1)
            
            # USE CURRENT MODEL_KWARGS INSTEAD OF INITIAL COPY
            model_kwargs_tmp = copy.deepcopy(model_kwargs)
            
            # Remove past_key_values for full forward pass
            model_kwargs_tmp.pop('past_key_values', None)
            
            # Update attention mask
            if 'attention_mask' in model_kwargs_tmp:
                current_mask = model_kwargs_tmp['attention_mask']
                new_mask = torch.cat([
                    current_mask,
                    torch.ones((current_mask.shape[0], 1), device=input_ids.device)
                ], dim=1)
                model_kwargs_tmp['attention_mask'] = new_mask
            
            # Update position_ids if present
            if 'position_ids' in model_kwargs_tmp:
                position_ids = model_kwargs_tmp['position_ids']
                new_position_ids = position_ids[:, -1:] + 1
                model_kwargs_tmp['position_ids'] = torch.cat([
                    position_ids, new_position_ids
                ], dim=1)

            # Forward pass
            outputs_tmp = self(
                input_ids=input_ids_tmp,
                output_attentions=True,
                return_dict=True,
                **model_kwargs_tmp
            )
            
            # The full attention matrix for the new sequence is already in outputs_tmp
            full_new_attn = outputs_tmp.attentions[-1].clone()
            
            # Reduce to single attention head dimension, as originally intended
            attn_square = full_new_attn.max(1, keepdim=True).values.data
            
            attn_last.append(attn_square)

        # Combine all candidate attentions
        attn_last = torch.cat(attn_last, 1)
        # Ensure sum is not zero before dividing
        attn_last = attn_last / attn_last.sum(-1, keepdim=True).clamp(min=1e-8)

        # Extract local attention window
        try:
            response_start = key_position.get("image_end", 0) + 1
        except AttributeError: # Handles case where key_position is a dict-like object without .get
            response_start = key_position.get("image_end", 0) + 1
        
        attn_local = attn_last[:, :, response_start:, response_start:]
        attn_local = scale_factor * attn_local
        
        # Calculate local attention scores
        attn_local_scores = torch.zeros((
            attn_local.shape[0], attn_local.shape[1], attn_local.shape[-1]), 
            dtype=torch.float, device=candidate_token_scores.device)
        
        for j in range(attn_local.shape[-1]):
            # Only consider the lower triangular part
            local_score = 1e-7 * attn_local[..., j:, j].prod(-1).data
            attn_local_scores[..., j] = local_score
        
        # Compute attention scores
        cur_response_lens = attn_local.shape[-1]
        try:
            image_start = key_position.get("image_start", 0)
            image_end = key_position.get("image_end", 0)
        except AttributeError:
            image_start = key_position.get("image_start", 0)
            image_end = key_position.get("image_end", 0)
            
        attn_i = attn_last[:, :, -1, image_start:image_end+1].sum(-1)
        
        # Compute rollback scores
        rollback_scores, rollback_locs = attn_local_scores.max(-1)
        rollback_loc = rollback_locs.mode().values
        
        # Determine penalty scores
        penalty_scores = -attn_i if cur_response_lens <= 10 else rollback_scores
        candidate_token_scores -= penalty_weights * penalty_scores
        current_state["candidate_token_scores"] = candidate_token_scores.clone()
        
        # Update history
        if len(history_states) >= history_length:
            history_states.pop(0)
        history_states.append(current_state)
        
        # Initialize rollback locations history
        if history_rollback_locs is None:
            history_rollback_locs = [rollback_locs.mode().values[:, None]]
        else:
            history_rollback_locs.append(rollback_locs.mode().values[:, None])
            
        rollback_loc_gathers = torch.cat(history_rollback_locs, -1)
        
        # --- OPERA: Rollback Handling ---
        try:
            # Re-evaluating the rollback condition as the previous `all()` was too strict
            rollback_condition = False
            for rollback_loc_gather in rollback_loc_gathers:
                if (rollback_loc_gather == rollback_loc).long().sum() > threshold:
                    rollback_condition = True
                    break

            if rollback_condition and len(history_states) > 1:
                # Determine rollback position
                rollback_pos = rollback_loc.item() + 1
                if rollback_pos >= len(max_rollback_time) or max_rollback_time[rollback_pos] >= num_attn_candidates:
                    rollback_pos = max(1, rollback_pos - 1)
                
                max_rollback_time[rollback_pos] += 1
                
                # Adjust rollback position if needed
                if current_seq_len - rollback_pos > history_length + 1:
                    rollback_pos = max(1, current_seq_len - history_length - 1)
                
                # Remove states after rollback point
                rollback_steps = current_seq_len - rollback_pos - 1
                if rollback_steps > 0:
                    history_states = history_states[:-rollback_steps]
                    history_rollback_locs = history_rollback_locs[:-rollback_steps]
                    for j in range(rollback_steps):
                        idx = len(reject_token_pos_gather) - 1 - j
                        if idx >= 0:
                            reject_token_pos_gather[idx] = []
                
                # Restore state from rollback position
                prev_state = history_states[-1]
                input_ids = prev_state["input_ids"]
                beam_scorer = prev_state["beam_scorer"]
                beam_indices = prev_state["beam_indices"]
                cur_len = prev_state["cur_len"]
                attn_previous = prev_state.get("attn_previous", None)
                candidate_token_scores = prev_state["candidate_token_scores"]
                candidate_tokens = prev_state["candidate_tokens"]
                beam_scores = prev_state["beam_scores"]
                beam_next_tokens = history_states[-1]["beam_next_tokens"]
                beam_idx = history_states[-1]["beam_idx"]
                
                # Restore model state
                model_kwargs = copy.deepcopy(model_kwargs_ori)
                
                # Prepare inputs for forward pass
                if "images" in model_kwargs:
                    # Handle vision-language models
                    model_kwargs["attention_mask"] = torch.cat([
                        model_kwargs["attention_mask"], 
                        torch.ones((input_ids.shape[0], input_ids.shape[1] - model_kwargs["attention_mask"].shape[1]),device=input_ids.device)
                    ], 1)
                    model_inputs_tmp = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
                else:
                    # Handle text-only models
                    model_inputs_tmp = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
                
                # Forward pass to update model state
                outputs_tmp = self(
                    **model_inputs_tmp,
                    return_dict=True,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                )
                model_kwargs = self._update_model_kwargs_for_generation(
                    outputs_tmp, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
                )
                
                # Remove last state from history
                history_states.pop(-1)
                if history_rollback_locs:
                    history_rollback_locs.pop(-1)
                if rollback_pos < len(reject_token_pos_gather):
                    reject_token_pos_gather[rollback_pos] = []
                
                # Prepare penalty for rejected tokens
                next_token_logits = outputs_tmp.logits[:, -1, :].clone().float()
                next_token_logits = next_token_logits - 999. + next_token_logits.min(-1, keepdim=True).values.data
                next_token_logits = next_token_logits.view(batch_size, num_beams * vocab_size)
                
                # Apply rejection penalty
                beam_idx_view = beam_idx.view(batch_size, num_beams)
                beam_next_tokens_view = beam_next_tokens.view(batch_size, num_beams)
                reject_token_pos = beam_idx_view * vocab_size + beam_next_tokens_view
                
                if rollback_pos < len(reject_token_pos_gather) and reject_token_pos_gather[rollback_pos]:
                    reject_token_pos = torch.cat([reject_token_pos_gather[rollback_pos], reject_token_pos], -1)
                
                if rollback_pos < len(reject_token_pos_gather):
                    reject_token_pos_gather[rollback_pos] = reject_token_pos
                next_token_logits = next_token_logits.scatter_(-1, reject_token_pos, -999.)
                next_token_logits = next_token_logits.view(batch_beam_size, vocab_size)
                
                # Skip to next iteration
                continue
        except Exception as e:
            # Fallback to original scores if rollback fails
            next_token_logits.fill_(-999.)
            next_token_logits.scatter_(-1, candidate_tokens, candidate_token_scores)
        
        # Continue with normal processing
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)

        next_token_scores_processed = logits_processor(input_ids, next_token_scores)
        next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
            next_token_scores_processed
        )

        # Store scores, attentions and hidden_states when required
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores_processed.detach(),)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # reshape for beam search
        vocab_size = next_token_logits.shape[-1]
        next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

        # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
        # non eos token per beam.
        n_eos_tokens = len(eos_token_id) if eos_token_id is not None else 0
        n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
        next_token_scores, next_tokens = torch.topk(
            next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
        )

        next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
        next_tokens = next_tokens % vocab_size

        # stateless
        beam_outputs = beam_scorer.process(
            input_ids,
            next_token_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            beam_indices=beam_indices,
            decoder_prompt_len=decoder_prompt_len,
        )

        beam_scores = beam_outputs["next_beam_scores"]
        beam_next_tokens = beam_outputs["next_beam_tokens"]
        beam_idx = beam_outputs["next_beam_indices"]

        input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

        # This is needed to properly delete outputs.logits which may be very large for first iteration
        # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
        # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
        # (that way the memory peak does not include outputs.logits)
        del outputs

        if model_kwargs.get("past_key_values", None) is not None:
            model_kwargs["past_key_values"] = self._temporary_reorder_cache(
                model_kwargs["past_key_values"], beam_idx
            )

        if return_dict_in_generate and output_scores:
            beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

        # increase cur_len
        cur_len = cur_len + 1

        # FIX: The error "Boolean value of Tensor with more than one value is ambiguous"
        # is caused by the stopping_criteria returning a boolean tensor, not a single value.
        # We use .any() to check if ANY of the sequences have met the stopping criteria.
        if beam_scorer.is_done or stopping_criteria(input_ids, scores).any():
            if not synced_gpus:
                break
            else:
                this_peer_finished = True

    sequence_outputs = beam_scorer.finalize(
        input_ids,
        beam_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        max_length=stopping_criteria.max_length,
        beam_indices=beam_indices,
        decoder_prompt_len=decoder_prompt_len,
    )

    if return_dict_in_generate:
        if not output_scores:
            sequence_outputs["sequence_scores"] = None

        if self.config.is_encoder_decoder:
            return GenerateBeamEncoderDecoderOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                logits=raw_logits,
                beam_indices=sequence_outputs["beam_indices"],
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GenerateBeamDecoderOnlyOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                logits=raw_logits,
                beam_indices=sequence_outputs["beam_indices"],
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return sequence_outputs["sequences"]

def patch_beam_search():
    transformers.generation.utils.GenerationMixin._beam_search = opera_beam_search
def patch_everything():
    patch_qwen_forward()
    patch_beam_search()
