import fire
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import torch
from tqdm import trange, tqdm


def generate_continuations(
    model,
    batch_size,
    num_completions,
    temperature,
    max_new_tokens,
    max_input_length,
    tokenizer,
    prompt,
    accelerator,
    gpu_num=None,
    gpu_total=None,
):
    with torch.no_grad():
        model.eval()
        model = model.to(accelerator.device)  # Move model to device

        single_prompt_ids = []
        tokenizer.pad_token = tokenizer.eos_token
        inputs = tokenizer(
            prompt, truncation=True, max_length=max_input_length, return_tensors="pt"
        ).to(
            accelerator.device
        )  # Move inputs to device
        for _ in trange(
            max(1, num_completions // batch_size)
        ):  # assume num_completions is divisible by batch_size
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                num_return_sequences=batch_size,
                do_sample=True,
            )
            # pad outputs to max_new_tokens + max_input_length
            outputs = torch.nn.functional.pad(
                outputs,
                (0, max_new_tokens + max_input_length - outputs.shape[1]),
                mode="constant",
                value=tokenizer.pad_token_id,
            )
            single_prompt_ids.append(outputs)

        # Combine the tensors on the same device
        generated_ids = torch.cat(single_prompt_ids, dim=0)
        return generated_ids

        # convert ids to text
        # generated_texts = []
        # for i in range(generated_ids.shape[0]):
        #     generated_texts.append(tokenizer.decode(generated_ids[i], skip_special_tokens=True, clean_up_tokenization_spaces=True))
        # print(generated_texts)


if __name__ == "__main__":
    fire.Fire(generate_continuations)
