# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn.functional as F

from . import DecodingStrategy, register_strategy
from .strategy_utils import generate_step_with_prob,assign_single_value_long, assign_single_value_byte, assign_multi_value_long, convert_tokens
import sys
import hashlib


@register_strategy('mask_predict')
class MaskPredict(DecodingStrategy):
    
    def __init__(self, args):
        super().__init__()
        self.iterations = args.decoding_iterations
    
    def generate(self, model, encoder_out, tgt_tokens, tgt_dict,wm_args):
        # print('generate is here!')
        # exit(0)
        
        # print('wm_type in generate:',wm_type)
        # exit(0)

        q_1=wm_args['q_1']
        p_1_1=wm_args['p_1_1']
        delta=wm_args['delta']
        
        bsz, seq_len = tgt_tokens.size()
        pad_mask = tgt_tokens.eq(tgt_dict.pad())
        seq_lens = seq_len - pad_mask.sum(dim=1)
        
        #generate a key sequence in advance
        #binary key now 
        assert q_1 is not None
        assert p_1_1 is not None
        # k_set=torch.tensor([41,42],device=X.device)
        
        k_seq=torch.zeros([bsz,seq_len],dtype=int,device=tgt_tokens.device)
        # markov_rng=torch.Generator(device=X.device)
        
        for bsz_idx in range(bsz):
            if torch.rand(1)<q_1: #get some true randomness here
                k_seq[bsz_idx,0]=0
            else:
                k_seq[bsz_idx,0]=1
                
            for len_idx in range(1,seq_len):
                if torch.rand(1)<p_1_1:
                    k_seq[bsz_idx,len_idx]=k_seq[bsz_idx,len_idx-1]
                else: #transition
                    k_seq[bsz_idx,len_idx]=1-k_seq[bsz_idx,len_idx-1]

        
        iterations = seq_len if self.iterations is None else self.iterations
        
        tgt_tokens, token_probs = self.generate_non_autoregressive(model, encoder_out, tgt_tokens,key_seq=k_seq,delta=delta)
        assign_single_value_byte(tgt_tokens, pad_mask, tgt_dict.pad())
        assign_single_value_byte(token_probs, pad_mask, 1.0)
        #print("Initialization: ", convert_tokens(tgt_dict, tgt_tokens[0]))
        
        for counter in range(1, iterations):
            last_itr_tokens=tgt_tokens[:,:]
            num_mask = (seq_lens.float() * (1.0 - (counter / iterations))).long()

            assign_single_value_byte(token_probs, pad_mask, 1.0)
            mask_ind = self.select_worst(token_probs, num_mask)
            assign_single_value_long(tgt_tokens, mask_ind, tgt_dict.mask())
            assign_single_value_byte(tgt_tokens, pad_mask, tgt_dict.pad())

            decoder_out = model.decoder(tgt_tokens, encoder_out)
            new_tgt_tokens, new_token_probs, all_token_probs=self.generate_with_watermark_final(decoder_output=decoder_out,
                                                                                                key_seq=k_seq,
                                                                                                delta=delta)
            assign_multi_value_long(token_probs, mask_ind, new_token_probs)
            assign_single_value_byte(token_probs, pad_mask, 1.0)
            
            assign_multi_value_long(tgt_tokens, mask_ind, new_tgt_tokens)
            assign_single_value_byte(tgt_tokens, pad_mask, tgt_dict.pad())

        
        lprobs = token_probs.log().sum(-1)
        return tgt_tokens, lprobs
    
    def generate_non_autoregressive(self, model, encoder_out, tgt_tokens,key_seq,delta):
        decoder_out = model.decoder(tgt_tokens, encoder_out)
        # tgt_tokens, token_probs, _ = generate_step_with_prob(decoder_out)
        # tgt_tokens, token_probs, _,hash_tokens = self.generate_with_watermark_toy(decoder_out,hash_tokens=hash_tokens)
        tgt_tokens,token_probs,_=self.generate_with_watermark_final(decoder_output=decoder_out,
                                                                    key_seq=key_seq,delta=delta)
        return tgt_tokens, token_probs

    def select_worst(self, token_probs, num_mask):
        bsz, seq_len = token_probs.size()
        masks = [token_probs[batch, :].topk(max(1, num_mask[batch]), largest=False, sorted=False)[1] for batch in range(bsz)]
        # print(masks)
        # exit(0)
        masks = [torch.cat([mask, mask.new(seq_len - mask.size(0)).fill_(mask[0])], dim=0) for mask in masks]
        #  the mask is very weird
        # print(masks)
        # exit(0)
        return torch.stack(masks, dim=0)
    
    def sample_from_dist(self,dist,dim):
        cumsum=torch.cumsum(dist,dim=dim)
        if len(cumsum.shape)==3:
            rand_v=torch.rand((cumsum.shape[0],cumsum.shape[1],1),device=cumsum.device)
        elif len(cumsum.shape)==1:
            rand_v=torch.rand(1,device=cumsum.device)
        else:
            raise NotImplementedError
        indices=torch.searchsorted(cumsum,rand_v,right=True)
        indices=torch.clamp(indices,0,cumsum.shape[-1]-1)
        return indices
    
    
    def generate_with_watermark_final(self,decoder_output,key_seq=None,delta=None):
        def get_red_green_list(vocab_size,device):
            #hyperparameters
            private_key=42 
            gamma=0.5
            
            greenlist_size = int(vocab_size * gamma)
            rng=torch.Generator(device=device)
            rng.manual_seed(private_key)
            
            vocab_permutation = torch.randperm(vocab_size, device=device, generator=rng)
            green_list_ids = vocab_permutation[:greenlist_size]
            red_list_ids=vocab_permutation[greenlist_size:]
            return green_list_ids,red_list_ids

        
        logits=decoder_output[0]
        ori_logits=logits.clone().detach()
        batch_size,max_length,vocabulary_size=logits.shape
        
        indices=torch.zeros((batch_size,max_length,1),device=logits.device,dtype=torch.long)
        

        assert delta is not None
        cur_green_list,cur_red_list=get_red_green_list(vocabulary_size,logits.device)
        for len_idx in range(max_length):
            for bsz_idx in range(batch_size):
                if key_seq[bsz_idx,len_idx]==0: # use greenlist to increase logits
                    logits[bsz_idx,len_idx,cur_green_list]=logits[bsz_idx,len_idx,cur_green_list]+delta
                else: #otherwise use redlist to increase logits
                    logits[bsz_idx,len_idx,cur_red_list]=logits[bsz_idx,len_idx,cur_red_list]+delta
        
            indices[:,len_idx:len_idx+1,:]=self.sample_from_dist(F.softmax(logits[:,len_idx:len_idx+1,:],dim=-1),dim=-1)
            
        sampled_prob=torch.gather(F.softmax(ori_logits,dim=-1),dim=-1,index=indices).squeeze(dim=-1)
        return indices.squeeze(dim=-1),sampled_prob,F.softmax(ori_logits,dim=-1)
