import math 
import copy 
from typing import Optional, Union 

import torch 
from torch import nn

from transformers.generation.logits_process import (
  LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
  StoppingCriteriaList,
)
import transformers 
from transformers.generation.utils import (
  GenerateNonBeamOutput,
  GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
)

from transformers.generation.streamers import BaseStreamer 

from methods.generation_configs.contrastive_generation_config import GenerationConfigContrastive
from methods.utils.crops_samplers_utils import get_generations, get_next_token_logits 
from methods.utils.vcd_forward_utils import add_diffusion_noise

def vcd_sample(
  self, 
  input_ids: torch.LongTensor,
  logits_processor: LogitsProcessorList,
  stopping_criteria: StoppingCriteriaList, 
  generation_config: GenerationConfigContrastive,
  synced_gpus: bool,
  streamer: Optional["BaseStreamer"] = None,
  **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
  r"""
    Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
    can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        generation_config ([`~generation.GenerationConfig`]):
            The generation configuration to be used as parametrization of the decoding method.
        synced_gpus (`bool`):
            Whether to continue running the while loop until max_length (needed to avoid deadlocking with
            `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
            an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
        A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """
  # init values
  pad_token_id = generation_config._pad_token_tensor
  output_attentions = generation_config.output_attentions
  output_hidden_states = generation_config.output_hidden_states
  output_scores = generation_config.output_scores
  output_logits = generation_config.output_logits
  return_dict_in_generate = generation_config.return_dict_in_generate
  max_length = generation_config.max_length
  has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
  do_sample = generation_config.do_sample

  # init attention / hidden states / scores tuples
  scores = () if (return_dict_in_generate and output_scores) else None
  raw_logits = () if (return_dict_in_generate and output_logits) else None
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
  decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

  # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  if return_dict_in_generate and self.config.is_encoder_decoder:
      encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
      encoder_hidden_states = (
          model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
      )
  
  pixel_values = model_kwargs.pop("pixel_values", None)
  key_position = generation_config.key_position

  # Noise step for VCD 
  model_kwargs_vcd = copy.deepcopy(model_kwargs) 
  noise_step = generation_config.noise_step
  alpha_stat_bias = generation_config.alpha_stat_bias

  # Other
  beta_cutoff = torch.tensor(generation_config.beta_cutoff) 
    

  # keep track of which sequences are already finished
  batch_size, cur_len = input_ids.shape
  this_peer_finished = False
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

  time_step = 0

  model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
  model_kwargs_vcd = self._get_initial_cache_position(input_ids, model_kwargs_vcd)

  while self._has_unfinished_sequences(
    this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
    ):
    outputs, model_kwargs = get_generations(self,
    input_ids,
    pixel_values=pixel_values,
    model_kwargs=model_kwargs,
    generation_config=generation_config,
    key_position=None, 
    use_text_mask=False,
    use_fast_v=False,
    output_attentions=output_attentions, 
    output_hidden_states=output_hidden_states)

    if synced_gpus and this_peer_finished:
      continue

    next_token_logits = get_next_token_logits(outputs, input_ids)

    # token logits for vcd sampling 
    pixel_values_vcd = add_diffusion_noise(pixel_values, noise_step)

    outputs_vcd, model_kwargs_vcd = get_generations(self,
    input_ids,
    pixel_values=pixel_values_vcd,
    model_kwargs=model_kwargs_vcd,
    generation_config=generation_config,
    key_position=key_position,
    use_text_mask=False,
    use_fast_v=False,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states
    )

    if synced_gpus and this_peer_finished:
      continue

    next_token_logits_vcd = get_next_token_logits(outputs_vcd, input_ids)

    # Remove vcd bias
    final_logits = (1+alpha_stat_bias) * next_token_logits - alpha_stat_bias * next_token_logits_vcd

    # Apply cutoff threshold
    cutoff_th = torch.log(beta_cutoff) + next_token_logits.max(dim=-1, keepdim=True).values
    final_logits = final_logits.masked_fill(next_token_logits<cutoff_th,-float("inf"))

    time_step += 1

    # pre-process distribution
    next_token_scores = logits_processor(input_ids, final_logits)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_logits:
            raw_logits += (final_logits,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # token selection
    if do_sample:
        probs = nn.functional.softmax(next_token_scores, dim=-1)
        # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
    else:
        next_tokens = torch.argmax(next_token_scores, dim=-1)

    # finished sentences should have their next token be a padding token
    if has_eos_stopping_criteria:
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

    if streamer is not None:
        streamer.put(next_tokens.cpu())

    unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
    this_peer_finished = unfinished_sequences.max() == 0
    cur_len += 1

    # This is needed to properly delete outputs.logits which may be very large for first iteration
    # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
    del outputs, outputs_vcd , next_token_logits_vcd , pixel_values_vcd
    torch.cuda.empty_cache()

  if streamer is not None:
      streamer.end()

  if return_dict_in_generate:
      if self.config.is_encoder_decoder:
          return GenerateEncoderDecoderOutput(
              sequences=input_ids,
              scores=scores,
              logits=raw_logits,
              encoder_attentions=encoder_attentions,
              encoder_hidden_states=encoder_hidden_states,
              decoder_attentions=decoder_attentions,
              cross_attentions=cross_attentions,
              decoder_hidden_states=decoder_hidden_states,
              past_key_values=model_kwargs.get("past_key_values"),
          )
      else:
          return GenerateDecoderOnlyOutput(
              sequences=input_ids,
              scores=scores,
              logits=raw_logits,
              attentions=decoder_attentions,
              hidden_states=decoder_hidden_states,
              past_key_values=model_kwargs.get("past_key_values"),
          )
  else:
      return input_ids

def patch_vcd_sampling():
    transformers.generation.utils.GenerationMixin._sample = vcd_sample