import torch
import torch.nn.functional as F

def generate_transformer(
    model,
    tokenizer1,
    tokenizer2,
    encode_text,
    decode_text,
    max_length,
    entry_length=30, #maximum number of words
    top_p=0.8,
    temperature=1.,
    eos_token='[SEP]', #<|endoftext|>,
    all_special_tokens = ['[CLS]','[PAD]','[SEP]', '[UNK]']
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        entry_finished = False
        src_generated = torch.tensor(tokenizer1.encode(encode_text, padding='max_length', max_length=max_length)[:max_length]).unsqueeze(0).to(device)
        trg_generated = torch.tensor(tokenizer2.encode(decode_text, add_special_tokens=False)[:max_length]).unsqueeze(0).to(device)

        for i in range(entry_length):

            logits, _ = model(src_generated, trg_generated)
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = filter_value

            next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            trg_generated = torch.cat((trg_generated, next_token), dim=1)[:,:max_length]

            if next_token in tokenizer2.encode(eos_token):
                entry_finished = True

            if entry_finished:

                generated_num = generated_num + 1

                output_list = list(trg_generated.squeeze().cpu().numpy())                
                output_text = tokenizer2.decode(torch.tensor(output_list))
                generated_list.append(output_text)
                break

        if not entry_finished:
            output_list = list(trg_generated.squeeze().cpu().numpy())
            output_text = f"{tokenizer2.decode(torch.tensor(output_list))}{eos_token}" 
            generated_list.append(output_text)
    
    generated_text = generated_list[0]
    
    for token in all_special_tokens:
        generated_text = generated_text.replace(token,'')
        
    return generated_text.strip() 

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.lr_mul = lr_mul
        self.d_model = d_model
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0


    def step(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()


    def zero_grad(self):
        "Zero out the gradients with the inner optimizer"
        self._optimizer.zero_grad()


    def _get_lr_scale(self):
        d_model = self.d_model
        n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
        return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))


    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_steps += 1
        lr = self.lr_mul * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr