import torch
import tqdm


def generate_synth_1(
    vocab_size=5000,
    batch_size=32,
    seq_len=128,
    device="cuda",
    noise_scale=1.0,
    draft_noise_scale=1.0,
):
    logits_p = torch.randn(batch_size, seq_len, vocab_size, device=device) * noise_scale
    logits_q = (
        logits_p
        + torch.randn(batch_size, seq_len, vocab_size, device=device)
        * draft_noise_scale
    )
    valid_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.bool)
    return logits_q, logits_p, valid_mask


def get_logits_generator(
    data_kwargs, model_kwargs, generation_kwargs, reproducibility_kwargs
):
    batch_size = generation_kwargs["batch_size"]
    num_prompts = data_kwargs["num_prompts"]
    for i in tqdm.trange(0, num_prompts, batch_size):
        idx_start = i
        idx_end = min(i + batch_size, num_prompts)

        seed = reproducibility_kwargs["seed"] + idx_start
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        yield generate_synth_1(
            vocab_size=5000,
            batch_size=idx_end - idx_start,
            seq_len=128,
            device="cuda",
        )
