import torch
import numpy as np
import torch.nn.functional as F
from .generate_ssd_cache import ssd_with_cache
from .generate_ssd import ssd_without_cache
from .generate_vanilla import vanilla,vanilla_with_cache


def generate(input_ids,attention_mask,model,gen_length,block_length,steps,
             temperature,cfg_scale,mask_id,draft_length,tree_strategy,remasking,
             verbose,kv_cache,ssd,refresh_interval,**kwargs):
    if ssd:
        if kv_cache:
            generated_answer, nfe = ssd_with_cache(input_ids,attention_mask,model,gen_length,block_length,temperature,cfg_scale,
                                                   mask_id,draft_length,tree_strategy,verbose,refresh_interval,**kwargs)
        else:
            generated_answer, nfe = ssd_without_cache(input_ids,attention_mask,model,gen_length,block_length,temperature,cfg_scale,
                                                   mask_id,draft_length,tree_strategy,verbose,**kwargs)
    else:
        if kv_cache:
            generated_answer, nfe = vanilla_with_cache(input_ids,attention_mask,model,gen_length,block_length,temperature,cfg_scale,
             mask_id,remasking,refresh_interval,**kwargs)
        else:
            generated_answer, nfe = vanilla(input_ids,attention_mask,model,gen_length,block_length,steps,temperature,cfg_scale,
             mask_id,remasking,**kwargs)

    return generated_answer, nfe