import torch
import random
import numpy as np
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, LogitsProcessorList

class LLaDAGenerator(object):
    def __init__(self, config, model, generation_tokenizer):
        self.config = config
        self.model = model
        self.generation_tokenizer = generation_tokenizer
        self.rng = torch.Generator(device=self.model.device).manual_seed(42)
        self.exp_rng = torch.Generator()

    def top_p_logits(self, logits, top_p=None):
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
        mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
        return logits

    def top_k_logits(self, logits, top_k=None):
        top_k = min(top_k, logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
        return logits

    def add_gumbel_noise(self, logits, temperature):
        '''
        The Gumbel max is a method for sampling categorical distributions.
        According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
        Thus, we use float64.
        '''
        if temperature == 0:
            return logits
        logits = logits.to(torch.float64)
        noise = torch.rand_like(logits, dtype=torch.float64)
        gumbel_noise = (- torch.log(noise)) ** temperature
        return logits.exp() / gumbel_noise


    def get_num_transfer_tokens(self, mask_index, steps):
        '''
        In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
        Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
        the expected number of tokens transitioned at each step should be consistent.

        This function is designed to precompute the number of tokens that need to be transitioned at each step.
        '''
        mask_num = mask_index.sum(dim=1, keepdim=True)

        base = mask_num // steps if mask_num > steps else 1
        remainder = mask_num % steps

        num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

        for i in range(mask_num.size(0)):
            num_transfer_tokens[i, :remainder[i]] += 1 if mask_num > steps else 0

        return num_transfer_tokens

    def get_logits(self, input_ids, logits, idx, logits_processor):
        # for idx in range(logits.shape[1]):
        score = logits[0, idx, :].unsqueeze(0)
        logits[0, idx, :] = logits_processor(input_ids, score)
        return logits

    def compute_ppl(self, text: str, ppl_model, ppl_tokenizer) -> float:
        encoded_text = ppl_tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.model.device)
        criterion = torch.nn.CrossEntropyLoss()
        logits = ppl_model(torch.unsqueeze(encoded_text, 0), return_dict=True).logits[0]
        loss = criterion(logits[:-1], encoded_text[1:])
        return torch.exp(loss)

    def seed_rng(self, input_ids: torch.LongTensor) -> None:
        """Seed the random number generator with the last `prefix_length` tokens of the input."""
        time_result = 1
        for i in range(0, self.config.prefix_length):
            time_result *= input_ids[-1 - i].item()
        prev_token = time_result % self.config.vocab_size
        self.exp_rng.manual_seed(self.config.hash_key * prev_token)
        return
    
    
    def exp_sampling(self, probs: torch.Tensor, u: torch.Tensor, top_k) -> torch.Tensor:
        """Sample a token from the vocabulary using the exponential sampling method."""
        
        # If top_k is not specified, use argmax
        if top_k <= 0:
            return torch.argmax(u ** (1 / probs), axis=1).unsqueeze(-1)
        
        # Ensure top_k is not greater than the vocabulary size
        top_k = min(top_k, probs.size(-1))
    
        # Get the top_k probabilities and their indices
        top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
    
        # Perform exponential sampling on the top_k probabilities
        sampled_indices = torch.argmax(u.gather(-1, top_indices) ** (1 / top_probs), dim=-1)
    
        # Map back the sampled indices to the original vocabulary indices
        token = top_indices.gather(-1, sampled_indices.unsqueeze(-1))
        
        return token
    
    # def compute_token_scores(self, x, prompt):
    #     logits = self.model(x).logits
    #     probs = F.softmax(logits, dim=-1)
    #     log_probs = F.log_softmax(logits, dim=-1)

    #     scores = -torch.sum(probs * log_probs, dim=-1)

    #     scores[:, :prompt.shape[1]] = float('inf')
    #     return scores
    def compute_token_scores(self, x, prompt):
        logits = self.model(x).logits  # [B, L, V]
        log_probs = F.log_softmax(logits, dim=-1)  # [B, L, V]

        batch_size, seq_len = x.shape
        token_log_probs = log_probs.gather(dim=2, index=x.unsqueeze(2)).squeeze(2)  # [B, L]

        scores = -token_log_probs  # [B, L]

        scores[:, :prompt.shape[1]] = float('inf')

        return scores

    def apply_scores_remasking(self, x, scores, prompt_len, mask_id, green_token_flags, watermark_algorithm='Unigram', remask_ratio=0.3):
        x_new = x.clone()
        B, L = x.shape

        for b in range(B):
            input_ids = x[b]
            
            candidate_indices = []
            candidate_scores = []

            for idx in range(prompt_len, L):
                if green_token_flags[idx] == 0: 
                    candidate_indices.append(idx)
                    candidate_scores.append(scores[b, idx].item())
            
            if len(candidate_indices) == 0:
                continue

            num_to_mask = max(1, int(len(candidate_indices) * remask_ratio))
            # num_to_mask = len(candidate_indices)

            scores_tensor = torch.tensor(candidate_scores)
            highest_local_indices = torch.topk(scores_tensor, k=num_to_mask, largest=True).indices

            for li in highest_local_indices:
                global_idx = candidate_indices[li]
                x_new[b, global_idx] = mask_id

        return x_new
    
    @ torch.no_grad()
    def generate_watermarked_text(self, input_ids, watermark_algorithm=None, watermark_type=None, steps=128, gen_length=128, block_length=16, temperature=0., detector=None,
                ppl_model=None, ppl_tokenizer=None, top_p=0.95, top_k=50, cfg_scale=0., remasking='confidence', remask_ratio=0.15, logits_processor=None, mask_id=126336, attention_mask=None):
        '''
        Args:
            model: Mask predictor.
            prompt: A tensor of shape (1, L).
            steps: Sampling steps, less than or equal to gen_length.
            gen_length: Generated answer length.
            block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
            temperature: Categorical distribution sampling temperature.
            cfg_scale: Unsupervised classifier-free guidance scale.
            remasking: Remasking strategy. 'confidence' or 'random'.
            mask_id: The toke id of [MASK] is 126336.
        '''
        prompt = input_ids

        x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(self.model.device)
        x[:, :prompt.shape[1]] = prompt.clone()

        prompt_index = (x != mask_id)

        assert gen_length % block_length == 0
        num_blocks = gen_length // block_length

        assert steps % num_blocks == 0
        steps = steps // num_blocks

        LOGITS_LIST = ['Unigram', 'SynthID', 'DIP', 'SWEET', 'EWD', 'MorphMark', 'SIR', 'UPV']
        SAMPLING_LIST = ['EXP']

        for num_block in range(num_blocks):
            block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
            num_transfer_tokens = self.get_num_transfer_tokens(block_mask_index, steps)
            for i in range(steps):
                mask_index = (x == mask_id)

                # Step 1: Compute all logits and probability
                logits = self.model(x).logits
                logits = self.add_gumbel_noise(logits, temperature=temperature)

                # if temperature > 0:
                #     logits = logits / temperature
                # if top_p is not None and top_p < 1:
                #     logits = self.top_p_logits(logits, top_p)
                # if top_k is not None:
                #     logits = self.top_k_logits(logits, top_k)

                p = F.softmax(logits.to(torch.float64), dim=-1)

                # Step 2: Sampling all tokens
                x_ = torch.argmax(p, dim=-1) # b, l
                # x_ = torch.multinomial(p.view(-1, p.size(-1)), num_samples=1, generator=self.rng).view(logits.size(0), logits.size(1))

                if remasking == 'confidence':
                    x_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x_, -1)), -1) # b, l
                elif remasking == 'random':
                    x_p = torch.rand((x_.shape[0], x_.shape[1]), device=x_.device)
                elif remasking == 'entropy':
                    log_p = torch.log(p + 1e-12)              
                    x_p = torch.sum(p * log_p, dim=-1)   
                elif remasking == 'margin':
                    top2_probs, _ = torch.topk(p, k=2, dim=-1) 
                    x_p = top2_probs[..., 0] - top2_probs[..., 1]
                else:
                    raise NotImplementedError(remasking)

                x_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

                x_ = torch.where(mask_index, x_, x)
                confidence = torch.where(mask_index, x_p, -np.inf)

                # Step 3: Select positions
                transfer_index = torch.zeros_like(x_, dtype=torch.bool, device=x_.device)
                for j in range(confidence.shape[0]):
                    _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                    transfer_index[j, select_index] = True
                
                # Step 4: Watermarking
                if watermark_algorithm in LOGITS_LIST:
                    logits = self.get_logits(x, logits, select_index[0], LogitsProcessorList([logits_processor]))

                    if temperature > 0:
                        logits = logits / temperature
                    if top_p is not None and top_p < 1:
                        logits = self.top_p_logits(logits, top_p)
                    if top_k is not None:
                        logits = self.top_k_logits(logits, top_k)

                    p = torch.softmax(logits, dim=-1)

                    threshold = 0.5
                    confidence, argmax_indices = p.max(dim=-1)
                    use_sampling = confidence > threshold
                    sampled_indices = torch.multinomial(p.view(-1, p.size(-1)), 1, generator=self.rng).view(p.size(0), p.size(1))
                    x_ = torch.where(use_sampling, sampled_indices, argmax_indices)  
                    # if watermark_type == 'V':
                    #     threshold = 0.5
                    #     confidence, argmax_indices = p.max(dim=-1)
                    #     use_sampling = confidence > threshold
                    #     sampled_indices = torch.multinomial(p.view(-1, p.size(-1)), 1, generator=self.rng).view(p.size(0), p.size(1))
                    #     x_ = torch.where(use_sampling, sampled_indices, argmax_indices)  
                    # else:
                    # x_ = torch.argmax(p, dim=-1)
                    # x_ = torch.multinomial(p.view(-1, p.size(-1)), 1, generator=self.rng).view(p.size(0), p.size(1))

                elif watermark_algorithm in SAMPLING_LIST:
                    # Generate r1, r2,..., rk
                    self.seed_rng(prompt[0])
                    random_numbers = torch.rand(self.config.vocab_size, generator=self.exp_rng)

                    threshold = 0.5
                    p = torch.softmax(logits, dim=-1)
                    confidence, argmax_indices = p.max(dim=-1)
                    use_sampling = confidence > threshold
                
                    if not use_sampling[0][torch.nonzero(transfer_index)[0][1]] or watermark_type == 'P':
                        p = p[:, torch.nonzero(transfer_index)[0][1], :]
                        # Sample token to add watermark
                        x_[transfer_index] = self.exp_sampling(p.detach().cpu(), random_numbers, self.config.top_k).to(self.model.device)
                    else:
                        sampled_indices = torch.multinomial(p.view(-1, p.size(-1)), 1).view(p.size(0), p.size(1))
                        x_ = torch.where(use_sampling, sampled_indices, argmax_indices)

                x[transfer_index] = x_[transfer_index]

        if watermark_type == 'R':

            decoded_text = self.generation_tokenizer.batch_decode(x, skip_special_tokens=True)[0]
            original_ppl = self.compute_ppl(decoded_text, ppl_model, ppl_tokenizer)

            generated_x = x.clone()
            B, L = generated_x.shape

            encoded_text = generated_x

            if watermark_algorithm == 'Unigram':
                z_score, green_token_flags = logits_processor.utils.score_sequence(encoded_text[0])
                logits_processor.config.delta = 4.0
                entropy_list = None
            elif watermark_algorithm in ['SWEET', 'EWD']:
                entropy_list = logits_processor.utils.calculate_entropy(self.model, encoded_text[0])
                z_score, green_token_flags, _ = logits_processor.utils.score_sequence(encoded_text[0], entropy_list)
                logits_processor.config.delta = 4.0
            elif watermark_algorithm == 'MorphMark':
                entropy_list = logits_processor.utils.calculate_entropy(self.model, encoded_text[0])
                z_score, green_token_flags, _ = logits_processor.utils.score_sequence(encoded_text[0], entropy_list)
                logits_processor.config.k_exp = 4.0
            elif watermark_algorithm == 'DIP':
                z_score, green_token_flags = logits_processor.utils.score_sequence(encoded_text[0])
                green_token_flags = [False] * L
                entropy_list = None
            elif watermark_algorithm == 'UPV':
                green_token_flags, green_token_count, z_score = logits_processor.utils.green_token_mask_and_stats(encoded_text[0])
                logits_processor.config.delta = 3.0
                entropy_list = None
            elif watermark_algorithm == 'SynthID':
                g_values = logits_processor.compute_g_values(encoded_text)
                z_score, per_token_score = logits_processor.compute_score(encoded_text, g_values, detector)
                self.config.z_threshold = 0.52
                green_token_flags = per_token_score > 0.6
                entropy_list = None
            else:
                return generated_x

            z_score_flag = False
            ppl_flag = False
            if z_score < self.config.z_threshold:
                logits_processor.config.delta = 4.0

                # mask_flags = [not is_green for is_green in green_token_flags]
                true_idx = [i for i, f in enumerate(green_token_flags) if not f]
                flip_idx = set(random.sample(true_idx, int(len(true_idx) * remask_ratio)))
                mask_flags = [(not f) ^ (i in flip_idx) for i, f in enumerate(green_token_flags)]

                mask_index = torch.tensor(
                    [mask_flags],
                    dtype=torch.bool,
                    device=generated_x.device
                )

                z_score_flag = True
            elif original_ppl > 10.0:
                logits_processor.config.delta = 1.5

                scores = self.compute_token_scores(generated_x, input_ids)

                x_remask = self.apply_scores_remasking(
                    generated_x,
                    scores,
                    input_ids.shape[1],
                    mask_id,
                    green_token_flags,
                    watermark_algorithm,
                    remask_ratio=remask_ratio
                )
                mask_index = (x_remask == mask_id)
                ppl_flag = True
            else:
                return x

            x_updated = generated_x.clone()

            for b in range(B):
                for idx in torch.where(mask_index[b])[0]:
                    x_updated[b, idx] = mask_id
                    logits = self.model(x_updated).logits
                    logits = self.get_logits(x_updated, logits, idx, LogitsProcessorList([logits_processor]))

                    if temperature > 0:
                        logits = logits / temperature
                    if top_p is not None and top_p < 1:
                        logits = self.top_p_logits(logits, top_p)
                    if top_k is not None:
                        logits = self.top_k_logits(logits, top_k)

                    probs = F.softmax(logits, dim=-1)

                    # sample a token
                    if z_score_flag:
                        new_token = torch.multinomial(probs[0, idx], num_samples=1).item()
                    elif ppl_flag:
                        new_token = torch.argmax(probs[0, idx], dim=-1).item()

                    # prepare candidate
                    x_candidate = x_updated[b].clone()
                    x_candidate[idx] = new_token

                    # decode candidate text
                    candidate_text = self.generation_tokenizer.decode(
                        x_candidate,
                        skip_special_tokens=True
                    )

                    # compute candidate PPL
                    candidate_ppl = self.compute_ppl(candidate_text, ppl_model, ppl_tokenizer)

                    try:
                        candidate_z_score, _ = logits_processor.utils.score_sequence(x_candidate)
                    except:
                        try:
                            entropy_list = logits_processor.utils.calculate_entropy(self.model, encoded_text[0])
                            candidate_z_score, green_token_flags, _ = logits_processor.utils.score_sequence(x_candidate, entropy_list)
                        except:
                            candidate_z_score, _ = logits_processor.compute_score(encoded_text, g_values, detector)

                    if z_score_flag:
                        if candidate_z_score > self.config.z_threshold:
                            generated_x[b, idx] = new_token
                    elif ppl_flag:
                        if candidate_ppl < original_ppl:
                            generated_x[b, idx] = new_token

            x = generated_x

        return x


    @ torch.no_grad()
    def generate_unwatermarked_text(self, input_ids, watermark_algorithm=None, watermark_type=None, steps=128, gen_length=128, block_length=16, temperature=0.,
                ppl_model=None, ppl_tokenizer=None, top_p=0.95, top_k=50, cfg_scale=0., remasking='confidence', logits_processor=None, mask_id=126336, attention_mask=None):
        '''
        Args:
            model: Mask predictor.
            prompt: A tensor of shape (1, L).
            steps: Sampling steps, less than or equal to gen_length.
            gen_length: Generated answer length.
            block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
            temperature: Categorical distribution sampling temperature.
            cfg_scale: Unsupervised classifier-free guidance scale.
            remasking: Remasking strategy. 'confidence' or 'random'.
            mask_id: The toke id of [MASK] is 126336.
        '''
        prompt = input_ids

        x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(self.model.device)
        x[:, :prompt.shape[1]] = prompt.clone()

        prompt_index = (x != mask_id)

        assert gen_length % block_length == 0
        num_blocks = gen_length // block_length

        assert steps % num_blocks == 0
        steps = steps // num_blocks

        for num_block in range(num_blocks):
            block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
            num_transfer_tokens = self.get_num_transfer_tokens(block_mask_index, steps)
            for i in range(steps):
                mask_index = (x == mask_id)
                if cfg_scale > 0.:
                    un_x = x.clone()
                    un_x[prompt_index] = mask_id
                    x_ = torch.cat([x, un_x], dim=0)
                    logits = self.model(x_).logits
                    logits, un_logits = torch.chunk(logits, 2, dim=0)
                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
                else:
                    logits = self.model(x).logits

                logits_with_noise = self.add_gumbel_noise(logits, temperature=temperature)
                p = F.softmax(logits_with_noise, dim=-1)

                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
                # x0 = torch.multinomial(p.view(-1, p.size(-1)), num_samples=1).view(logits.size(0), logits.size(1))

                if remasking == 'confidence':
                    x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
                elif remasking == 'random':
                    x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
                elif remasking == 'entropy':
                    log_p = torch.log(p + 1e-12)              
                    entropy = -torch.sum(p * log_p, dim=-1)   
                    x0_p = -entropy
                elif remasking == 'margin':
                    top2_probs, _ = torch.topk(p, k=2, dim=-1) 
                    margin = top2_probs[..., 0] - top2_probs[..., 1]
                    x0_p = margin
                else:
                    raise NotImplementedError(remasking)

                x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

                x0 = torch.where(mask_index, x0, x)
                confidence = torch.where(mask_index, x0_p, -np.inf)

                transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
                for j in range(confidence.shape[0]):
                    _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                    transfer_index[j, select_index] = True
                x[transfer_index] = x0[transfer_index]

        return x
