import math
import copy
import numpy as np
from typing import Optional, Union

import torch
from torch import nn
import torch.nn.functional as F

from transformers.generation.logits_process import (
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    StoppingCriteriaList,
)
import transformers
from transformers.generation.utils import (
    GenerateNonBeamOutput, 
    GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
)

from transformers.generation.streamers import BaseStreamer

from methods.generation_configs.contrastive_generation_config import GenerationConfigContrastive
from methods.utils.crops_samplers_utils import get_generations, get_next_token_logits


_jsd_log_path = "jsd_values_2.npy"
def _jensen_shannon(p_logits: torch.Tensor, q_logits: torch.Tensor, eps: float = 1e-12) -> float:
    p = F.softmax(p_logits, dim=-1).clamp(min=eps)
    q = F.softmax(q_logits, dim=-1).clamp(min=eps)
    m = 0.5 * (p + q)
    kl_pm = (p * (p.log() - m.log())).sum(dim=-1)
    kl_qm = (q * (q.log() - m.log())).sum(dim=-1)
    jsd = 0.5 * (kl_pm + kl_qm)
    # assume batch_size=1 → scalar
    return jsd.item()

def crops_sample(
    self,
    input_ids: torch.LongTensor,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfigContrastive,
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
    r"""
    Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
    can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        generation_config ([`~generation.GenerationConfig`]):
            The generation configuration to be used as parametrization of the decoding method.
        synced_gpus (`bool`):
            Whether to continue running the while loop until max_length (needed to avoid deadlocking with
            `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
            an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
        A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """
    # init values
    pad_token_id = generation_config._pad_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = generation_config.output_logits
    return_dict_in_generate = generation_config.return_dict_in_generate
    max_length = generation_config.max_length
    has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
    do_sample = generation_config.do_sample

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    #### Intialise variables for CRoPS ####
    pixel_values = model_kwargs.pop("pixel_values", None)
    key_position = generation_config.key_position

    # Lang Prior
    model_kwargs_lang_prior = copy.deepcopy(model_kwargs)
    input_ids_lang_prior = generation_config.input_ids_lang_prior
    lambda_lang_prior = generation_config.lambda_lang_prior

    ## Update attention mask for lang prior
    model_kwargs_lang_prior["attention_mask"] = torch.ones_like(input_ids_lang_prior)

    # Stat Bias
    model_kwargs_stat_bias = copy.deepcopy(model_kwargs)
    alpha_stat_bias = generation_config.alpha_stat_bias

    # Other
    jsd_vals = []
    time_step = 1
    beta_cutoff = torch.tensor(generation_config.beta_cutoff)
    max_threshold_plausibility_constraint = torch.tensor(
        generation_config.max_threshold_plausibility_constraint
    )
    
    # keep track of which sequences are already finished
    batch_size, cur_len = input_ids.shape
    this_peer_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
    model_kwargs_lang_prior = self._get_initial_cache_position(input_ids_lang_prior, model_kwargs_lang_prior)
    model_kwargs_stat_bias = self._get_initial_cache_position(input_ids, model_kwargs_stat_bias)

    while self._has_unfinished_sequences(
        this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
    ):

        outputs, model_kwargs = get_generations(self,
                                                input_ids,
                                                pixel_values=pixel_values,
                                                model_kwargs=model_kwargs,
                                                generation_config=generation_config,
                                                key_position=None, 
                                                use_text_mask=False,
                                                use_fast_v=False,
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states)

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits = get_next_token_logits(outputs, input_ids)

        generation_config.minimum_text_tokens = new_text_tokens(time_step)
        # generation_config.minimum_text_tokens = new_text_tokens1((input_ids.shape[1]))

        # generation_config.minimum_text_tokens = 30

        # logits with Language Prior
        outputs_lang_prior, model_kwargs_lang_prior = get_generations(self,
                                                input_ids_lang_prior,
                                                pixel_values=None,
                                                model_kwargs=model_kwargs_lang_prior,
                                                generation_config=generation_config,
                                                key_position=key_position, 
                                                use_text_mask=True,
                                                use_fast_v=False,
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states)

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits_lang_prior = get_next_token_logits(outputs_lang_prior, input_ids_lang_prior)

        # # # logits with Stat Bias
        outputs_stat_bias, model_kwargs_stat_bias = get_generations(self,
                                                input_ids, 
                                                pixel_values=pixel_values,
                                                model_kwargs=model_kwargs_stat_bias,
                                                generation_config=generation_config,
                                                key_position=key_position, 
                                                use_text_mask=False,
                                                use_fast_v=True,
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states)

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits_stat_bias = get_next_token_logits(outputs_stat_bias, input_ids) 

        # Apply cutoff threshold
        cutoff_th = torch.log(beta_cutoff) + next_token_logits.max(dim=-1, keepdim=True).values
        next_token_logits = next_token_logits.masked_fill(next_token_logits<cutoff_th,-float("inf"))

        log_probs_next_token = torch.log_softmax(next_token_logits, dim=-1)
        probs_next_token = torch.softmax(next_token_logits, dim=-1)

        if probs_next_token.max(dim=-1, keepdim=True).values > max_threshold_plausibility_constraint:
            final_logits = next_token_logits
        else:
            log_probs_next_token_lang_prior = torch.log_softmax(next_token_logits_lang_prior, dim=-1)
            log_probs_next_token_stat_bias = torch.log_softmax(next_token_logits_stat_bias, dim=-1)
            gamma_lang_prior = math.exp(-lambda_lang_prior * time_step)
            # jsd_val = _jensen_shannon(next_token_logits_stat_bias, next_token_logits_lang_prior)
            # with open(_jsd_log_path, "a") as _f:
            #     _f.write(f"{time_step}\t{jsd_val:.6f}\n")

            # final_logits = (1+alpha_stat_bias) * log_probs_next_token - alpha_stat_bias * log_probs_next_token_stat_bias
            # final_logits = final_logits + \
            #     (1-gamma_lang_prior)/gamma_lang_prior * (final_logits - log_probs_next_token_lang_prior)
            
            # Remove Language Prior
            final_logits = log_probs_next_token + \
                (1-gamma_lang_prior)/gamma_lang_prior * (log_probs_next_token - log_probs_next_token_lang_prior)

            # # # # # # # Remove Stat Bias
            final_logits = (1+alpha_stat_bias) * final_logits - alpha_stat_bias * log_probs_next_token_stat_bias

        # if time_step <= 100: 
        #     jsd_val = _jensen_shannon(next_token_logits,
        #                               next_token_logits_lang_prior)
        #     jsd_vals.append(jsd_val)

        time_step += 1

        # pre-process distribution
        next_token_scores = logits_processor(input_ids, final_logits)

        # Store scores, attentions and hidden_states when required
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_logits:
                raw_logits += (final_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # token selection
        if do_sample:
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_tokens = torch.argmax(next_token_scores, dim=-1)

        # finished sentences should have their next token be a padding token
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        input_ids_lang_prior = torch.cat([input_ids_lang_prior, next_tokens[:, None]], dim=-1)

        if streamer is not None:
            streamer.put(next_tokens.cpu())

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1

        # This is needed to properly delete outputs.logits which may be very large for first iteration
        # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
        del outputs, outputs_lang_prior, outputs_stat_bias
        # del outputs, outputs_lang_prior

    # np.save(_jsd_log_path, np.array(jsd_vals, dtype=np.float32))
    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return input_ids

def patch_crops_sampling():
    transformers.generation.utils.GenerationMixin._sample = crops_sample

def new_text_tokens(t,b0= 10, b1=30, lamb=0.001):
    return math.floor(b0 + b1 * (1 - math.exp(-lamb * t)))

# def new_text_tokens(t):
#     return math.floor(5 + math.sqrt(t))
# def new_text_tokens(t):
#     return math.floor(20*math.exp(0.01*t))
# def new_text_tokens(t,b0=30,b1 = 0.06):
#     return math.floor(b0+ b1*t)
    # return math.floor(b0)

# def new_text_tokens1(len):
#     return math.floor(0.25*len)

# def new_text_tokens1(t,len,b0=0.1,b1=0.3,lamda = 0.001):
#     return math.floor((b0 + b1*(1 - math.exp(-lamda*t)))*len)